第四十二讲 R-回归预测模型的交叉验证
在“R与生物统计专题”中,我们会从介绍R的基本知识展开到生物统计原理及其在R中的实现。以从浅入深,层层递进的形式在投必得医学公众号更新。
在第四十一讲中,我们讲到了判断回归模型性能的指标(第四十一讲 R-判断回归模型性能的指标),但是,我们的例子都是展现在训练数据集(建立模型的数据集)中的拟合情况,也就是说,我们通过训练数据集建立了预测模型,然后在训练数据集中检测模型的拟合性能情况。那么,这个建立的预测模型在独立的另一个数据集,即测试集中的表现如何呢?在实际科研中,我们并不总是能获得一个或多个完全独立的样本作为“训练集”对模型进行验证。于是,我们有了交叉验证和自举重采样(bootstrap-resampling)验证方法来解决这个问题。交叉验证是指一组检测建立的给定预测模型用在新数据(测试数据集)中效果好坏的方法。
交叉验证法背后的基本思想是将数据分为两组:
1) 训练数据集,用于训练(即构建)模型;
2) 测试数据集(或验证集),用于测试(即验证)模型(通过估计预测误差) 。
交叉验证也称为重采样方法,它涉及使用数据的不同子集使用同一方法进行多次拟合。
评估模型性能的交叉验证法有多种方法:
验证集方法(或数据拆分):Validation set approach
单个剔除交叉验证: Leave One Out Cross Validation
k倍交叉验证(又叫k折交叉验证):k-fold Cross Validation
重复k倍交叉验证: Repeated k-fold Cross Validation
这些方法各有优缺点。通常,我们建议使用重复k倍交叉验证。
library(tidyverse)
library(caret)
我们将使用R的内置数据集swiss。
加载数据
library(datasets)
data("swiss")
输出结果
Fertility Agriculture Examination Education Catholic Infant.Mortality
77.3 89.7 5 2 100.00 18.3
76.1 35.3 9 7 90.57 26.6
83.1 45.1 6 9 84.84 22.2
library(datasets)
data("swiss")
Fertility Agriculture Examination Education Catholic Infant.Mortality
77.3 89.7 5 2 100.00 18.3
76.1 35.3 9 7 90.57 26.6
83.1 45.1 6 9 84.84 22.2
研究问题:根据社会经济相关的多个指标(Agriculture,Examination,Education,Catholic,Infant.Mortality)预测生育力得分(Fertility)。
为此,基本策略是:
1) 在训练数据集上建立模型
2) 将模型应用于测试数据集以进行预测
3) 计算预测误差(即预测结果与实际观察结果之间的变异情况)
简而言之,交叉验证算法可以总结如下:
1) 保留一小部分数据集作为测试集
2) 使用数据集(训练集)的其余部分构建(或训练)模型
3) 将训练集上训练好的模型用在测试集上,检查模型的性能。
6.1验证集方法(或数据拆分)
set.seed(123)
training.samples<-swiss$Fertility%>%
createDataPartition(p= 0.8,list=FALSE)
train.data <-swiss[training.samples,]
test.data<-swiss[-training.samples,]
model<-lm(Fertility~.,data=train.data)
predictions<-model%>%predict(test.data)data.frame(R2=R2(predictions,test.data$Fertility), RMSE=RMSE(predictions,test.data$Fertility),MAE=MAE(predictions,test.data$Fertility))
R2 RMSE MAE0.
1 0.39 9.11 7.48
RMSE(predictions,test.data$Fertility)/mean(test.data$Fertility)
请注意,仅当数据集含有较大样本,可以供拆分时,验证集方法才有效。
6.2单个剔除交叉验证
此方法的思路如下:
1) 剔除一个数据样本,并在其余数据集上建立模型
2) 针对在步骤1中剔除的单个数据样本进行模型测试,并记录下与预测相关的预测误差
3) 对所有样本重复该过程
4) 通过取在步骤2中记录的所有这些测试误差估计的平均值,计算总体预测误差。
train.control<-trainControl(method= "LOOCV")
model<-train(Fertility~.,data=swiss,method= "lm", trControl=train.control)
print(model)
Linear Regression
47 samples
5 predictor
No pre-processing
Resampling: Leave-One-Out Cross-Validation
Summary of sample sizes: 46, 46, 46, 46, 46, 46, ...
Resampling results:
RMSE Rsquared MAE
7.74 0.613 6.12
Tuning parameter 'intercept' was held constant at a value of TRUE
单个剔除交叉验证法的优点是我们利用了所有数据样本来减少潜在的误差。
缺点是该过程重复进行的次数与存在数据样本点的次数相同,因此当n极大时,将导致运行时间很长。
此外,我们在每次迭代时针对一个数据点测试模型性能。如果某些数据点是异常值,则可能导致预测误差的变化较大而不稳定。
6.3 K倍交叉验证
步骤如下:
1) 将数据集随机拆分为k个子集(或k倍)(例如5个子集)
2) 保留一个子集(测试集)并在所有其他子集(训练集)上训练模型
3) 在保留的子集(测试集)上测试模型并记录预测误差
4) 重复此过程,直到k个子集中的每一个都已用作一次测试集
5) 计算k个记录的预测误差的平均值。它代表了交叉验证误差情况,是评价模型性能的指标。
那么,如何选择正确的k值呢?
较低的k值更容易产生偏差;较高的k值偏差较小,但可能会出现较大的可变性。
当k值较小时,(例如k = 2)趋向于验证集方法;
当k值较大时,(例如k =数据样本点数)即为单个剔除交叉验证。
set.seed(123)
train.control<-trainControl(method= "cv",number= 10)
model<-train(Fertility~.,data=swiss,method= "lm", trControl=train.control)
print(model)
Linear Regression
47 samples
5 predictor
No pre-processing
Resampling: Cross-Validated (10 fold)
Summary of sample sizes: 43, 42, 42, 41, 43, 41, ...
Resampling results:
RMSE Rsquared MAE
7.38 0.751 6.03
Tuning parameter 'intercept' was held constant at a value of TRUE
6.4 重复K倍交叉验证
set.seed(123)
train.control<-trainControl(method= "repeatedcv", number= 10,repeats= 3)
model<-train(Fertility~.,data=swiss,method= "lm", trControl=train.control)
print(model)
Linear Regression 47 samples 5 predictor
No pre-processing
Resampling: Cross-Validated (10 fold, repeated 3 times)
Summary of sample sizes: 43, 42, 42, 41, 43, 41, ...
Resampling results: RMSE Rsquared MAE
7.319331 0.688556 6.093787
Tuning parameter 'intercept' was held constant at a value of TRUE