使用workflow一次完成多个模型的评价和比较
💡专注R语言在🩺生物医学中的使用
前面给大家介绍了使用tidymodels
搞定二分类资料的模型评价和比较。
简介的语法、统一的格式、优雅的操作,让人欲罢不能!
但是太费事儿了,同样的流程来了4遍,那要是选择10个模型,就得来10遍!无聊,非常的无聊。
所以个大家介绍简便方法,不用重复写代码,一次搞定多个模型!
本期目录:
加载数据和R包
数据预处理
选择模型
选择重抽样方法
构建workflow
运行模型
查看结果
可视化结果
选择最好的模型用于测试集
加载数据和R包
首先还是加载数据和R包,和前面的一模一样的操作,数据也没变。
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(tidymodels))
library(kknn)
tidymodels_prefer()
all_plays <- read_rds("../000files/all_plays.rds")
set.seed(20220520)
split_pbp <- initial_split(all_plays, 0.75, strata = play_type)
train_data <- training(split_pbp)
test_data <- testing(split_pbp)
数据预处理
pbp_rec <- recipe(play_type ~ ., data = train_data) %>%
step_rm(half_seconds_remaining,yards_gained, game_id) %>%
step_string2factor(posteam, defteam) %>%
step_corr(all_numeric(), threshold = 0.7) %>%
step_center(all_numeric()) %>%
step_zv(all_predictors())
选择模型
直接选择4个模型,你想选几个都是可以的。
lm_mod <- logistic_reg(mode = "classification",engine = "glm")
knn_mod <- nearest_neighbor(mode = "classification", engine = "kknn")
rf_mod <- rand_forest(mode = "classification", engine = "ranger")
tree_mod <- decision_tree(mode = "classification",engine = "rpart")
选择重抽样方法
set.seed(20220520)
folds <- vfold_cv(train_data, v = 10)
folds
## # 10-fold cross-validation
## # A tibble: 10 × 2
## splits id
## <list> <chr>
## 1 <split [62082/6899]> Fold01
## 2 <split [62083/6898]> Fold02
## 3 <split [62083/6898]> Fold03
## 4 <split [62083/6898]> Fold04
## 5 <split [62083/6898]> Fold05
## 6 <split [62083/6898]> Fold06
## 7 <split [62083/6898]> Fold07
## 8 <split [62083/6898]> Fold08
## 9 <split [62083/6898]> Fold09
## 10 <split [62083/6898]> Fold10
构建workflow
这一步就是不用重复写代码的关键,把所有模型和数据预处理步骤自动连接起来。
library(workflowsets)
four_mods <- workflow_set(list(rec = pbp_rec),
list(lm = lm_mod,
knn = knn_mod,
rf = rf_mod,
tree = tree_mod
),
cross = T
)
four_mods
## # A workflow set/tibble: 4 × 4
## wflow_id info option result
## <chr> <list> <list> <list>
## 1 rec_lm <tibble [1 × 4]> <opts[0]> <list [0]>
## 2 rec_knn <tibble [1 × 4]> <opts[0]> <list [0]>
## 3 rec_rf <tibble [1 × 4]> <opts[0]> <list [0]>
## 4 rec_tree <tibble [1 × 4]> <opts[0]> <list [0]>
运行模型
首先是一些运行过程中的参数设置:
keep_pred <- control_resamples(save_pred = T, verbose = T)
然后就是运行4个模型(目前一直是在训练集中),我们给它加速一下:
library(doParallel)
## Loading required package: foreach
##
## Attaching package: 'foreach'
## The following objects are masked from 'package:purrr':
##
## accumulate, when
## Loading required package: iterators
## Loading required package: parallel
cl <- makePSOCKcluster(12) # 加速,用12个线程
registerDoParallel(cl)
four_fits <- four_mods %>%
workflow_map("fit_resamples",
seed = 0520,
verbose = T,
resamples = folds,
control = keep_pred
)
## i 1 of 4 resampling: rec_lm
## ✔ 1 of 4 resampling: rec_lm (18.4s)
## i 2 of 4 resampling: rec_knn
## ✔ 2 of 4 resampling: rec_knn (3m 51.9s)
## i 3 of 4 resampling: rec_rf
## ✔ 3 of 4 resampling: rec_rf (1m 15.6s)
## i 4 of 4 resampling: rec_tree
## ✔ 4 of 4 resampling: rec_tree (6.1s)
four_fits
## # A workflow set/tibble: 4 × 4
## wflow_id info option result
## <chr> <list> <list> <list>
## 1 rec_lm <tibble [1 × 4]> <opts[2]> <rsmp[+]>
## 2 rec_knn <tibble [1 × 4]> <opts[2]> <rsmp[+]>
## 3 rec_rf <tibble [1 × 4]> <opts[2]> <rsmp[+]>
## 4 rec_tree <tibble [1 × 4]> <opts[2]> <rsmp[+]>
stopCluster(cl)
需要很长时间!大家笔记本如果内存不够可能会失败哦~
查看结果
查看模型在训练集中的表现:
collect_metrics(four_fits)
## # A tibble: 8 × 9
## wflow_id .config preproc model .metric .estimator mean n std_err
## <chr> <chr> <chr> <chr> <chr> <chr> <dbl> <int> <dbl>
## 1 rec_lm Preprocessor1_M… recipe logi… accura… binary 0.724 10 1.91e-3
## 2 rec_lm Preprocessor1_M… recipe logi… roc_auc binary 0.781 10 1.88e-3
## 3 rec_knn Preprocessor1_M… recipe near… accura… binary 0.671 10 7.31e-4
## 4 rec_knn Preprocessor1_M… recipe near… roc_auc binary 0.716 10 1.28e-3
## 5 rec_rf Preprocessor1_M… recipe rand… accura… binary 0.732 10 1.48e-3
## 6 rec_rf Preprocessor1_M… recipe rand… roc_auc binary 0.799 10 1.90e-3
## 7 rec_tree Preprocessor1_M… recipe deci… accura… binary 0.720 10 1.97e-3
## 8 rec_tree Preprocessor1_M… recipe deci… roc_auc binary 0.704 10 2.01e-3
查看每一个预测结果,这个就不运行了,毕竟好几万行,太多了。。。
collect_predictions(four_fits)
可视化结果
直接可视化4个模型的结果,感觉比ROC曲线更好看,还给出了可信区间。
这个图可以自己用ggplot2
语法修改。
four_fits %>% autoplot(metric = "roc_auc")+theme_bw()
选择最好的模型用于测试集
选择表现最好的应用于测试集:
rand_res <- last_fit(rf_mod,pbp_rec,split_pbp)
查看在测试集的模型表现:
collect_metrics(rand_res) # test 中的模型表现
使用其他指标查看模型表现:
metricsets <- metric_set(accuracy, mcc, f_meas, j_index)
collect_predictions(rand_res) %>%
metricsets(truth = play_type, estimate = .pred_class)
可视化结果,喜闻乐见的混淆矩阵:
collect_predictions(rand_res) %>%
conf_mat(play_type,.pred_class) %>%
autoplot()
喜闻乐见的ROC曲线:
collect_predictions(rand_res) %>%
roc_curve(play_type,.pred_pass) %>%
autoplot()
还有非常多曲线和评价指标可选,大家可以看我之前的介绍推文~
是不是很神奇呢,完美符合一次挑选多个模型的要求,且步骤清稀,代码美观,非常适合进行多个模型的比较。
获取更多信息,欢迎加入🐧QQ交流群:613637742
“医学和生信笔记,专注R语言在临床医学中的使用、R语言数据分析和可视化。主要分享R语言做医学统计学、meta分析、网络药理学、临床预测模型、机器学习、生物信息学等。
往期回顾
ggplot2修改坐标轴详细介绍
ggplot2版本的热图-方便拼图!
相关矩阵的ggplot2版本,方便拼图
ggplot2版本的韦恩图画法
超详细教程:修改ggplot2图例
你没见过的ggplot2另类画图!