查看原文
其他

“抽丝剥茧”,层层下分——机器学习基本算法之决策树

爬虫俱乐部 Stata and Python数据分析 2022-03-15

本文作者:王   歌

文字编辑:孙晓玲

技术总编:张   邯

写在前面:



在《“物以类聚”、“近朱者赤”——机器学习初探之KNN中我们一起了解了机器学习的基本概念以及KNN算法的思想,并将该算法应用到了一个小例子中,我们知道KNN算法并没有显式训练过程,因而存储训练集进行判断时会占用很大内存,尽管这样训练时间为0,但我们还是想改变一下,毕竟在实际中我们的数据不止150个鸢尾花。今天我们要学习的决策树算法则不同,虽然都属于有监督学习,都可以处理分类和回归的问题,但这一算法不仅可以通过训练建立模型,而且得到的模型很直观,下面就让我们一起来看看它到底是怎么进行划分的。


算  法  介  绍 


我们在生活中都会有这样的思维方式:出门之前看一眼天气,如果是阴雨天就带伞,晴天就不带伞,其实这就是一个很简单的决策树。我们通过对天气的判断,得出带或不带伞的决策。决策树,顾名思义,它的思想就是通过对样本按某些属性进行层层划分,最后得到每种划分的类别标签(在上面的例子中就是是否带伞),生成类似于树状图的模型结构。在决策树中,我们将整个样本称为根节点,从根结点出发,每个用来做出判断、可以继续向下分的结点称为决策结点,不再向下分割的结点称为叶结点。当因变量连续时,生成回归树,不连续则生成分类树。
根据选择划分特征的方法的不同,可以分为三种算法:ID3、C4.5和CART。我们分别来介绍一下。
(1)ID3算法
其算法的核心是不断在各结点上利用信息增益准则从样本特征集中选择未使用过的特征来对样本进行划分,直至样本全部分完或信息增益达不到要求时停止。这里要解释两点,一是信息增益准则,要了解信息增益,首先要知道熵,在自然科学中我们用熵来衡量系统的无序程度,类比过来,我们认为当某结点下的样本均为同一类时无序程度最小,纯度最大,根据这一思路,我们定义随机变量X的取值概率为,则X的熵定义为:

可以看到熵的计算本身是一种期望值的计算,所以在样本集合D中,若我们以某个属性特征A作为划分属性,该属性有m个取值,其中类别i的样本比例为,对应的样本集为,则特征A对数据集D的条件熵为:

判断在划分前后熵值的变化即为信息增益:

信息增益越大,利用A属性划分后纯度的提高程度越大。二是要注意一旦按某特征划分数据后,该特征将不会在后面的算法中再次作为划分的结点出现,并且生成的每个结点可以有多个分支,同时更倾向于选取取值较多的属性作为划分结点。为克服这一缺点,产生了C4.5算法

(2)C4.5算法

该算法在ID3的基础上,使用信息增益率:


作为划分的标准,其中为属性A作为随机变量时的熵。但由于加入惩罚参数后信息增益率倾向于选择取值少的属性,因此通常不直接使用它,而是先选出信息增益高出平均水平的属性,再从中选择信息增益率高的。相比于ID3,除了克服了选择取值多的属性值的缺点外,还可以处理非离散数据和不完整数据。

(3)CART算法

相比于ID3和C4.5,CART构建的则是二叉树,每个结点有两个分支。这一算法首先计算了数据集D的纯度(称为基尼值):


其中为属于类别k的概率,因此属性A的基尼指数:


由于基尼值表示两样本属于不同类别的概率,因此基尼值越小纯度越高,所以我们要选择基尼指数最小的属性来进行划分。这一算法既可以用于分类也可以用于回归,相较于前两种算法也更为常用。

在训练决策树的过程中,很可能因为训练样本中的噪声数据,或者样本本身存在总体所不具备的特征,这时可能会影响到决策树的分支,发生过拟合的现象。所以在使用这一算法时,通常会对决策树进行“剪枝”,包括预剪枝和后剪枝两种。预剪枝就是在每个结点划分前先对其进行统计显著性、信息增益等指标估计,若划分后这些指标值低于预先设定的阈值,则停止划分,但通常阈值的选取较为困难,也容易发生欠拟合的情况;后剪枝则是在生成决策树后,自下而上对每个非叶结点进行检查,若将该非叶结点换成叶结点后的性能优于非叶结点,则剪掉这一枝,并将其改为叶结点,但这样做要比预剪枝花费的时间更长并且仍然有发生过拟合的风险。

算  法  实  例 


在了解了这一算法的基本理论以后,我们来看看在sklearn库中具体是如何操作的。与上节相同,我们所使用的数据依然是sklearn中自带的iris鸢尾花数据,并且以分类问题为例,回归可自行尝试。首先我们导入所使用的库和数据,程序如下:

from sklearn.datasets import load_irisfrom sklearn.tree import DecisionTreeClassifierfrom sklearn.model_selection import train_test_splitiris_sample = load_iris()  # 导入数据集

这里我们首先介绍一下所用的分类器DecisionTreeClassifier里面的参数,其基本语法如下:

classsklearn.tree.DecisionTreeClassifier(criterion=‘gini’,splitter=‘best’,max_depth=None,min_samples_split=2,min_samples_leaf=1,min_weight_fraction_leaf=0.0,max_features=None,random_state=None,max_leaf_nodes=None,min_impurity_decrease=0.0,class_weight=None,presort=False)
其中criterion是选择划分属性的标准,默认基尼指数(‘gini’),也可以选择信息增益(‘entropy’);splitter用来确定是优先选择重要特征还是随机选择,默认选择最优划分(‘best’),也可选择随机(‘random’);max_depth用来限制树的最大深度,防止过拟合;min_samples_split表示一个结点最少应包含多少样本才会继续下分,取值可以是int或float,当传入float时向上取整;min_samples_leaf表示叶结点的最小样本数;min_weight_fraction_leaf用来限制叶结点所有样本的权重和的最小值,当小于该值是会与兄弟结点一同剪枝,默认样本的权重相同;max_features表示在寻找最优划分时搜索的最大属性个数,默认为None,即搜索全部属性;当将参数splitter设置为‘random’时,可以通过random_state设置种子,默认为None,表示使用np.random产生随机种子;max_leaf_nodes表示最大的叶结点数;min_impurity_decrease结点划分时的最小不纯度;class_weight设置样本数据中每个类的权重;presort表示在划分前是否先对训练数据进行排序。在以上介绍的基础上,我们对数据集拟合决策树,分类器使用默认参数,程序如下:
x_train, x_test, y_train, y_test = train_test_split( iris_sample.data, iris_sample.target, test_size=0.25, random_state=123)treeclf = DecisionTreeClassifier() 得到分类器treeclftreeclf.fit(x_train, y_train) # 拟合决策树y_test_pre = treeclf.predict(x_test) #预测score = treeclf.score(x_test, y_test) #评估预测结果print('测试集预测结果为:', y_test_pre)print('测试集正确结果为:', y_test)print('测试集准确度为:', score)

运行结果如下:

可以看到准确率为92.11%,有三个样本分类错误。大家也可通过自己的需要更改上面的参数来达到自己的要求,由于篇幅限制小编就不再一一演示了。同时我们在拟合了决策树后也可使用graphviz生成可视化图形,程序如下:

from sklearn import tree # 需要导入的包with open('../iris_tree.dot', 'w') as f: f=tree.export_graphviz(treeclf, out_file=f)
生成的dot文件直接打开是这样的:


对于这个文件我们可以安装graphviz后可查看树的结构,大家来试试吧。







对我们的推文累计打赏超过1000元,我们即可给您开具发票,发票类别为“咨询费”。用心做事,不负您的支持!
往期推文推荐

爬取东方财富网当日股票交易情况

stata调用python爬取时间数据——借他山之石以攻玉

全国31省GDP排行强势登场!
接力《发哨子的人》Stata版
批量实现WORD转PDF

Stata有问必答环节

我听到了企业的哀鸣
“物以类聚”、“近朱者赤”——机器学习初探之KNN
SFI:Stata与Python的数据交互手册(二)

从流调数据中寻找感染真相

熟悉又陌生的reshape

NBA球员薪资分析——基于随机森林算法(二)

NBA球员薪资分析——基于随机森林算法(一)

高亮输出之唐诗作者

湖北省各市疫情数据爬取

古代诗人总去的这些地方你一定要知道!

DataFrame数组常用方法(二)

ftools命令——畅游大数据时代的加速器

关于我们



微信公众号“Stata and Python数据分析”分享实用的stata、python等软件的数据处理知识,欢迎转载、打赏。我们是由李春涛教授领导下的研究生及本科生组成的大数据处理和分析团队。

此外,欢迎大家踊跃投稿,介绍一些关于stata和python的数据处理和分析技巧。
投稿邮箱:statatraining@163.com
投稿要求:
1)必须原创,禁止抄袭;
2)必须准确,详细,有例子,有截图;
注意事项:
1)所有投稿都会经过本公众号运营团队成员的审核,审核通过才可录用,一经录用,会在该推文里为作者署名,并有赏金分成。
2)邮件请注明投稿,邮件名称为“投稿+推文名称”。
3)应广大读者要求,现开通有偿问答服务,如果大家遇到有关数据处理、分析等问题,可以在公众号中提出,只需支付少量赏金,我们会在后期的推文里给予解答。

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存