河南南阳收割机被堵事件:官员缺德,祸患无穷

极目新闻领导公开“记者毕节采访被打”细节:他们打人后擦去指纹

突发!员工跳楼!只拿低保工资!央企设计院集体罢工!

退休后的温家宝

突发!北京某院集体罢工!

生成图片,分享到微信朋友圈

自由微信安卓APP发布,立即下载! | 提交文章网址
查看原文

Python数据分析之决策树(进阶篇)

胡萝卜酱 DataGo数据狗 2022-07-01

Python数据分析之基础篇中,我们介绍了通过信息增益划分分类,接下来我们将介绍如何构建,可视化决策树。本文的代码是基于基础篇的代码扩展的。

 递归构建决策树 


构建决策树的方法有很多,本文将采用ID3算法。ID3算法的核心是在决策树各个结点上对应信息增益准则选择特征,递归地构建决策树。具体方法是:从根结点(root node)开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子节点;再对子结点递归地调用以上方法,构建决策树;直到所有特征的信息增益均很小或没有特征可以选择为止,最后得到一个决策树。ID3相当于用极大似然法进行概率模型的选择。

首先我们创建一个数据集:

代码如下:

def createDataSet():
    dataSet = [[0000'no'],                        #数据集
            [0001'no'],
            [0101'yes'],
            [0110'yes'],
            [0000'no'],
            [1000'no'],
            [1001'no'],
            [1111'yes'],
            [1012'yes'],
            [1012'yes'],
            [2012'yes'],
            [2011'yes'],
            [2101'yes'],
            [2102'yes'],
            [2000'no']]
    labels = ['年龄''有工作''有自己的房子''信贷情况']        #特征标签
    return dataSet, labels          

然后我们将基于基础篇已经创建好的函数,来创建决策树。代码如下:

def majorityCnt(classList):
    classCount={}
    for vote in classList:    #统计classList中每个元素出现的次数
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0#返回classList中出现次数最多的元素

def createTree(dataSet, labels, featLabels):
    classList = [example[-1for example in dataSet]   #取分类标签(是否放贷:yes or no)
    if classList.count(classList[0]) == len(classList):   #如果类别完全相同则停止继续划分
        return classList[0]
    if len(dataSet[0]) == 1:   #遍历完所有特征时返回出现次数最多的类标签
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)   #选择最优特征
    bestFeatLabel = labels[bestFeat]    #最优特征的标签
    featLabels.append(bestFeatLabel)
    myTree = {bestFeatLabel:{}}      #根据最优特征的标签生成树
    del(labels[bestFeat])     #删除已经使用特征标签
    featValues = [example[bestFeat] for example in dataSet]   #得到训练集中所有最优特征的属性值
    uniqueVals = set(featValues)  #去掉重复的属性值
    for value in uniqueVals:       #遍历特征,创建决策树。                       
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
    return myTree                          

if __name__ == '__main__':
    dataSet, labels = createDataSet()
    featLabels = []
    myTree = createTree(dataSet, labels, featLabels)
    print(myTree)         

结果输出:

 可视化决策树 


我们将使用Matplotlib库创建树形图。代码如下:

from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt

def getNumLeafs(myTree):
    numLeafs = 0     #初始化叶子
    firstStr = next(iter(myTree)) 
    secondDict = myTree[firstStr]       #获取下一组字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':     #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0      #初始化决策树深度
    firstStr = next(iter(myTree)) 
    secondDict = myTree[firstStr]    #获取下一个字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':    #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth   #更新层数
    return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    arrow_args = dict(arrowstyle="<-")                                            #定义箭头格式
    font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)        #设置中文字体
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',    #绘制结点
        xytext=centerPt, textcoords='axes fraction',
        va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  #计算标注位置                   
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")       #设置结点格式
    leafNode = dict(boxstyle="round4", fc="0.8")    #设置叶结点格式
    numLeafs = getNumLeafs(myTree)       #获取决策树叶结点数目,决定了树的宽度
    depth = getTreeDepth(myTree)       #获取决策树层数
    firstStr = next(iter(myTree))         #下个字典                                                 
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)    #中心位置
    plotMidText(cntrPt, parentPt, nodeTxt)    #标注有向边属性值
    plotNode(firstStr, cntrPt, parentPt, decisionNode)   #绘制结点
    secondDict = myTree[firstStr]     #下一个字典,也就是继续绘制子结点
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD        #y偏移
    for key in secondDict.keys():                               
        if type(secondDict[key]).__name__=='dict':      #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            plotTree(secondDict[key],cntrPt,str(key))      #不是叶结点,递归调用继续绘制
        else:        #如果是叶结点,绘制叶结点,并标注有向边属性值                                             
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')    #创建fig
    fig.clf()                  #清空fig
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #去掉x、y轴
    plotTree.totalW = float(getNumLeafs(inTree))  #获取决策树叶结点数目
    plotTree.totalD = float(getTreeDepth(inTree))   #获取决策树层数
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;     #x偏移
    plotTree(inTree, (0.5,1.0), '')      #绘制决策树
    plt.show()      #显示绘制结果     

if __name__ == '__main__':
    dataSet, labels = createDataSet()
    featLabels = []
    myTree = createTree(dataSet, labels, featLabels)
    print(myTree)  
    createPlot(myTree) 

结果输出:

 结语 


也许你能够快速的把上述代码运用到自己想要的数据集上面,但其实并不理解这些函数的含义,那么也没有关系,你能通过from sklearn import tree来训练模型,然后用Graphviz来进行决策树可视化。

文章有问题?点此查看未经处理的缓存