查看原文
其他

第8.5节 从零实现ID3与C4.5算法

空字符 月来客栈 2024-01-21

各位朋友大家好,欢迎来到月来客栈,我是掌柜空字符。

本期推送内容目录如下,如果本期内容对你有所帮助,欢迎点赞、转发支持掌柜!

  • 8.5 从零实现ID3与C4.5算法
    • 8.5.1 节点定义实现
    • 8.5.2 信息熵与条件熵实现
    • 8.5.3 决策树构建实现
    • 8.5.4 决策树遍历实现
    • 8.5.5 样本预测实现
    • 8.5.6 使用示例
    • 8.5.7 剪枝判断实现
    • 8.5.8 剪枝过程实现
    • 8.5.9 小结
    • 引用

8.5 从零实现ID3与C4.5算法

在前面几节内容中,笔者详细介绍了ID3与C4.5决策树算法的原理与计算示例,并且还介绍了如何借助开源的sklearn框架来完成整个建模的搭建流程。在接下来的这节内容中,笔者将会详细地来介绍如何从零一步步地实现ID3与C4.5这两种决策树算法。

8.5.1 节点定义实现

由ID3与C4.5决策树算法的原理可知,两则的唯一差别就是体现在对于特征划分标准的不同上,前者采用的信息增益,而后者则采用的是信息增益比来进行判断。因此,两者在代码实现时只需要将这部分内容单独抽象成一个函数即可,其它部分的代码可以保持不变。本节所有实现代码可参见Book/Chapter08/C13_id3_categorical.py文件。

在实现决策树之前,我们需要先来定义决策树中每个节点的构成。同时,参考图8-7中的节点信息,这里将决策树的节点定义为如下形式:

 1 class Node(object):
 2     def __init__(self, ):
 3         self.sample_index = None  # 保存当前节点中对应样本在数据集中的索引
 4         self.values = None  # 保存每个类别的数量
 5         self.features = None  # 保存当前节点状态时特征集中剩余特征
 6         self.feature_id = -1  # 保存当前节点对应划分特征的id
 7         self.label = None  # 保存当前节点对应的类别标签(叶子节点才有)
 8         self.n_samples = 0  # 保存当前节点对应的样本数量
 9         self.children = {}  # 保存当前节点对应的孩子节点
10         self.criterion_value = 0.
11         self.n_leaf = 0  # 以当前节点为根节点时其叶子节点的个数
12         self.leaf_costs = 0. 

在上述代码中,第3行sample_index用来保存当前节点中对应样本在数据集中的索引,这样我们在需要的时候可以直接通过索引去取到对应的样本而不是保存到每个节点中,同时也方便根据索引来取对应的样本标签;第4行values用来保存每个类别的数量,例如[10,4,6]则表示第0、1、2这三个类别在当前节点中的数量分别是10、4和6,其作用是根据这一结果可以知道当前叶子节点所代表的类别;第5行features用于保存在当前节点状态时特征集中剩余特征维度(即还剩下哪些特征没有被用于前面的划分中),例如[0,2,3]则表示对于当前节点来说,其备选特征为第0、2和3个;第6行feature_id用于保存当前节点对应划分特征的id,因为在决策树预测阶段时需要知道当前节点是用哪个特征来进行划分的;第7行label用来保存当前节点对应的类别标签(叶子节点才有),不过这个可要可不要,因为通过前面的values就已经能够得到当前叶子节点的类别;第8行n_samples用来保存当前节点对应的样本数量,用于分析观察;第9行children用来保存当前节点对应的所有孩子节点,因为利用ID3和C4.5生成的决策树为n叉树,所以笔者这里定义了一个字典来进行存储,其中key为特征取值,value为对应的孩子节点,值得一提的是在sklearn框架中均采用的是二叉树来进行实现;第10行criterion_value则是用来保存当前节点对应的信息熵;第11行是记录以当前节点为根节点时其叶子节点的个数;第12行是记录以当前节点为根节点时其所有叶子节点的损失和,这两行主要用于后续剪枝部分的代码实现。

同时,为了在打印输出构建好的决策树的能够更加方便,我们需要定义如下方法,代码如下:

 1     def __str__(self):
 2         return f"<======================>\n" \
 3                f"当前节点所有样本的索引({self.sample_index})\n" \
 4                f"当前节点的样本数量({self.n_samples})\n" \
 5                f"当前节点每个类别的样本数({self.values})\n" \
 6                f"当前节点对应的信息增益(比)({round(self.criterion_value, 3)})\n" \
 7                f"当前节点状态时特征集中剩余特征({self.features})\n" \
 8                f"当前节点状态时划分特征ID({self.feature_id})\n" \
 9                f"当前节点对应的类别标签为({self.label})\n" \
10                f"当前节点为根节点对应孩子节点数为({self.n_leaf})\n" \
11                f"当前节点为根节点对应孩子节点损失为({self.leaf_costs})\n" \
12                f"当前节点对应的孩子为({self.children.keys()})"

这里值得一提的是__str__方法是Python中每个类对象都有的一个方法,只是默认情况下没有进行实现。__str__方法的作用是在通过print函数打印类的实例化对象时输出的便是上面我们定义的信息,而不是默认情况下的对象内存地址。

为你认可的知识付费,欢迎订阅本专栏阅读更多优质内容!

例如:

继续滑动看下一个

第8.5节 从零实现ID3与C4.5算法

空字符 月来客栈
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存