查看原文
其他

mlr3:模型评价

阿越就是我 医学和生信笔记 2023-02-25


前面一篇介绍了如何使用mlr3创建任务和学习器、拟合模型、预测和简单的评价,本篇将模型评价的一些细节问题,展示mlr3如何使得这些步骤变得更加简单!

二分类变量和ROC曲线

对于二分类变量,结果有阴性和阳性两种,而且判定阴性和阳性的阈值是可以认为设定的。ROC曲线可以很好的帮助我们确定最佳的分割点。

首先看一下如何获取一个分类变量的混淆矩阵:

library(mlr3verse)
## 载入需要的程辑包:mlr3
data("Sonar", package = "mlbench")
task <- as_task_classif(Sonar, target = "Class", positive = "M"# 指定阳性

learner <- lrn("classif.rpart", predict_type = "prob"# 指定预测类型
prediction <- learner$train(task)$predict(task)
conf <- prediction$confusion
print(conf)
##         truth
## response  M  R
##        M 95 10
##        R 16 87

绘制ROC曲线也是非常方便:

autoplot(prediction, type = "roc")
plot of chunk unnamed-chunk-2

也可以非常方便的绘制PRC曲线:

autoplot(prediction, type = "prc")
plot of chunk unnamed-chunk-3

重抽样

mlr3支持的重抽样方法:

  • cross validation ("cv"),
  • leave-one-out cross validation ("loo"),
  • repeated cross validation ("repeated_cv"),
  • otstrapping ("bootstrap"),
  • subsampling ("subsampling"),
  • holdout ("holdout"),
  • in-sample resampling ("insample"),
  • custom resampling ("custom").

查看重抽样的方法:

library(mlr3verse)
as.data.table(mlr_resamplings)
##            key        params iters
## 1:   bootstrap ratio,repeats    30
## 2:      custom                  NA
## 3:   custom_cv                  NA
## 4:          cv         folds    10
## 5:     holdout         ratio     1
## 6:    insample                   1
## 7:         loo                  NA
## 8: repeated_cv folds,repeats   100
## 9: subsampling ratio,repeats    30

还有一些特殊类型的重抽样方法可以通过扩展包实现,比如mlr3spatiotemporal包。

默认的方法是holdout

resampling <- rsmp("holdout")
print(resampling)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.6667

可以通过以下方法改变比例:

resampling$param_set$values <- list(ratio = 0.8)

# 或者
rsmp("holdout", ratio = 0.8)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.8

下面一个例子使用5折交叉验证方法,建立一个决策树模型:

library(mlr3verse)
task <- tsk("penguins"# 创建任务
learner <- lrn("classif.rpart", predict_type = "prob"# 创建学习器,设定预测的结果是概率
resampling <- rsmp("cv", folds = 5# 选择重抽样方法

rr <- resample(task, learner, resampling, store_models = T# 1行代码搞定
## INFO  [20:47:12.966] [mlr3] Applying learner 'classif.rpart' on task 'penguins' (iter 5/5) 
## INFO  [20:47:12.996] [mlr3] Applying learner 'classif.rpart' on task 'penguins' (iter 1/5) 
## INFO  [20:47:13.010] [mlr3] Applying learner 'classif.rpart' on task 'penguins' (iter 2/5) 
## INFO  [20:47:13.019] [mlr3] Applying learner 'classif.rpart' on task 'penguins' (iter 4/5) 
## INFO  [20:47:13.029] [mlr3] Applying learner 'classif.rpart' on task 'penguins' (iter 3/5)
print(rr)
## <ResampleResult> of 5 iterations
## * Task: penguins
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations

获得平均的模型表现

rr$aggregate(msr("classif.acc"))
## classif.acc 
##   0.9448423

获得单个模型的表现

rr$score(msr("classif.acc"))[,7:9]
##    iteration              prediction classif.acc
## 1:         1 <PredictionClassif[20]>   0.9710145
## 2:         2 <PredictionClassif[20]>   0.8985507
## 3:         3 <PredictionClassif[20]>   0.9130435
## 4:         4 <PredictionClassif[20]>   0.9710145
## 5:         5 <PredictionClassif[20]>   0.9705882

检查警告或者错误:

rr$warnings
## Empty data.table (0 rows and 2 cols): iteration,msg
rr$errors
## Empty data.table (0 rows and 2 cols): iteration,msg

取出单个模型

rr$learners[[5]]$model
## n= 276 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 276 158 Adelie (0.427536232 0.206521739 0.365942029)  
##   2) flipper_length< 206.5 170  54 Adelie (0.682352941 0.311764706 0.005882353)  
##     4) bill_length< 43.35 117   4 Adelie (0.965811966 0.034188034 0.000000000) *
##     5) bill_length>=43.35 53   4 Chinstrap (0.056603774 0.924528302 0.018867925) *
##   3) flipper_length>=206.5 106   6 Gentoo (0.018867925 0.037735849 0.943396226)  
##     6) bill_depth>=17.2 8   4 Chinstrap (0.250000000 0.500000000 0.250000000) *
##     7) bill_depth< 17.2 98   0 Gentoo (0.000000000 0.000000000 1.000000000) *

这个包也可以和其他决策树可视化R包无缝衔接,比如非常画图非常好看的rpart.plot:

library(rpart.plot)
## 载入需要的程辑包:rpart
rpart.plot(rr$learners[[5]]$model)
plot of chunk unnamed-chunk-12

查看预测结果:

rr$prediction()
## <PredictionClassif> for 344 observations:
##     row_ids     truth  response prob.Adelie prob.Chinstrap prob.Gentoo
##           1    Adelie    Adelie  0.96969697     0.03030303  0.00000000
##           4    Adelie    Adelie  0.96969697     0.03030303  0.00000000
##          26    Adelie    Adelie  0.96969697     0.03030303  0.00000000
## ---                                                                   
##         333 Chinstrap Chinstrap  0.05660377     0.92452830  0.01886792
##         334 Chinstrap Chinstrap  0.05660377     0.92452830  0.01886792
##         335 Chinstrap Chinstrap  0.05660377     0.92452830  0.01886792
# 查看单个预测结果
rr$predictions()[[1]]
## <PredictionClassif> for 69 observations:
##     row_ids     truth  response prob.Adelie prob.Chinstrap prob.Gentoo
##           1    Adelie    Adelie  0.96969697     0.03030303  0.00000000
##           4    Adelie    Adelie  0.96969697     0.03030303  0.00000000
##          26    Adelie    Adelie  0.96969697     0.03030303  0.00000000
## ---                                                                   
##         338 Chinstrap Chinstrap  0.08888889     0.88888889  0.02222222
##         342 Chinstrap Chinstrap  0.08888889     0.88888889  0.02222222
##         344 Chinstrap Chinstrap  0.08888889     0.88888889  0.02222222

提取特定iteration的结果

rr$filter(c(3,5))
print(rr)
## <ResampleResult> of 2 iterations
## * Task: penguins
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations

可视化结果:

task <- tsk("pima"# 非常著名的糖尿病数据集
task$select(c("glucose","mass"))
learner <- lrn("classif.rpart", predict_type = "prob")
resampling <- rsmp("cv")
rr <- resample(task, learner, resampling, store_models = T)
## INFO  [20:47:13.436] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 5/10) 
## INFO  [20:47:13.449] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 6/10) 
## INFO  [20:47:13.461] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 9/10) 
## INFO  [20:47:13.473] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 8/10) 
## INFO  [20:47:13.488] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/10) 
## INFO  [20:47:13.501] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/10) 
## INFO  [20:47:13.513] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 10/10) 
## INFO  [20:47:13.524] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/10) 
## INFO  [20:47:13.536] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 7/10) 
## INFO  [20:47:13.548] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/10)

autoplot(rr, measure = msr("classif.auc"))
plot of chunk unnamed-chunk-16

ROC曲线:10折交叉验证平均后的:

autoplot(rr, type = "roc")
plot of chunk unnamed-chunk-17

树状图:

autoplot(rr, type = "prediction")
plot of chunk unnamed-chunk-18

可视化单个模型:

rr1 <- rr$filter(1)

autoplot(rr1, type = "prediction")
plot of chunk unnamed-chunk-19

所有支持的可视化类型可在此处找到:autoplot.ResampleResult

内容太多了,明天学习多个模型的比较!



以上就是今天的内容,希望对你有帮助哦!欢迎点赞、在看、关注、转发

欢迎在评论区留言或直接添加我的微信!




欢迎关注我的公众号:医学和生信笔记

医学和生信笔记 公众号主要分享:1.医学小知识、肛肠科小知识;2.R语言和Python相关的数据分析、可视化、机器学习等;3.生物信息学学习资料和自己的学习笔记!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存