概述
调用sklearn.linear_model中的LogisticRegression库,尝试通过对率回归对离散数据进行划分,对每个属性进行预测,选取正确率最大的属性作为根节点,并对该节点的每个属性取值进行划分选择,依此类推,最终绘制一棵决策树。
程序功能
对于给定西瓜数据集3.0,将字符串类型的属性取值转换为数值类型以便模型进行训练,并将连续属性离散化以便选取划分点,通过正确率来选取根节点,最终得到决策树数组。通过dealanddraw(n0, pngname)函数将数组转化为字典类型,绘制决策树,将决策树以图片形式保存在程序的同一目录下。
程序数据及代码
- 主程序
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 21 11:57:22 2019
@author: lazyn
"""
import os
from sklearn.linear_model import LogisticRegression
import numpy as np
import pandas as pd
import warnings
from createPlot import createPlot
import matplotlib.pyplot as plt  
warnings.filterwarnings("ignore")
#定义连续值处理函数
def con_deal(temp_df, a):
    for j in range(0, len(temp_df)):
        temp_df.iat[j] = 0 if(temp_df.iat[j] < a) else 1
    return temp_df
#定义计算连续值正确率的函数
def con_acc(data, Y):
    a = np.sort(np.array(data))
    a = (a[0: len(a) - 1] + a[1: len(a)])/2
    max_acc, ind = 0, 0
    for i in range(0, len(a)):
        temp_df = con_deal(data.copy(), a[i]) 
        X0 = np.array(temp_df).reshape(-1, 1)
        logreg = LogisticRegression()
        logreg.fit(X0, Y)
        acc = logreg.score(X0, Y)
        if max_acc < acc:
            max_acc = acc
            ind = i
            temp_df0 = X0
    print(round(max_acc, 3), end = ', 判断结果为:\n')
    print(logreg.predict(temp_df0))
    return [max_acc, a[ind]]
#获取根节点函数
def getroot(X1, Y1, m):
    max_acc = 0
    for i in m:
        if i != '密度' and i != '含糖率':
            print(i + '节点, 正确率为', end = ':')
            X0 = np.array(X1[i]).reshape(-1, 1) 
            logreg = LogisticRegression()
            logreg.fit(X0, Y1)
            acc = logreg.score(X0, Y1)
            print(round(acc, 3), end = ', 判断结果为:\n')
            print(logreg.predict(X0))
            if max_acc < acc:
                max_acc = acc
                root = i
        else:
            print(i + '节点, 正确率为', end = ':')
            acc = con_acc(X1[i], Y1)[0]
            if max_acc < acc:
                max_acc = acc
                root = i
    return root
#获取决策树数组函数
def gettree(X, Xo, Y, m):
    n1, n2 = [], []  
    root = getroot(X, Y['好瓜'], m)
    print('故选择' + root + '为根节点')
    n1.append(root)
    m.remove(root) 
    if root == '密度' or root == '含糖率':
        div = con_acc(X[root], Y['好瓜'])[1]
        X[root], Xo[root], Y[root] = con_deal(X[root], div), con_deal(Xo[root], div), con_deal(X[root], div)
#    print(X, Xo)
    Attr, Attro = X[root].unique(), Xo[root].unique()
#    print(Attr, Attro)
    for j, jo in zip(Attr, Attro):
        n3 = []
        if root == '密度' or root == '含糖率':
            if j >= div:
                key = '≥' + str(div)
            else:
                key = '<' + str(div)
        else:
            key = jo
        print(root + '为' + key + '时:')  
        n3.append(key)
        X1 = X[X[root] == j]
        Xo1 = Xo[Xo[root] == jo]
        Y0 = Y[Y[root] == j]
        Y1 = Y0['好瓜']
        if Y1.unique().size > 1:
            Xn, Xon, Yn = X1, Xo1, Y0
            n3.append(gettree(Xn, Xon, Yn, m))   
        else:
            flag = '好瓜' if Y1.unique() == '是' else '坏瓜'
            print(flag)
            n3.append(flag)
        n2.append(n3)
    n1 += n2
    return n1        
 
#数组处理及绘制函数
def dealanddraw(n0, pngname):
    alstr = str(n0)
    alstr = alstr.replace(',', ':'); alstr = alstr.replace(']: [', ',')
    alstr = alstr.replace(']:', '],')
    alstr = alstr.replace('[', '{'); alstr = alstr.replace(']', '}')
    inTree = eval(alstr)
#    print(inTree)
    plt.figure(figsize = (10, 7))
    createPlot(inTree)
#    dpi, 控制每英寸长度上的分辨率;bbox_inches, 能删除figure周围的空白部分
    plt.savefig(pngname, dpi = 400, bbox_inches = 'tight')
        
f = open('watermelon3.txt')
watermelon3_df = pd.read_table(f)
Xo = watermelon3_df[['色泽',	'根蒂', '敲声', '纹理', '脐部', '触感', '密度', '含糖率']]
m = list(watermelon3_df.columns)
h = 0.001
for i in m:
    if i != '密度' and i != '含糖率' and i != '好瓜':
        size_mapping = {}
        m0 = watermelon3_df[i].unique()
        j = 1
        for i0 in m0:
            size_mapping[i0] = j
            j += 1
#        print(size_mapping)
        watermelon3_df[i] = watermelon3_df[i].map(size_mapping)
        
X = watermelon3_df[['色泽',	'根蒂', '敲声', '纹理', '脐部', '触感', '密度', '含糖率']]
Y = watermelon3_df
m = list(X.columns)
n0 = gettree(X, Xo, Y, m)
pngname = os.path.basename(os.path.realpath(__file__)).replace('py', 'png')
dealanddraw(n0, pngname)- 决策树绘制程序
参考自决策树的绘制_TaoTaoFu的博客-CSDN博客_绘制决策树,并做部分修改后代码如下:
import matplotlib.pyplot as plt
#用来正常显示中文
plt.rcParams['font.sans-serif'] = ['SimHei']
#用来正常显示负号
plt.rcParams['axes.unicode_minus'] = False
#设置画节点用的盒子的样式
decisionNode = dict(boxstyle = "sawtooth", color = '#3366FF')
leafNode = dict(boxstyle = "round4", color = '#FF6633')
#设置画箭头的样式
arrow_args = dict(arrowstyle="<-", color='g')
def getNumLeafs(myTree):
    #初始化树的叶子节点个数
    numLeafs = 0
    #myTree.keys()获取树的非叶子节点'no surfacing'和'flippers'
    #list(myTree.keys())[0]获取第一个键名'no surfacing'
    firstStr = list(myTree.keys())[0]
    #通过键名获取与之对应的值,即{0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}
    secondDict = myTree[firstStr]
    #遍历树,secondDict.keys()获取所有的键
    for key in secondDict.keys():
        #判断键是否为字典,键名1和其值就组成了一个字典,如果是字典则通过递归继续遍历,寻找叶子节点
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key])
        #如果不是字典,则叶子结点的数目就加1
        else:
            numLeafs += 1
    #返回叶子节点的数目
    return numLeafs
def getTreeDepth(myTree):
    #初始化树的深度
    maxDepth = 0
    #获取树的第一个键名
    firstStr = list(myTree.keys())[0]
    #获取键名所对应的值
    secondDict = myTree[firstStr]
    #遍历树
    for key in secondDict.keys():
        #如果获取的键是字典,树的深度加1
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        #去深度的最大值
        if thisDepth > maxDepth : maxDepth = thisDepth
    #返回树的深度
    return maxDepth
#绘图相关参数的设置
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    '''
    annotate函数是为绘制图上指定的数据点xy添加一个nodeTxt注释
    nodeTxt是给数据点xy添加一个注释,xy为数据点的开始绘制的坐标,位于节点的中间位置
    xycoords设置指定点xy的坐标类型,xytext为注释的中间点坐标,textcoords设置注释点坐标样式
    bbox设置装注释盒子的样式,arrowprops设置箭头的样式
    '''
    '''
    figure points:表示坐标原点在图的左下角的数据点
    figure pixels:表示坐标原点在图的左下角的像素点
    figure fraction:此时取值是小数,范围是([0,1],[0,1]),在图的左下角时xy是(0,0),最右上角是(1,1)
    其他位置是按相对图的宽高的比例取最小值
    axes points : 表示坐标原点在图中坐标的左下角的数据点
    axes pixels : 表示坐标原点在图中坐标的左下角的像素点
    axes fraction : 与figure fraction类似,只不过相对于图的位置改成是相对于坐标轴的位置
    '''
    createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction', 
                            xytext = centerPt, textcoords = 'axes fraction',
                            va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)
#绘制线中间的文字(0和1)的绘制
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]   #计算文字的x坐标
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]   #计算文字的y坐标
    createPlot.ax1.text(xMid, yMid, txtString, va = "center", ha = "center", rotation = 20)
    
#绘制树
def plotTree(myTree, parentPt, nodeTxt):
    #获取树的叶子节点
    numLeafs = getNumLeafs(myTree)
    #获取树的深度
    depth = getTreeDepth(myTree)
    #firstStr = myTree.keys()[0]
    #获取第一个键名
    firstStr = list(myTree.keys())[0]
    #计算子节点的坐标
    cntrPt = (plotTree.xoff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yoff)
    #绘制线上的文字
    plotMidText(cntrPt, parentPt, nodeTxt)
    #绘制节点
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    #获取第一个键值
    secondDict = myTree[firstStr]
    #计算节点y方向上的偏移量,根据树的深度
    plotTree.yoff = plotTree.yoff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            #递归绘制树
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            #更新x的偏移量,每个叶子结点x轴方向上的距离为 1/plotTree.totalW
            plotTree.xoff = plotTree.xoff + 1.0 / plotTree.totalW
            #绘制非叶子节点
            plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode)
            #绘制箭头上的标志
            plotMidText((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
    plotTree.yoff = plotTree.yoff + 1.0 / plotTree.totalD
#绘制决策树
def createPlot(inTree):
    #清除figure
    plt.clf()
    axprops = dict(xticks = [], yticks = [])
    #创建一个1行1列1个figure,并把网格里面的第一个figure的Axes实例返回给ax1作为函数createPlot()
    #的属性,这个属性ax1相当于一个全局变量,可以给plotNode函数使用
    createPlot.ax1 = plt.subplot(frameon = False, **axprops)
    #获取树的叶子节点
    plotTree.totalW = float(getNumLeafs(inTree))
    #获取树的深度
    plotTree.totalD = float(getTreeDepth(inTree))
    #节点的x轴的偏移量为-1/plotTree.totlaW/2,1为x轴的长度,除以2保证每一个节点的x轴之间的距离为1/plotTree.totlaW*2
    plotTree.xoff = -0.5/plotTree.totalW
    plotTree.yoff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()
#inTree = {'色泽': {'青绿': {'敲声': {'浊响': '好瓜', '清脆': '坏瓜', '沉闷': '坏瓜'}},\
#  '乌黑': {'根蒂': {'蜷缩': '好瓜', '稍蜷': {'纹理': {'稍糊': '好瓜', '清晰': '坏瓜'}}}},\
#  '浅白': '坏瓜'}}
#createPlot(inTree)- 西瓜数据集3.0
色泽	根蒂	敲声	纹理	脐部	触感	密度	含糖率	好瓜
青绿	蜷缩	浊响	清晰	凹陷	硬滑	0.697	0.46	是
乌黑	蜷缩	沉闷	清晰	凹陷	硬滑	0.774	0.376	是
乌黑	蜷缩	浊响	清晰	凹陷	硬滑	0.634	0.264	是
青绿	蜷缩	沉闷	清晰	凹陷	硬滑	0.608	0.318	是
浅白	蜷缩	浊响	清晰	凹陷	硬滑	0.556	0.215	是
青绿	稍蜷	浊响	清晰	稍凹	软粘	0.403	0.237	是
乌黑	稍蜷	浊响	稍糊	稍凹	软粘	0.481	0.149	是
乌黑	稍蜷	浊响	清晰	稍凹	硬滑	0.437	0.211	是
乌黑	稍蜷	沉闷	稍糊	稍凹	硬滑	0.666	0.091	否
青绿	硬挺	清脆	清晰	平坦	软粘	0.243	0.267	否
浅白	硬挺	清脆	模糊	平坦	硬滑	0.245	0.057	否
浅白	蜷缩	浊响	模糊	平坦	软粘	0.343	0.099	否
青绿	稍蜷	浊响	稍糊	凹陷	硬滑	0.639	0.161	否
浅白	稍蜷	沉闷	稍糊	凹陷	硬滑	0.657	0.198	否
乌黑	稍蜷	浊响	清晰	稍凹	软粘	0.36	0.37	否
浅白	蜷缩	浊响	模糊	平坦	硬滑	0.593	0.042	否
青绿	蜷缩	沉闷	稍糊	稍凹	硬滑	0.719	0.103	否结果
 

   对率回归决策树 
 
不足
对于正确率相同的节点,选取优先遍历的属性作为根节点,与基于信息增益进行划分选择的方法相比,通过下图,可知两种方法绘制的决策树正确率均为100%,但对率回归方法容易忽略在同一正确率下划分较佳的节点,从而使决策树层数增多,变得更加复杂。
 

   信息增益决策树 
 
数据集来源
《机器学习》周志华










