其他
mlr3实战:决策树和xgboost预测房价
前面用10篇推文详细介绍了mlr3
包的基础使用及进阶方法。
今天学习用一个简单的例子说明mlr3
的实战用法。
预测King Country地区的房价,将学习使用mlr3
及其生态进行数据预处理、建模、重抽样、超参数调优等内容。用到了决策树以及xgboost
。
加载数据和R包
library(mlr3verse)
## 载入需要的程辑包:mlr3
set.seed(123) # 设置种子数,数据可重复
lgr::get_logger("mlr3")$set_threshold("warn") # 减少屏幕日志
lgr::get_logger("bbotk")$set_threshold("warn")
data("kc_housing", package = "mlr3data") # 加载数据
数据探索
str(kc_housing)
## 'data.frame': 21613 obs. of 20 variables:
## $ date : POSIXct, format: "2014-10-13" "2014-12-09" ...
## $ price : num 221900 538000 180000 604000 510000 ...
## $ bedrooms : int 3 3 2 4 3 4 3 3 3 3 ...
## $ bathrooms : num 1 2.25 1 3 2 4.5 2.25 1.5 1 2.5 ...
## $ sqft_living : int 1180 2570 770 1960 1680 5420 1715 1060 1780 1890 ...
## $ sqft_lot : int 5650 7242 10000 5000 8080 101930 6819 9711 7470 6560 ...
## $ floors : num 1 2 1 1 1 1 2 1 1 2 ...
## $ waterfront : logi FALSE FALSE FALSE FALSE FALSE FALSE ...
## $ view : int 0 0 0 0 0 0 0 0 0 0 ...
## $ condition : int 3 3 3 5 3 3 3 3 3 3 ...
## $ grade : int 7 7 6 7 8 11 7 7 7 7 ...
## $ sqft_above : int 1180 2170 770 1050 1680 3890 1715 1060 1050 1890 ...
## $ sqft_basement: int NA 400 NA 910 NA 1530 NA NA 730 NA ...
## $ yr_built : int 1955 1951 1933 1965 1987 2001 1995 1963 1960 2003 ...
## $ yr_renovated : int NA 1991 NA NA NA NA NA NA NA NA ...
## $ zipcode : int 98178 98125 98028 98136 98074 98053 98003 98198 98146 98038 ...
## $ lat : num 47.5 47.7 47.7 47.5 47.6 ...
## $ long : num -122 -122 -122 -122 -122 ...
## $ sqft_living15: int 1340 1690 2720 1360 1800 4760 2238 1650 1780 2390 ...
## $ sqft_lot15 : int 5650 7639 8062 5000 7503 101930 6819 9711 8113 7570 ...
## - attr(*, "index")= int(0)
dim(kc_housing) # 21613,20
## [1] 21613 20
summary(kc_housing)
## date price bedrooms
## Min. :2014-05-02 00:00:00 Min. : 75000 Min. : 0.000
## 1st Qu.:2014-07-22 00:00:00 1st Qu.: 321950 1st Qu.: 3.000
## Median :2014-10-16 00:00:00 Median : 450000 Median : 3.000
## Mean :2014-10-29 03:58:09 Mean : 540088 Mean : 3.371
## 3rd Qu.:2015-02-17 00:00:00 3rd Qu.: 645000 3rd Qu.: 4.000
## Max. :2015-05-27 00:00:00 Max. :7700000 Max. :33.000
##
## bathrooms sqft_living sqft_lot floors
## Min. :0.000 Min. : 290 Min. : 520 Min. :1.000
## 1st Qu.:1.750 1st Qu.: 1427 1st Qu.: 5040 1st Qu.:1.000
## Median :2.250 Median : 1910 Median : 7618 Median :1.500
## Mean :2.115 Mean : 2080 Mean : 15107 Mean :1.494
## 3rd Qu.:2.500 3rd Qu.: 2550 3rd Qu.: 10688 3rd Qu.:2.000
## Max. :8.000 Max. :13540 Max. :1651359 Max. :3.500
##
## waterfront view condition grade
## Mode :logical Min. :0.0000 Min. :1.000 Min. : 1.000
## FALSE:21450 1st Qu.:0.0000 1st Qu.:3.000 1st Qu.: 7.000
## TRUE :163 Median :0.0000 Median :3.000 Median : 7.000
## Mean :0.2343 Mean :3.409 Mean : 7.657
## 3rd Qu.:0.0000 3rd Qu.:4.000 3rd Qu.: 8.000
## Max. :4.0000 Max. :5.000 Max. :13.000
##
## sqft_above sqft_basement yr_built yr_renovated zipcode
## Min. : 290 Min. : 10.0 Min. :1900 Min. :1934 Min. :98001
## 1st Qu.:1190 1st Qu.: 450.0 1st Qu.:1951 1st Qu.:1987 1st Qu.:98033
## Median :1560 Median : 700.0 Median :1975 Median :2000 Median :98065
## Mean :1788 Mean : 742.4 Mean :1971 Mean :1996 Mean :98078
## 3rd Qu.:2210 3rd Qu.: 980.0 3rd Qu.:1997 3rd Qu.:2007 3rd Qu.:98118
## Max. :9410 Max. :4820.0 Max. :2015 Max. :2015 Max. :98199
## NA's :13126 NA's :20699
## lat long sqft_living15 sqft_lot15
## Min. :47.16 Min. :-122.5 Min. : 399 Min. : 651
## 1st Qu.:47.47 1st Qu.:-122.3 1st Qu.:1490 1st Qu.: 5100
## Median :47.57 Median :-122.2 Median :1840 Median : 7620
## Mean :47.56 Mean :-122.2 Mean :1987 Mean : 12768
## 3rd Qu.:47.68 3rd Qu.:-122.1 3rd Qu.:2360 3rd Qu.: 10083
## Max. :47.78 Max. :-121.3 Max. :6210 Max. :871200
##
数据预处理
price
是结果变量(target),其余是预测变量(feature)。
首先要把日期型变量date
变为数值型,然后以最早的日期为标准变成数值,以天为单位。
把邮政编码变为因子型。
增加新列renovates
,记录房子是否翻修过。
增加新列has_basement
,记录有无地下室情况。
把房子价格单位从1。
删除有缺失值的行。
library(anytime)
dates <- anytime(kc_housing$date)
kc_housing$date <- as.numeric(difftime(dates, min(dates), units = "days"))
kc_housing$renovated <- as.numeric(!is.na(kc_housing$yr_renovated))
kc_housing$has_basement <- as.numeric(!is.na(kc_housing$sqft_basement))
kc_housing$yr_renovated <- NULL
kc_housing$sqft_basement <- NULL
kc_housing$price <- kc_housing$price / 1000
简单画图看一下:
library(ggplot2)
ggplot(kc_housing, aes(x = price)) + geom_density() + theme_minimal()
创建任务
task <- as_task_regr(kc_housing, target = "price")
task
## <TaskRegr:kc_housing> (21613 x 20)
## * Target: price
## * Properties: -
## * Features (19):
## - int (11): bedrooms, condition, grade, sqft_above, sqft_living,
## sqft_living15, sqft_lot, sqft_lot15, view, yr_built, zipcode
## - dbl (7): bathrooms, date, floors, has_basement, lat, long,
## renovated
## - lgl (1): waterfront
autoplot(task)+facet_wrap(~ condition)
# 变量间关系
autoplot(task$clone()$select(task$feature_names[c(3,17)]),type = "pairs")
## Registered S3 method overwritten by 'GGally':
## method from
## +.gg ggplot2
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
划分数据开始建模
split <- partition(task, ratio = 0.7)
train_idx <- split$train
test_idx <- split$test
task_train <- task$clone()$filter(train_idx)
task_test <- task$clone()$filter(test_idx)
决策树
# 先不用zipcode这一列
task_nozip <- task_train$clone()$select(setdiff(task$feature_names, "zipcode"))
# 建模
lrn <- lrn("regr.rpart")
lrn$train(task_nozip, row_ids = train_idx)
# 可视化决策树
library(rpart.plot)
## 载入需要的程辑包:rpart
rpart.plot(lrn$model)
可以看到决策树在grade/sqft_living/lat
等水平上进行了分支,下面画一个地图,看看经纬度对价格的影响。
library(ggmap)
## Google's Terms of Service: https://cloud.google.com/maps-platform/terms/.
## Please cite ggmap if you use it! See citation("ggmap") for details.
qmplot(long, lat, maptype = "watercolor", color = log(price),
data = kc_housing[train_idx[1:3000],]) +
scale_colour_viridis_c()
很明显还是靠近水边的房子价格更贵!经纬度对房价影响也是有一点的。
下面看看不同邮政区域对价格的影响。
qmplot(long, lat, maptype = "watercolor", color = zipcode,
data = kc_housing[train_idx[1:3000],]) + guides(color = "none")
看起来不同邮政区域对价格有影响的。
下面用加上邮政区域的数据进行建模,使用3折交叉验证提高模型稳定性:
lrn_rpart <- lrn("regr.rpart")
cv3 <- rsmp("cv", folds = 3)
res <- resample(task_train, lrn_rpart, cv3, store_models = T)
res$aggregate(msr("regr.rmse"))
## regr.rmse
## 221.0799
xgboost
lrn_xgboost <- lrn("regr.xgboost")
lrn_xgboost$param_set # 查看可以设置的超参数
## <ParamSet>
## id class lower upper nlevels default
## 1: alpha ParamDbl 0 Inf Inf 0
## 2: approxcontrib ParamLgl NA NA 2 FALSE
## 3: base_score ParamDbl -Inf Inf Inf 0.5
## 4: booster ParamFct NA NA 3 gbtree
## 5: callbacks ParamUty NA NA Inf <list[0]>
## 6: colsample_bylevel ParamDbl 0 1 Inf 1
## 7: colsample_bynode ParamDbl 0 1 Inf 1
## 8: colsample_bytree ParamDbl 0 1 Inf 1
## 9: disable_default_eval_metric ParamLgl NA NA 2 FALSE
## 10: early_stopping_rounds ParamInt 1 Inf Inf
## 11: eta ParamDbl 0 1 Inf 0.3
## 12: eval_metric ParamUty NA NA Inf rmse
## 13: feature_selector ParamFct NA NA 5 cyclic
## 14: feval ParamUty NA NA Inf
## 15: gamma ParamDbl 0 Inf Inf 0
## 16: grow_policy ParamFct NA NA 2 depthwise
## 17: interaction_constraints ParamUty NA NA Inf <NoDefault[3]>
## 18: iterationrange ParamUty NA NA Inf <NoDefault[3]>
## 19: lambda ParamDbl 0 Inf Inf 1
## 20: lambda_bias ParamDbl 0 Inf Inf 0
## 21: max_bin ParamInt 2 Inf Inf 256
## 22: max_delta_step ParamDbl 0 Inf Inf 0
## 23: max_depth ParamInt 0 Inf Inf 6
## 24: max_leaves ParamInt 0 Inf Inf 0
## 25: maximize ParamLgl NA NA 2
## 26: min_child_weight ParamDbl 0 Inf Inf 1
## 27: missing ParamDbl -Inf Inf Inf NA
## 28: monotone_constraints ParamUty NA NA Inf 0
## 29: normalize_type ParamFct NA NA 2 tree
## 30: nrounds ParamInt 1 Inf Inf <NoDefault[3]>
## 31: nthread ParamInt 1 Inf Inf 1
## 32: ntreelimit ParamInt 1 Inf Inf
## 33: num_parallel_tree ParamInt 1 Inf Inf 1
## 34: objective ParamUty NA NA Inf reg:squarederror
## 35: one_drop ParamLgl NA NA 2 FALSE
## 36: outputmargin ParamLgl NA NA 2 FALSE
## 37: predcontrib ParamLgl NA NA 2 FALSE
## 38: predictor ParamFct NA NA 2 cpu_predictor
## 39: predinteraction ParamLgl NA NA 2 FALSE
## 40: predleaf ParamLgl NA NA 2 FALSE
## 41: print_every_n ParamInt 1 Inf Inf 1
## 42: process_type ParamFct NA NA 2 default
## 43: rate_drop ParamDbl 0 1 Inf 0
## 44: refresh_leaf ParamLgl NA NA 2 TRUE
## 45: reshape ParamLgl NA NA 2 FALSE
## 46: sample_type ParamFct NA NA 2 uniform
## 47: sampling_method ParamFct NA NA 2 uniform
## 48: save_name ParamUty NA NA Inf
## 49: save_period ParamInt 0 Inf Inf
## 50: scale_pos_weight ParamDbl -Inf Inf Inf 1
## 51: seed_per_iteration ParamLgl NA NA 2 FALSE
## 52: single_precision_histogram ParamLgl NA NA 2 FALSE
## 53: sketch_eps ParamDbl 0 1 Inf 0.03
## 54: skip_drop ParamDbl 0 1 Inf 0
## 55: strict_shape ParamLgl NA NA 2 FALSE
## 56: subsample ParamDbl 0 1 Inf 1
## 57: top_k ParamInt 0 Inf Inf 0
## 58: training ParamLgl NA NA 2 FALSE
## 59: tree_method ParamFct NA NA 5 auto
## 60: tweedie_variance_power ParamDbl 1 2 Inf 1.5
## 61: updater ParamUty NA NA Inf <NoDefault[3]>
## 62: verbose ParamInt 0 2 3 1
## 63: watchlist ParamUty NA NA Inf
## 64: xgb_model ParamUty NA NA Inf
## id class lower upper nlevels default
## parents value
## 1:
## 2:
## 3:
## 4:
## 5:
## 6:
## 7:
## 8:
## 9:
## 10:
## 11:
## 12:
## 13: booster
## 14:
## 15:
## 16: tree_method
## 17:
## 18:
## 19:
## 20: booster
## 21: tree_method
## 22:
## 23:
## 24: grow_policy
## 25:
## 26:
## 27:
## 28:
## 29: booster
## 30: 1
## 31: 1
## 32:
## 33:
## 34:
## 35: booster
## 36:
## 37:
## 38:
## 39:
## 40:
## 41: verbose
## 42:
## 43: booster
## 44:
## 45:
## 46: booster
## 47: booster
## 48:
## 49:
## 50:
## 51:
## 52: tree_method
## 53: tree_method
## 54: booster
## 55:
## 56:
## 57: booster,feature_selector
## 58:
## 59: booster
## 60: objective
## 61:
## 62: 0
## 63:
## 64:
## parents value
search_space <- ps(
eta = p_dbl(lower = 0.2, upper = .4),
min_child_weight = p_dbl(lower = 1, upper = 20),
subsample = p_dbl(lower = .7, upper = .8),
colsample_bytree = p_dbl( lower = .9, upper = 1),
colsample_bylevel = p_dbl(lower = .5, upper = .7),
nrounds = p_int(lower = 1L, upper = 25)
)
at <- auto_tuner(
method = "random_search",
learner = lrn_xgboost,
resampling = rsmp("holdout"),
measure = msr("regr.rmse"),
search_space = search_space,
term_evals = 10,
batch_size = 8
)
res <- resample(task_nozip, at, cv3, store_models = T)
res$aggregate()
## regr.mse
## 19122.95
效果比决策树好很多!
以上就是今天的内容,希望对你有帮助哦!欢迎点赞、在看、关注、转发!
欢迎在评论区留言或直接添加我的微信!
欢迎关注公众号:医学和生信笔记
“医学和生信笔记 公众号主要分享:1.医学小知识、肛肠科小知识;2.R语言和Python相关的数据分析、可视化、机器学习等;3.生物信息学学习资料和自己的学习笔记!
往期回顾
2022-03-13
2022-03-14
2022-03-15
2022-03-16
2022-03-12