博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
sklearn实现决策树算法
阅读量:4551 次
发布时间:2019-06-08

本文共 3032 字,大约阅读时间需要 10 分钟。

1、决策树算法是一种非参数的决策算法,它根据数据的不同特征进行多层次的分类和判断,最终决策出所需要预测的结果。它既可以解决分类算法,也可以解决回归问题,具有很好的解释能力。另外,对于决策树的构建方法具有多种出发点,它具有多种构建方式,如何构建决策树的出发点主要在于决策树每一个决策点上需要在哪些维度上进行划分以及在这些维度的哪些阈值节点做划分等细节问题。

具体在sklearn中调用决策树算法解决分类问题和回归问题的程序代码如下所示:

#1-1导入基础训练数据集 import numpy as np import matplotlib.pyplot as plt from sklearn import datasets d=datasets.load_iris() x=d.data[:,2:] y=d.target plt.figure() plt.scatter(x[y==0,0],x[y==0,1],color="r") plt.scatter(x[y==1,0],x[y==1,1],color="g") plt.scatter(x[y==2,0],x[y==2,1],color="b") plt.show() #1-2导入sklearn中的决策树算法进行数据的分类问题实现训练预测 from sklearn.tree import DecisionTreeClassifier dt1=DecisionTreeClassifier(max_depth=2,criterion="entropy")  #定义决策树的分类器相关决策超参数 dt1.fit(x,y) def plot_decision_boundary(model,axis):   #决策边界输出函数(二维数据点)     x0,x1=np.meshgrid(         np.linspace(axis[0],axis[1],int((axis[1]-axis[0])*100)).reshape(-1,1),         np.linspace(axis[2],axis[3], int((axis[3] - axis[2]) * 100)).reshape(-1,1)     )     x_new=np.c_[x0.ravel(),x1.ravel()]     y_pre=model.predict(x_new)     zz=y_pre.reshape(x0.shape)     from matplotlib.colors import ListedColormap     cus=ListedColormap(["#EF9A9A","#FFF59D","#90CAF9"])     plt.contourf(x0,x1,zz,cmap=cus) plot_decision_boundary(dt1,axis=[0.5,8,0,3]) plt.scatter(x[y==0,0],x[y==0,1],color="r") plt.scatter(x[y==1,0],x[y==1,1],color="g") plt.scatter(x[y==2,0],x[y==2,1],color="b") plt.show() #定义二分类问题的信息熵计算函数np.sum(-p*np.log(p)) def entropy(p):     return -p*np.log(p)-(1-p)*np.log(1-p) x1=np.linspace(0.01,0.99,100) y1=entropy(x1) plt.plot(x1,y1,"r") plt.show() #利用信息熵的原理对数据进行实现划分,决策树信息熵构建方式的原理实现代码 def split(x,y,d,value):     index_a=(x[:,d]<=value)     index_b=(x[:,d]>value)     return x[index_a],x[index_b],y[index_a],y[index_b] from collections import Counter def entropy(y):     Counter1=Counter(y)     res=0.0     for num in Counter1.values():         p=num/len(y)         res+=-p*np.log(p)     return res def try_spit(x,y):     best_entropy=float("inf")     best_d,best_v=-1,-1     for d in range(x.shape[1]):         sorted_index=np.argsort(x[:,d])         for i in range(1,len(x)):             if x[sorted_index[i-1],d] != x[sorted_index[i],d]:                 v=(x[sorted_index[i-1],d]+x[sorted_index[i],d])/2                 x_l,x_r,y_l,y_r=split(x,y,d,v)                 e=entropy(y_l)+entropy(y_r)                 if e
value) return x[index_a],x[index_b],y[index_a],y[index_b] from collections import Counter def gini(y): Counter1 = Counter(y) res = 1.0 for num in Counter1.values(): p = num / len(y) res -= p**2 return res def try_spit1(x,y): best_gini=float("inf") best_d,best_v=-1,-1 for d in range(x.shape[1]): sorted_index=np.argsort(x[:,d]) for i in range(1,len(x)): if x[sorted_index[i-1],d] != x[sorted_index[i],d]: v=(x[sorted_index[i-1],d]+x[sorted_index[i],d])/2 x_l,x_r,y_l,y_r=split(x,y,d,v) g=gini(y_l)+gini(y_r) if g

 

转载于:https://www.cnblogs.com/Yanjy-OnlyOne/p/11372286.html

你可能感兴趣的文章
jQuery Ajax 回调函数中调用$(this)的问题 [ 转 ]
查看>>
thymeleaf:字符串拼接+输出单引号
查看>>
springboot:集成fastjson(教训)
查看>>
网络流 Edmons-Karp 算法讲解
查看>>
「NOIP2018模拟9.10」公约数 - 找规律 - gcd
查看>>
使用java理解程序逻辑(15)
查看>>
bzoj 1879 状压dp
查看>>
python 一些特殊用法和坑
查看>>
WIFI密码破解全攻略
查看>>
c++string各种函数
查看>>
errno.h含义
查看>>
字典树(模型体)
查看>>
盒模型详解
查看>>
bzoj2157 旅游
查看>>
bzoj5016 [Snoi2017]一个简单的询问
查看>>
poj2417 bzoj3239 Discrete Logging(bsgs)
查看>>
UVa10054 - The Necklace(欧拉回路【输出带来的麻烦)
查看>>
string和stringbuffer的区别 集合的作用 ArrayList vector linklist hashmap hashtable collection和collections...
查看>>
6月27日 ajax
查看>>
iOS开发之画图板(贝塞尔曲线)
查看>>