[Python]决策数01─运用CART做决策树

Hi! 大家好,我是Eric,这次要来用Python做决策树。

/images/emoticon/emoticon30.gif


  • 缘起:决策树因为相对於其他机器学习模型而言,是较易被解释的,所以是蛮常见的分类方法。
  • 方法:运用 [Python]的[sklearn的tree.DecisionTreeClassifier] 套件。
  • 使用资料:kaggle的加州警备队(SWITRS)提供的加州交通事故资料集(collisions)。
  • 参考来源
  • https://www.kaggle.com/alexgude/california-traffic-collision-data-from-switrs

0. 在使用python前,由於电脑硬体本身限制,故先以EmEditor软件做初步的筛选,将资料量大幅减少,再输入到python中。

1. 载入套件及初步筛选出欲使用的资料。

# 载入资料
import pandas as pd
cm = pd.read_csv('colli_motor.txt', sep=",")

# 筛选资料,筛选出与机车有关的资料
cm_mot = cm[cm["motorcycle_collision"] == 1]
cm_mot_d = cm_mot.drop("motorcycle_collision", axis=1)

https://ithelp.ithome.com.tw/upload/images/20220129/20115774V5Ceozj256.png

2. 叙述性分析,先查看变数的资料值比例。

# 各变数资料圆饼图-以天气变数为例,其於变数作法相同

import matplotlib.pyplot as plt 

weather = cm_mot_d["weather_1"].value_counts() 
weather
weather2 = weather.to_frame(name = "count")
print(weather2)
weather2.insert(0, column="weather", value=["clear", "cloudy", "raining", "fog", "other", "wind", "snowing"])
print(weather2)

plt.figure(figsize=(7,10))    # 显示图框架大小

labels = weather2["weather"]      # 制作圆饼图的类别标签
separeted = (0, 0, 0.3, 0.7, 1.1, 1.5, 1.9)                  # 依据类别数量,分别设定要突出的区块
size = weather2["count"]                         # 制作圆饼图的数值来源

plt.pie(size,                           # 数值
        labels = labels,                # 标签
        autopct = "%1.1f%%",            # 将数值百分比并留到小数点一位
        explode = separeted,            # 设定分隔的区块位置
        pctdistance = 0.6,              # 数字距圆心的距离
        textprops = {"fontsize" : 12},  # 文字大小
        shadow=True)                    # 设定阴影

 
                                         # 使圆饼图比例相等
plt.title("Pie chart of collision weather", {"fontsize" : 30})  # 设定标题及其文字大小
plt.legend(loc = "best")                                   # 设定图例及其位置为最佳

plt.savefig("Pie chart of collision weather.jpg",   # 储存图档
            bbox_inches='tight',               # 去除座标轴占用的空间
            pad_inches=0.0)                    # 去除所有白边
plt.close()      # 关闭图表

https://ithelp.ithome.com.tw/upload/images/20220129/20115774np17CSjK01.jpg

3.1 资料前置处理-检查Null、NaN值。

#载入资料处理与决策树分析套件
from sklearn import tree
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn import metrics

# 检查是否有Null与NaN值
import numpy as np

print(np.isnan(cm_mot_d.any()))     #检查是否有NaN值
print()

https://ithelp.ithome.com.tw/upload/images/20220129/20115774tAPIIzV7fP.png

print(np.isfinite(cm_mot_d.all()))  # 检查资料是否为有限值

https://ithelp.ithome.com.tw/upload/images/20220129/20115774y86hMHEQKn.png

cm_mot_d.isnull().sum() #检查是否有null
#处理有null的栏位,移除null的资料列

cm_mot_d2 = cm_mot_d[cm_mot_d["population"].notnull()]
cm_mot_d3 = cm_mot_d2[cm_mot_d2["intersection"].notnull()]
cm_mot_d4 = cm_mot_d3[cm_mot_d3["weather_1"].notnull()]
cm_mot_d5 = cm_mot_d4[cm_mot_d4["type_of_collision"].notnull()]
cm_mot_d6 = cm_mot_d5[cm_mot_d5["road_surface"].notnull()]
cm_mot_d7 = cm_mot_d6[cm_mot_d6["road_condition_1"].notnull()]
cm_mot_d8 = cm_mot_d7[cm_mot_d7["lighting"].notnull()]
cm_mot_d9 = cm_mot_d8[cm_mot_d8["control_device"].notnull()]
  
print(cm_mot_d9.isnull().any().any())  #再次检查是否还有null

3.2 资料前置处理-将目标变数转为二元。

#资料前处理,由於CART是二元,所以将y分成严重与不严重,严重包含fatal、severe injury;不严重包含property damage only、pain、other injury

for i in range(len(cm_mot_d9["collision_severity"])):
    if cm_mot_d9["collision_severity"].iloc[i] == "fatal":
        cm_mot_d9["collision_severity"].iloc[i] = "severe"
    elif cm_mot_d9["collision_severity"].iloc[i] == "severe injury":
        cm_mot_d9["collision_severity"].iloc[i] = "severe"
    elif cm_mot_d9["collision_severity"].iloc[i] == "property damage only":
        cm_mot_d9["collision_severity"].iloc[i] = "not severe"
    elif cm_mot_d9["collision_severity"].iloc[i] == "pain":
        cm_mot_d9["collision_severity"].iloc[i] = "not severe"
    else:
        cm_mot_d9["collision_severity"].iloc[i] = "not severe"

3.3 资料前置处理-制作虚拟变数(dummy variable)。

#为了dummy後,other不重复栏位名称,排除other
cm_mot_d9_2 = cm_mot_d9[cm_mot_d9["type_of_collision"] != "other"]
cm_mot_d9_3 = cm_mot_d9_2[cm_mot_d9_2["weather_1"] != "other"]
cm_mot_d9_4 = cm_mot_d9_3[cm_mot_d9_3["road_condition_1"] != "other"]

#取出X与Y
cm_mot_d9_4_y = cm_mot_d9_4["collision_severity"]
cm_mot_d9_4_X = cm_mot_d9_4.drop("collision_severity", axis=1)

cm_mot_d9_4.describe(include = ["object"]) # 类别资料叙述分析

https://ithelp.ithome.com.tw/upload/images/20220129/2011577459F1eu4Xnh.png

#x资料前处理,转换为dummy variable使资料符合使用CART (https://towardsdatascience.com/the-dummys-guide-to-creating-dummy-variables-f21faddb1d40)

population2 = pd.get_dummies(cm_mot_d9_4_X["population"])
weather_1_2 = pd.get_dummies(cm_mot_d9_4_X["weather_1"])
type_of_collision2 = pd.get_dummies(cm_mot_d9_4_X["type_of_collision"])
road_surface2 = pd.get_dummies(cm_mot_d9_4_X["road_surface"])
road_condition_1_2 = pd.get_dummies(cm_mot_d9_4_X["road_condition_1"])
lighting2 = pd.get_dummies(cm_mot_d9_4_X["lighting"])
control_device2 = pd.get_dummies(cm_mot_d9_4_X["control_device"])

cm_mot_d9_4_X2 = pd.concat([cm_mot_d9_4_X, population2, weather_1_2, type_of_collision2, road_surface2, road_condition_1_2, lighting2, control_device2], axis=1)
cm_mot_d9_4_X3 = cm_mot_d9_4_X2.drop("population", axis=1)
cm_mot_d9_4_X4 = cm_mot_d9_4_X3.drop("weather_1", axis=1)
cm_mot_d9_4_X5 = cm_mot_d9_4_X4.drop("type_of_collision", axis=1)
cm_mot_d9_4_X6 = cm_mot_d9_4_X5.drop("road_surface", axis=1)
cm_mot_d9_4_X7 = cm_mot_d9_4_X6.drop("road_condition_1", axis=1)
cm_mot_d9_4_X8 = cm_mot_d9_4_X7.drop("lighting", axis=1)
cm_mot_d9_4_X9 = cm_mot_d9_4_X8.drop("control_device", axis=1)

https://ithelp.ithome.com.tw/upload/images/20220129/20115774FQGJgqlx9h.png

4. 建立决策树模型。

# 切分训练与测试资料
train_X, test_X, train_y, test_y = train_test_split(cm_mot_d9_4_X9, cm_mot_d9_4_y, test_size = 0.3)

#先查看不同深度的准确度,以决定
# List of values to try for max_depth:
max_depth_range = list(range(3, 13))
# List to store the accuracy for each value of max_depth:
accuracy = []
for depth in max_depth_range:
    clf = tree.DecisionTreeClassifier(criterion="gini", max_depth = depth)
    clf.fit(train_X, train_y)
    test_y_predicted = clf.predict(test_X)
    score = metrics.accuracy_score(test_y, test_y_predicted)
    accuracy.append(score)
    
print(accuracy)

https://ithelp.ithome.com.tw/upload/images/20220129/20115774gfGMET6qtV.png

#查看不同子节点最小样本数的准确度
# List of values to try for max_depth:
min_leaf_range = list(range(5, 15))
# List to store the accuracy for each value of max_depth:
accuracy = []
for leaf in min_leaf_range:
    clf2 = tree.DecisionTreeClassifier(criterion="gini", max_depth = 3, min_samples_leaf = leaf)
    clf2.fit(train_X, train_y) 
    test_y_predicted2 = clf2.predict(test_X)
    score = metrics.accuracy_score(test_y, test_y_predicted2)
    accuracy.append(score)
    
print(accuracy)

https://ithelp.ithome.com.tw/upload/images/20220129/20115774ojqVVRBCrq.png

# 建立分类器 (http://www.taroballz.com/2019/05/15/ML_decision_tree_detail/)
clf3 = tree.DecisionTreeClassifier(criterion="gini", max_depth = 3)
clf3.fit(train_X, train_y)

# 预测
test_y_predicted3 = clf3.predict(test_X)

# 绩效
print(confusion_matrix(test_y, test_y_predicted3))
print() 
print(classification_report(test_y, test_y_predicted3))
accuracy = metrics.accuracy_score(test_y, test_y_predicted3)
print(f"Accuracy is {accuracy}.")

https://ithelp.ithome.com.tw/upload/images/20220129/20115774GpxDbDww7V.png

# 制作混淆矩阵热力图
import seaborn as sns

sns.set()
f,ax=plt.subplots()
C2= confusion_matrix(test_y, test_y_predicted3, labels = ["not severe", "severe"])
print(C2) #打印出来看看
sns.heatmap(C2, annot=True, ax=ax, cmap = "GnBu") #画热力图

ax.set_title('confusion matrix') #标题
ax.set_xlabel('predict') #x轴
ax.set_ylabel('true') #y轴

https://ithelp.ithome.com.tw/upload/images/20220129/20115774rErGHGsOuR.png

5. 视觉化。

features = list(cm_mot_d9_4_X9[:])


# viz code

from six import StringIO
import pydot
import pydotplus


dot_data = StringIO()
tree.export_graphviz(clf3,
        out_file=dot_data,
        feature_names=features,
        class_names=clf3.classes_,
        filled=True, rounded=True,
        impurity=False)

graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("cm7.pdf")

https://ithelp.ithome.com.tw/upload/images/20220129/201157744VIeRbFS7H.png

6. 大功告成。
运用CART的重点有:1.将目标变数转为二元化。2.将自变数转为虚拟变数。


<<:  ES6 常用方法

>>:  MacOS读取蓝牙摇杆讯号,利用python修改pynput程序码实现 - 1.起始

Day 09 CSS <背景属性>

CSS背景属性 可以给页面元素添加背景样式 背景属性可设置背景颜色、背景图片、背景平舖、背景图片位置...

23.MYSQL IN指令

IN跟OR其实是很相像的东西,但是IN比OR更好使用 不需要用拢长的句子,就可以达成和OR一样的目的...

【Day05-遍历】不要再只会用for回圈了,你值得拥有更好的选择-apply

第三天我们简单介绍了处理表格的pandas套件 接下来就要开始对资料进行处理了 我们都知道电脑比起人...

Day15 Pseudo Element 眼见不一定为真

什麽是伪元素(Pseudo Element)? 伪元素就如同它的名字一样,不是一个实际存在於网页里...

使相机看着目标

大家好,我是西瓜,你现在看到的是 2021 iThome 铁人赛『如何在网页中绘制 3D 场景?从 ...