机器学习三人行(系列八)----神奇的分类回归决策树(附代码)
系列五我们一起学习并实战了支持向量机的分类和回归,见下面链接:
机器学习三人行(系列七)----支持向量机实践指南(附代码)
文末附代码关键字,回复即可下载。
今天,我们一起学习下决策树算法,该算法和SVM一样,既可以用来分类,也可以用来回归。之前系列的文章,我们大多都是先学原理,再来实战,今天我们反着走一遭,先来实战,再看原理。因为决策树这个算法的模型是可以可视化的,所以看过模型之后,再去理解原理会easy些。今天的主要内容如下:
决策树分类实战
决策树算法简介
决策树回归实战
决策树稳定性分析
一. 决策树分类实战
决策树其实是一种很容易理解的一种算法,我们来从一个实例入手,来认识一下该算法。
1. 训练模型
该实验中,我们依然用大家熟悉的iris数据集,为了更好的可视化,我们仍然用花瓣的长和宽来进行多分类,直接调用sklearn中的DecisionTreeClassifier,如下代码:
我们看到上述分类器有很多参数,这些参数上面一般是选择默认,但是具体怎么调,需要详细深入其原理,本文主要是基于实战的,对原理不做太深入的介绍,感兴趣的请戳之前的文章链接。
一些常用的参数调试,我们下面聊,这里可视化决策树需要用到graphviz包,如果没有安装,则需要安装。可视化代码见文末关键字,公众号回复关键字下载。先看看训练好的决策树长什么样?
来,我们一起看下这个决策树是怎么分类的?
2. 模型预测
假如现在我们手里有一个样本,我们想把它进行分类,我们从根节点走(depth 0,顶部),这个节点的决策属性是看这个花的花瓣长度是否小于2.45cm,如果是,则往下走左侧(derth 1,左侧),此时我们碰到了一个叶子节点,不必再进行决策选择,因为该叶子节点就是一个类别,则该花的类别就为setosa。假如在根节点的时候,该花的花瓣长度大于2.45cm,那么我们则往右下方走,(depth 1,右侧),这里我们碰到一个子节点,该子节点决策是看花瓣的宽度是否小于1.75cm,如果是的话,则往下走(depth 2,左),分为versicolor,否则(depth 2,右)分为virginica。
知道了怎么分类一个新样本,我们来了解一下上面决策树中的参数都什么意思,以跟节点为例:
第一行为决策条件
第二行gini = 0.6667,为该节点的gini指标,即根据gini指标(也可以根据信息熵,戳上面文章链接)来进行节点的选择,gini指标的计算如下:
其中Pi,k表示第i个节点中类别k所占的比例。
第三行samples = 150表示该节点需要决策的样本数量
第四行value = [50,50,50],表示三个类别的样本数量
第五行class = setosa,表示哪一类样本数量最多
下面,我们来看下决策树的决策边界:
上图中竖直的实线表示根节点的决策(花瓣的长度<=2.45cm),右侧的水平虚线表示右侧花瓣宽度是否小于1.75cm的决策节点。如果我们的节点深度max_depth设为3的话,就会有多一层的决策边界,如上面竖直的两条虚线。
决策树当然也可以输出一个样本属于某类的概率,如果输出概率的话,首先还是需要根据节点选择走到叶子节点上,然后根据叶子节点中各类所占的比例进行概率输出。
比方说我们有一个样本:花瓣长度为5cm,宽度为1.5cm,那么走到叶子节点后的value为[0,49,5],那么对于三个类别输出的概率值为[0/54, 49/54, 5/54],如下:
通过上面的实战,对决策树有一些了解之后,知其然也要知其所以然,下面我们简单看一下它的原理,详细原理戳下面链接。
二. 决策树算法简介
2.1 节点选择原理
Scikit-Learn用CART(Classification And Regression Tree)算法来训练决策树。思想也很简单:算法用一个特征k和一个阈值t_k来把训练集分为两个子集,对子集采用同样的方法再分,直至决策树达到终止条件。那么它是根据上面来分为两个子集呢?当然是损失函数,如下:
其中G_left和G_right分别表示左右子集的gini指标,而m_left和m_right分别表示左右子集的样本数量。终止条件一般为最大深度max_depth,或者gini指标为0, 不能再分了,当然还有一些其他的参数,如min_samples_split, min_samples_leaf, min_weight_fraction_leaf, 和 max_leaf_nodes。
在Scikit-Learn中,默认节点选择标准为gini指标,这里我们可以通过修改菜蔬ctiterion将其改为entropy,即根据信息熵来进行节点选择,某一个节点的信息熵的公式如下:
其中n为参与决策的特征的个数,k为某一个特征,p_i,k表示节点i上在特征k下分类,不同类别的概率值。
那么我们到底是该选择gini指标还是信息熵作为节点选择的依据呢?事实上,大多数情况下,两者差别并不大,由于gini指标计算速度稍微快一些,所以,默认用gini指标,究其细微,两者对决策树的形状影响略微不同,gini指标一般是高频类别独树一枝,而信息熵的方法大多数生成的是平衡树。
决策树这样选择节点,那么它的时间复杂度怎么样呢?
2.2 算法复杂度
如果我们预测样本,那么我们需要从根节点遍历到叶子节点,进而得到预测结果,一般情况下,一颗决策树近似为平衡树,因此遍历该决策树(每个节点仅仅判断一个特征)所需的时间复杂度为O(log_2(m)),其中m为样本个数,所以决策树预测样本的速度还是很快的。
然而,我们训练样本的时候,由于我们需要在每一个节点上比较所有样本所有特征(除非我们设置了max_features参数),所以训练的时候还是比较耗时的,所需的时间复杂度为O(n*m*log(m)),n为用于决策的属性个数。
对于小训练集来说,Scikit-Learn可以通过预排序(参数presort=true)来进行加速,但是,对于大数据集的话,这样预排序的方法反而能增加训练时间。
2.3 正则化参数
对于决策树来说,对于训练集来说,我们几乎没有什么前提条件(像线性模型,我们常假设数据是线性的),但是如果我们对模型进行训练的时候,不加以任何限制的话,让算法根据数据进行自然训练,这样得到的模型将严重吻合训练数据,也就是造成过拟合。那么为什么线性模型不容易造成过拟合呢?
其实决策树算法是一种非参数模型算法,而线性模型属于参数模型算法,参数模型是指用代数方程、微分方程、微分方程组以及传递函数等描述的模型,建立模型在于确定已知模型结构中的各个参数。对于非参数模型算法,并不是说训练过程中没有参数限制(相反会有很多),而是书直接或间接地从实际系统的实验分析中得到的模型。由于没有参数限制,所以,非参数模型的自由度很大,一般容易造成过拟合。
那么为了避免过拟合的发生,一般在训练的过程中,进行一些参数限制来防止过拟合,常用的参数有:
max_depth: 树的最大深度;
min_samples_split: 每一个节点上最少样本数量
min_samples_leaf: 每一个叶子节点的最少样本
min_weight_fraction_leaf: 基本和上面一致,重点强调有权重的样本
max_leaf_nodes: 最大叶子节点数
max_features: 用到的最大特征数量
一般增加min_*参数或者减小max_*参数将会正则化该模型。
当然,也有一些决策树模型,并不是通过设置这些参数来进行正则化的,而是先任其生成,最后通过对生成的决策树进行剪枝,来达到正则模型的目的。详情戳下面链接。
我们看一个图(代码文末下载),下图中的左侧决策树模型是对模型的训练未加任何限制,我们可以看到起决策边界有很多,明显过拟合,而对于右侧模型,显然其泛化能力更胜一筹。
三. 决策树回归实战
决策树当然也有能力处理回归任务,接下来我们通过Scikit-Learn中的DecisionTreeRegressor来建立一颗回归树,这里的数据选择一个带噪声的二次方程,其中最大深度max_depth = 2,如下:
咋一看,上面的回归树跟前面的分类决策树蛮相似的,主要差别就是对某一个数据的预测结果是一个数值,而非一个类别。比方说,加入x=0.5,我们遍历该回归树之后得到的预测值为0.1106,其实value = 0.1106是这samples=110个样本的平均值,其(110个样本)均方误差mse=0.0151。
模型的具体预测如下图:
其中左侧为最大深度为2的回归模型,而右侧图为最大深度为3的回归模型。从上图中可以很明显的看出,其预测的值就是该区间的样本均值。
其实回归的原理和分类原理基本一致,唯一不同的地方就是损失函数的不同,这里分裂节点的依据就最小化训练集的均方误差,算法损失函数如下:
(上面函数中参数意义参考分类的损失函数,如有不解,欢迎进入公众号社区或留言交流)
和分类一样,如果不加以限制的话,回归树也容易出现过拟合,如下图,左图为对该回归不加任何参数限制,而右图则对min_samples_leaf进行设置为10.
四. 稳定性分析
现在我们基本上已经发现决策树算法不管是在分类还是回归上面都很容易理解,很好用。然而,这个算法也是有一些缺点的:
首先,我们也看到了,决策树模型的决策边界呈现出直角垂直的边界,这样导致模型对数据方向敏感,如下图:
上图左侧对数据的划分比较好,但是当我们把数据进行45°旋转之后,发现,上图右侧的决策边界呈现台阶状,虽然右侧对数据的分类还是挺不错的,但是泛化能力可能并不理想。我们可以通过降维(后面会讲)的方法进行降维,使数据有一个好的方向特征。
再一个,决策树通常对数据集中的小的变化比较敏感,比如,我们移除iris数据集中的某些数据(见代码)后,我们将会得到一个完全不同的决策树,如下图:
下集我们将介绍的集成算法可以通过多颗决策树共同决策的方法来解决稳定性的问题。
五. 小结
本节,我们先从决策树的实战开始,对决策树有一个直观的认识,介于之前已经详细介绍过决策树的原理,这里通过对其节点选择标准,时间复杂度,以及正则化模型对其原理进行简单介绍。然后我们又学习了一下决策树回归的相关知识,以及进行了相关实战。最后我们从决策树的稳定性方面对决策树模型进行了一些优缺点分析。希望通过本文我们能一起更清楚的了解决策树。
(如需更好的了解相关知识,欢迎加入智能算法社区,在“智能算法”公众号发送“社区”,即可加入算法微信群和QQ群)
本文代码回复关键字:decision tree
公众号回复关键字即可免费下载。