0
点赞
收藏
分享

微信扫一扫

初识决策树

静守幸福 2022-01-27 阅读 115

决策树的生成主要分以下两步,这两步通常通过学习已经知道分类结果的样本来实现。

1. 节点的分裂:一般当一个节点所代表的属性无法给出判断时,则选择将这一节点分成2个

子节点(如不是二叉树的情况会分成n个子节点)

2. 阈值的确定:选择适当的阈值使得分类错误率最小 (Training Error)。

比较常用的决策树有ID3,C4.5和CART(Classification And Regression Tree),CART的分类效果一般优于其他决策树。

ID3: 由增熵(Entropy)原理来决定那个做父节点,那个节点需要分裂。对于一组数据,熵越小说明分类结果越好。熵定义如下:

Entropy=- sum [p(x_i) * log2(P(x_i) ]

其中p(x_i) 为x_i出现的概率。假如是2分类问题,当A类和B类各占50%的时候,

Entropy = - (0.5*log_2( 0.5)+0.5*log_2( 0.5))= 1

当只有A类,或只有B类的时候,

Entropy= - (1*log_2( 1)+0)=0

枚举每个属性,计算熵减

熵减为父节点的熵减去所有子节点的熵

import numpy as np
import pandas as pd
import operator
from math import log

def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

def chooseMaxNum(dataset):
    cnt={}
    for i in dataset:
        if i not in cnt.keys():
            cnt[i]=0
        cnt[i]+=1
    tmp=sorted(cnt.items(),key=operator.itemgetter(1),reverse=1)
    return tmp[0][0]

def calEntropy(dataset):
    pass

def splitdata(dataset,bestFeat,value):
    tmp=[]
    for row in dataset:
        if(row[bestFeat]==value):
            atmp=row[:bestFeat]
            atmp.extend(row[bestFeat+1:])
            tmp.append(atmp)
    return tmp

def cal(dataset):
    n=len(dataset)
    cnt={}
    for row in dataset:
        if row[-1] not in cnt.keys():
            cnt[row[-1]]=0
        cnt[row[-1]]+=1
    entropy=0.0
    for row in cnt:
        p=float(cnt[row])/n
        entropy-=log(p,2)*p
    return entropy


def chooseBestFeat(dataset):
    m=len(dataset[0])-1
    n=float(len(dataset))
    baseEntropy=cal(dataset)
    bestInfoGain,bestFeat=0.0,-1
    for i in range(m):
        featset=set([row[i] for row in dataset])
        ans=0.0
        for j in featset:
            splitset=splitdata(dataset,i,j)
            ans+=cal(splitset)*len(splitset)/n
        ans=baseEntropy-ans
        if ans>bestInfoGain:
            bestInfoGain=ans
            bestFeat=i
    return bestFeat


def createTree(dataset,labels):
    result=[row[-1] for row in dataset]
    if result.count(result[0])==len(result):
        return result[0]
    if len(dataset[0])==1:
        return chooseMaxNum(dataset)
    bestFeat=chooseBestFeat(dataset)
    bestFeatLabels=labels[bestFeat]
    myTree={bestFeatLabels:{}}
    del(labels[bestFeat])
    featSet=[x[bestFeat] for x in dataset]
    featSet=set(featSet)
    for i in featSet:
        sublabel=labels[:]
        myTree[bestFeatLabels][i]=createTree(splitdata(dataset,bestFeat,i),sublabel)
    return myTree


def classify(myTree,labels,testdata):
    nowfeat=list(myTree.keys())[0]
    secondTree=myTree[nowfeat]
    featIndex=labels.index(nowfeat)
    featvalue=testdata[featIndex]
    ans=secondTree[featvalue]
    if isinstance(ans,dict):
        classLabel=classify(ans,labels,testdata)
    else:
        classLabel=ans
    return classLabel


def fishTest():
    dataset,labels=createDataSet()
    import copy
    myTree=createTree(dataset,copy.deepcopy(labels))
    print(myTree)
    print(classify(myTree,labels,[1,1]))


if __name__ == '__main__':
    fishTest()

结果:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
yes

举报

相关推荐

0 条评论