查看原文
其他

第5.3节 sklearn接口与示例代码

空字符 月来客栈 2024-01-19

各位朋友大家好,欢迎来到月来客栈,我是掌柜空字符。

本期推送内容目录如下,如果你觉得本期内容对你所有帮助欢迎点个赞、关个注、下回更新不迷路。为方便大家提问交流,专栏订阅用户可微信私信掌柜“MLWM”进入专属群答疑。

  • 5.3 sklearn接口与示例代码
    • 5.3.1 sklearn接口介绍
    • 5.3.2 近邻示例代码
    • 5.3.3 小结
  • 引用

5.3 sklearn接口与示例代码

5.3.1 sklearn接口介绍

在正式介绍如何使用sklearn库完成近邻的建模任务前,笔者先来总结一下sklearn的使用方法,这样更加有利于对后续内容的学习。

根据第2、3、4章中的示例代码可以发现,sklearn在实现各类算法模型时基本上遵循了统一的接口风格,这使我们在刚开始学习的时候很容易入门。总结起来,在sklearn中对于各类模型的使用,基本上遵循着以下3个步骤。

1. 建立模型

这一步通常来讲在对应的路径下导入我们需要用到的模型类,例如可以通过代码来导入一个基于梯度下降算法优化的分类器,代码如下:

from sklearn.linear_model import SGDClassifier

在导入模型类后,需要通过传入模型对应的参数来实例化这个模型,例如可以通过代码实例化一个逻辑回归模型,代码如下:

model = SGDClassifier(loss='log',penalty='l2', alpha=0.5)

同时,由于sklearn在迭代更新中可能会更改一些接口的名称或者位置,所以具体的路径信息可以查看官方的API说明文档 [1]。

2. 训练模型

在sklearn中,所有模型的训练(或者计算)过程都是通过model.fit()方法来完成的,并且一般情况下需要按实际情况在调用model.fit()时传入相应参数。如果是有监督模型,则一般是model.fit(x,y),如果是无监督模型,则一般是mode.fit(x)。同时,还可以调用model.score(x,y)来对模型的结果进行评估。

3. 模型预测

在训练好一个模型后,通常要对测试集或者新输入的数据进行预测。在sklearn中一般通过模型类对应的model.predict(x)方法实现,但这也不是绝对的,例如在对数据进行预处理时,调用model.fit()方法在训练集上计算并得到相应的参数后,往往通过model.transform()方法来对测试集(或新数据)进行变换。

总体上来讲,在sklearn中基本上算法模型可以通过上面这3个步骤来完成对模型的建立、训练与预测。

5.3.2 近邻示例代码

从5.2节的分析可知,其实近邻在分类过程中同之前算法模型不一样,即有一个训练求解参数的过程。这是因为近邻算法根本就没有可训练的参数,只有3个超参数,而近邻算法的核心在于如何快速地找到距离任意一个样本最近的个样本点。当然,最直接的办法就是遍历样本点进行距离计算,但是当样本点达到一定数量级后这种做法显然是行不通的,所以此时可以通过建立KD树(KD Tree)或者Ball树(Ball Tree)来解决这一问题。不过这里暂时不做介绍,先直接通过开源库sklearn实现。

本次示例的数据集仍旧采用第4章中所介绍的手写体分类数据集。同时,在下面示例中笔者将介绍如何通过网格搜索(Grid Search)来快速完成模型的选择,完整代码见Book/Chapter05/01_knn_train.py文件。

1. 模型选择

由于在4.6.1节中已经详细介绍了该数据集的载入和预处理方法,所以这里就不再赘述了。从上面的分析可以得知,近邻算法有两个(另外一个暂不考虑)超参数,即值和度量方式P值。假设两者的取值分别为n_neighbors = [5, 6, 7, 8, 9, 10]p=[1, 2],则此时一共有12个备选模型。同时,如果采用5折交叉验证,则一共需要进行60次模型拟合,并且需要3个循环实现。不过好在sklearn中提供了网格搜索的功能可以帮助我们通过4行代码实现上述功能,代码如下:

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
def model_selection(x_train, y_train):
    paras = {'n_neighbors': [5678910], 'p': [12]}
    model = KNeighborsClassifier()
    gs = GridSearchCV(model, paras, verbose=2, cv=5)
    gs.fit(x_train, y_train)
    print('最佳模型:', gs.best_params_, ‘准确率:’,gs.best_score_)

在上述代码中,第4行以字典的形式定义了超参数的取值情况,并且需要注意的是,字典的key必须是类KNeighborsClassifier中参数的名字。其中,类KNeighborsClassifier就是sklearn中所实现的K近邻算法模型,因此,该类也包含了K近邻中最基本的两个参数K值和P值。第5行定义了一个K近邻模型,但值得注意的是此时并没有在定义模型的时候就传入相应的参数,即以KNeighborsClassifier(n_neighbors=2, p=2)的形式(其中n_neighbors就是值)来实例化这个类。因为在使用网格搜索时,需要将模型作为一个参数传入GridSearchCV类中,同时也需要将模型对应的超参数以字典的形式传入。第6行在实例化GridSearchCV类时便传入了定义的K近邻模型及参数字典,其中verbose用来控制训练过程中输出提示信息的详细程度,cv=5表示在训练过程中使用5折交叉验证。最后,根据传入的训练集,便可以对模型进行训练。

在模型训练完成后,便可以输出最佳模型(超参数组合),以及此时对应的模型得分,结果如下:

继续滑动看下一个

第5.3节 sklearn接口与示例代码

空字符 月来客栈
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存