查看原文
其他

基于网格搜索优化的SVM分类预测

爬虫俱乐部 Stata and Python数据分析 2023-10-24

本文作者:许林丽,中南财经政法大学统计与数学学院

本文编辑:周一鸣

技术总编:孙一博

Stata and Python 数据分析

爬虫俱乐部Stata基础课程Stata进阶课程Python课程可在小鹅通平台查看,欢迎大家多多支持订阅!如需了解详情,可以通过课程链接(https://appbqiqpzi66527.h5.xiaoeknow.com/homepage/10)或课程二维码进行访问哦~
引言


支持向量机(SVM)是一种二分类模型,可以在高维或无限维空间中找到间隔最大的超平面进行分类,并使用这些超平面对数据点进行分类。 支持向量机模型分为线性支持向量机和非线性支持向量机,在实际应用中大多数问题都是非线性的,对于这种情况,我们可以引入核函数。核函数的本质是将每个样本点从低维空间映射到高维空间来使其能够被线性分离,从而可以处理线性不可分割的数据。常用的核函数主要包括线性核函数(linear)、多项式核函数(poly)、高斯径向基核函数(RBF)和 Sigmoid 核函数 。

      网格搜索(GridSearch)是一种调参方法,基本原理是在所有候选的参数选择中,通过循环遍历,尝试每一种可能性,表现最好的参数就是最终的结果,因此也被称为“穷举搜索”和“暴力搜索”。 此外,使用交叉验证可以使得评分更加严谨,因此交叉验证经常与网格搜索一起结合使用,即GridSearchCV,最后可以从列出的超参数中选择最佳参数。可以看出,需要遍历所有可能的参数组合的网格搜索的缺点就是非常耗时!!特别是在处理大数据集和多参数时。

      本次小编将带着大家学习如何用支持向量机对数据集进行分类预测,并且使用网格搜索进行调参,使模型更加精确。

说明


01数据说明
本次分类预测使用的是威斯康星州乳腺癌(Breast Cancer)数据集,该数据集一共有569个样本,30个特征,标签为二分类。其中良性benign为357个,恶性malignant为212个。下图是乳腺癌数据集展示以及30个特征的具体描述。

02参数说明
(1)SVC语法在python中,可以通过调用sklearn模块下的svm.SVC()函数实现支持向量分类的基本功能,函数参数说明如下:
sklearn.svm.SVC(C=1.0,kernel='rbf', degree=3, gamma='auto',coef0=0.0,shrinking=True,probability=False,tol=0.001,cache_size=200, class_weight=None,verbose=False,max_iter=-1,decision_function_shape=None,random_state=None)C:目标函数的惩罚系数,默认值为1.0。C越大,表示在训练样本中准确率越高,但泛化能力低,即对测试数据的分类准确率降低。kernel:核函数类型,默认为‘rbf’。常用的可选参数有linear(线性核函数)、poly(多项式核函数)、rbf(径像核函数/高斯核)和sigmoid(双曲正切核函数)。degree:使用kernel为 ‘poly’时,给定多项式的项数,默认为3。若指定kernel为其他核函数则忽略该参数。gamma:表示当kernel为‘rbf’, ‘poly’或‘sigmoid’时的kernel系数,默认为 ‘auto’,即样本特征数的倒数。coef0:核函数的常数项,只有在 kernel为‘poly’或‘sigmoid’时有效,默认为0.0。shrinking:是否采用启发式,默认为True。probability:是否启用概率估计,默认为False。tol:训练结束要求的精度,默认为0.001。cache_size:指定训练所需要的内存,以MB为单位,默认为200MB。class_weight:给定各个类别的权重,默认为1。verbose:是否详细输出训练过程,默认为False。max_iter:最大迭代次数,默认为-1,表示无穷大迭代次数。decision_function_shape:多分类时选择的方式,有‘ovo’、‘ovr’和None三种,默认为None。random_state:将训练集打乱顺序时使用的伪随机数生成器的种子,默认为None。主要需要调节的参数有:C、kernel、degree、gamma、coef0。(2)GridSearchCV语法GridSearchCV是python里sklearn库中的一个函数,常用的GridSearchCV的参数说明如下。
sklearn.model_selection.GridSearchCV(estimator, param_grid=None, scoring=None, cv=None, verbose=0)estimator:参数针对的搜索对象,即所使用的分类器。param_grid:需要最优化的参数的取值,值可以是字典或者列表。scoring:模型评价标准,根据所选模型不同,评价准则不同。默认None,使用estimator的误差估计函数。n_jobs:与并行运行相关,可以提高搜索速度,取值为整数,默认为1,大于1的整数表示运行核数(但不能超过运行主机有的核数),取-1则代表使用主机所有的核数。cv:交叉验证参数,默认None,使用三折交叉验证。verbose:日志冗长度,取值为整数。取值为0,则不输出训练过程;取值为1,则偶尔输出;取值大于1,则对每个子模型都输出。常用评价标准参数说明:grid_search.best_estimator_:查看带有最优超参的搜索器的相关信息。grid_search.best_score_:查看当前最优超参情况下的得分。grid_search.best_params_:输出当前由最优的超参及其取值组成的字典。算法实现


01导入数据
import numpy as npimport pandas as pdfrom sklearn.datasets import load_breast_cancer ##导入数据集from sklearn import svmfrom sklearn.svm import SVC ##导入SVC函数from sklearn.metrics import classification_report ##导入模型评估函数from sklearn.model_selection import GridSearchCV ##导入网格搜索函数from sklearn.model_selection import train_test_split ##导入划分数据集函数from sklearn.metrics import roc_curve, auc ##用于计算roc和auc
data = load_breast_cancer() ##乳腺癌数据集X = data.data ##数据特征Y = data.target ##数据标签02划分训练集、验证集和测试集验证集可以用于调整模型的超参数和用于对模型的能力进行初步评估,常用来在模型迭代训练时,验证当前模型泛化能力(准确率,召回率等),防止过拟合的现象出现,并决定如何调整超参数。对于小规模样本集,常用的分配比例是 60% 训练集(train)、20% 验证集(val)、20% 测试集(test)。
x_train, x_test_val, y_train, y_test_val = train_test_split(X, Y, test_size=0.4, random_state=0)x_val, x_test, y_val, y_test = train_test_split(x_test_val, y_test_val, test_size=0.5, random_state=0)03数据标准化处理
def zscore_normalize_features(X): mu = np.mean(X, axis=0) sigma = np.std(X ,axis=0) X_norm = (X - mu) / sigma return (X_norm)
x_train = zscore_normalize_features(x_train)x_val = zscore_normalize_features(x_val)x_test = zscore_normalize_features(x_test)04模型训练

机器学习中常使用精确度、 查准率、召回率以及 F1 得分作为评估分类效率的评价指标。在这里,我们使用精确度和F1得分来对模型性能进行评估。

print("————————————调参前————————————")clf = svm.SVC()clf.fit(x_train, y_train)predictions = clf.predict(x_val)print(classification_report(y_val, predictions))使用网格搜索对 ‘kernel’、‘C’、‘degree’和‘gamma’进行调参,选择表现最好的参数,结果为‘kernel’选择‘rbf'核函数,‘C’取值为10,‘gamma’取值为0.01。
params = [{'kernel':['linear'], 'C':[1,10,100]},{'kernel':['poly'], 'C':[1,10], 'degree':[2,3]},{'kernel':['rbf'], 'C':[1,10,100], 'gamma':[1, 0.1, 0.01, 0.001]}]model = GridSearchCV(SVC(), param_grid=params, cv=5)model.fit(x_train, y_train)print(model.best_params_)

利用验证集验证调参后模型精度是否有提高,结果可以看出精确度由调参前的0.97提高到了0.98,说明分类正确的样本比例增加。F1得分由调参前的0.98提高到了0.99,说明在调整参数后,模型的拟合效果提高,具有较好的可解释性。
print("————————————调参后————————————")predictions_val = model.predict(x_val)print(classification_report(y_val, predictions_val))

最后,利用测试集对已训练好的模型做最后评估,由结果可以看出模型精度和F1得分较高,说明模型分类效果较好。

print("————————————测试集————————————")predictions_test = model.predict(x_test)print(classification_report(y_test, predictions_test))

END


重磅福利!为了更好地服务各位同学的研究,爬虫俱乐部将在小鹅通平台上持续提供金融研究所需要的各类指标,包括上市公司十大股东、股价崩盘、投资效率、融资约束、企业避税、分析师跟踪、净资产收益率、资产回报率、国际四大审计、托宾Q值、第一大股东持股比例、账面市值比、沪深A股上市公司研究常用控制变量等一系列深加工数据,基于各交易所信息披露的数据利用Stata在实现数据实时更新的同时还将不断上线更多的数据指标。我们以最前沿的数据处理技术、最好的服务质量、最大的诚意望能助力大家的研究工作!相关数据链接,请大家访问:(https://appbqiqpzi66527.h5.xiaoeknow.com/homepage/10)或扫描二维码:

最后,我们为大家揭秘雪球网(https://xueqiu.com/)最新所展示的沪深证券和港股关注人数增长Top10。




对我们的推文累计打赏超过1000元,我们即可给您开具发票,发票类别为“咨询费”。用心做事,不负您的支持!







往期推文推荐Stata18之dtas——The new in data management
定制属于自己的“贾维斯”——Python调用Chat
学会format,数据格式任你拿捏【Python实战】游客最青睐的城市,你的家乡上榜了吗?

What’ new ? 速通Stata 18

【爬虫实战】Python爬取美食菜谱揭秘网络中心人物,你会是其中之一吗?考研之后,文科生需以“do”躬“do”!焕新升级!轻松获取港股、权证的历史交易数据爬虫俱乐部的精彩答疑---cntraveltime【爬虫俱乐部新命令速递】在Stata中与ChatGPT对话

用`fs`命令批量获取文件夹和不同文件夹下的excel文件

自然语言处理之实例应用

JSON帮手,FeHelper

最新、最热门的命令这里都有!

Python实现微信自动回复告诉python,我想“狂飙”了——线程池与异步协程为爬虫提速高级函数——map()和reduce()

Stata绘制条形图的进阶用法

快来看看武汉的房价是不是又双叒叕涨了!
     关于我们 

   微信公众号“Stata and Python数据分析”分享实用的Stata、Python等软件的数据处理知识,欢迎转载、打赏。我们是由李春涛教授领导下的研究生及本科生组成的大数据处理和分析团队。

   武汉字符串数据科技有限公司一直为广大用户提供数据采集和分析的服务工作,如果您有这方面的需求,请发邮件到statatraining@163.com,或者直接联系我们的数据中台总工程司海涛先生,电话:18203668525,wechat: super4ht。海涛先生曾长期在香港大学从事研究工作,现为知名985大学的博士生,爬虫俱乐部网络爬虫技术和正则表达式的课程负责人。



此外,欢迎大家踊跃投稿,介绍一些关于Stata和Python的数据处理和分析技巧。

投稿邮箱:statatraining@163.com投稿要求:1)必须原创,禁止抄袭;2)必须准确,详细,有例子,有截图;注意事项:1)所有投稿都会经过本公众号运营团队成员的审核,审核通过才可录用,一经录用,会在该推文里为作者署名,并有赏金分成。2)邮件请注明投稿,邮件名称为“投稿+推文名称”。3)应广大读者要求,现开通有偿问答服务,如果大家遇到有关数据处理、分析等问题,可以在公众号中提出,只需支付少量赏金,我们会在后期的推文里给予解答。

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存