博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
以鸢尾花数据集为例,用Python对决策树进行分类
阅读量:2088 次
发布时间:2019-04-29

本文共 5010 字,大约阅读时间需要 16 分钟。

全文共4730字,预计学习时长10分钟

 

图片来源:https://www.pexels.com/@andree-brennan-974943

 

基于多种原因,决策树是一种广受欢迎的监督学习方法。决策树的优点包括可以同时用于回归和分类,易于解释且不需要特征缩放。同时决策树也存在劣势,容易出现过度拟合就是其中之一。本教程主要介绍了用于分类的决策树,也称为分类树。

 

 

此外,本教程还将涵盖:

 

· 分类树的解剖结构(树的深度、根节点、决策节点、叶节点/终端节点)。

 

· 分类树如何进行预测

 

· 如何使用scikit-learn(Python)制作分类树

 

· 超参数调整

 

现在,让我们开始吧!

 

什么是分类树?

 

分类和回归树(CART)术语最早由利奥·布雷曼提出,用于指代可以用于分类或回归预测建模问题的决策树算法。本篇文章主要涵盖分类树。

 

分类树

 

本质上,分类树就是设计一系列问题来进行分类。下图是在鸢尾花数据集(花种)上训练的分类树。根(棕色)和决策(蓝色)节点包含分成子节点的问题。根节点是最顶层的决策节点。换句话说,它是开始进行分类的节点。叶节点(绿色),也称为终端节点,是没有子节点的节点。叶节点是通过多数投票分配类的地方。

 

 

分类树到三种花种之一的分类(鸢尾花数据集)

 

如何使用分类树

 

要使用分类树,请从根节点(棕色)开始,然后一直进行分类,直到到达叶(终端)节点。使用下图中的分类树,想象有一朵花,花瓣长4.5厘米,然后对其进行分类。从根节点开始,首先设置第一个问题,“花瓣长度(cm)是否≤2.45”?长度若大于2.45,则为假。接着进入下一个决策节点,“花瓣长度(cm)是否≤4.95”?这是真的,所以你可以推测花种为杂色。这只是一个例子。

 

 

分类树是如何生成的?(非数学版)

 

分类树学习一系列是否/然后问题,每个问题涉及一个特征和一个分裂点。查看一下部分树(A),问题:“花瓣长度(cm)≤2.45”根据某个值(在这种情况下为2.45)将数据分成两个分支。节点之间的值称为分裂点。分裂点的值取得好(导致最大信息增益的值)是分类能否有序进行的关键。查看下图中的B部分,分割点左侧的所有点都被归为蓝色鸢尾花,而分割点右侧的所有点都被分类为杂色鸢尾花。

 

 

该图显示38个点都被正确归为蓝色。它是一个纯节点。分类树不会在纯节点上分割。这不会产生下一步信息增益。但是,不纯的节点可以进行下一步分类。请注意,图B的右侧显示许多点被错误分类为杂色鸢尾花。换句话说,它包含两个不同类别的点(virginica和versicolor)。分类树是一种贪婪算法,默认情况下还将继续分裂,直到有一个纯节点。同样,算法选择也会为不纯节点选择最佳分类点(我们将在下一节中介绍数学方法)

 

 

在上图中,树的最大深度为2。树深度可以衡量在进行预测之前可以进行多少次分裂。这个过程可以继续进行更多分裂,直到树尽可能纯净。分裂过程中的大量重复可能会导致形成拥有很多节点的分类树。通常会导致训练数据集过度拟合。幸运的是,实现大多数分类树支持预剪枝来控制树的最大深度,从而减少过度拟合。例如,Python的scikit-learn就支持预剪枝。换句话说,你可以设置最大深度以防止决策树过深。为直观了解最大深度,可以查看下图。

 

 

鸢尾花数据集适合的不同深度分类树。

 

 

选择标准

 

 

本节主要解答如何计算信息增益和两个标准即基尼和熵。

 

本节实际上是关于如何理解分类树上根/决策节点的良好分割点。决策树在特征和相应的分裂点上分离,产生给定标准的最大信息增益(IG)(在该示例中为基尼或熵)。我们可以将信息收益大概定义为:

 

IG = information before splitting (parent) — information after splitting (children)

 

为了更清楚地了解父节点和子节点,请查看下列决策树。

 

 

以下是更恰当的信息增益公式。

 

 

由于分类树具有二进制分割,因此可以将公式简化如下。

 

 

可以使用基尼指数和熵作为度量不纯节点的标准。

 

 

为了更好理解这些公式,下图展示了如何使用基尼系数来计算决策树的信息增益。

 

 

下图显示了如何计算使用熵决策树的信息增益。

 

 

本文不打算详细讨论这个问题,因为应该注意的是,不同的不纯度测量(基尼指数和熵)通常会产生类似的结果。下图显示基尼指数和熵有相似的不纯度标准。scikit-learn中的默认设置为Gini的原因之一可能是计算熵会慢一点(因为它使用了对数)。

 

 

不同的不纯度测量(基尼指数和熵)通常产生类似的结果。感谢Data Science StackExchange和Sebastian Raschka提供此图。

 

在本节结束之前,应该注意到各种决策树算法彼此不同。ID3,C4.5和CART算法使用较为普遍。Scikit-learn使用CART算法的优化版本。

 

 

使用Pytho的分类树

 

前面几节介绍了分类树的理论。学习如何用编程语言制作决策树的好处之一是处理数据有助于理解算法。

 

加载数据集

 

鸢尾花数据集是scikit-learn附带的数据集之一,不需要从某个外部网站下载任何文件。下面的代码加载了鸢尾花数据集。

 

import pandas as pd

from sklearn.datasets import load_irisdata = load_iris()

data = load_iris()

df = pd.DataFrame(data.data, columns=data.feature_names)

df['target'] = data.target

 

 

原始df(功能+目标)

 

将数据拆分为训练和测试集

 

下面的代码将75%的数据放入训练集,25%的数据放入测试集。

 

X_train, X_test, Y_train, Y_test = train_test_split(df[data.feature_names], df['target'], random_state=0)

 

 

图像中的颜色表示来自数据帧df的数据用于此特定列车测试分割的变量(X_train,X_test,Y_train,Y_test)。

 

请注意,使用决策树的一个好处是不必将数据标准化,这与PCA和逻辑回归不同,后者对不标准化数据造成的影响异常敏感。

 

 

Scikit-learn 4步建模模式

 

第1步:导入要使用的模型

 

在scikit-learn中,所有机器学习模型都是作为Python类实现。

 

from sklearn.tree import DecisionTreeClassifier

 

第2步:创建模型的实例

 

在下列代码中,设置max_depth = 2来预剪枝,以确保树深不超过2。本教程的下一节将讨论如何为树选择最佳的最大纵深。

 

另请注意,在下列代码中,使random_state = 0,以便你可以得出相同的结果。

 

clf = DecisionTreeClassifier(max_depth = 2, 

                             random_state = 0)

 

第3步:在数据上训练模型

 

该模型正在学习X(萼片长度,萼片宽度,花瓣长度和花瓣宽度)与Y(鸢尾花种类)之间的关系

 

clf.fit(X_train, Y_train)

 

第4步:预测测试数据标签

 

# Predict for 1 observation

clf.predict(X_test.iloc[0].values.reshape(1, -1))

# Predict for multiple observations

clf.predict(X_test[0:10])

 

请记住,预测只是叶节点中实例的多数类。

 

 

测量模型性能

 

虽然还有其他方法可以测量模型性能(精度,召回率,F1分数,ROC曲线等),但我们将简单地将精度用作指标。

 

将准确度定义为:

 

(正确预测的分数):正确预测/数据点总数

 

# The score method returns the accuracy of the model

score = clf.score(X_test, Y_test)

print(score)

 

调整树的深度

 

找到最佳的max_depth是调整模型的一种方法。以下代码输出具有不同max_depth的决策树的准确性。

 

 max_depth.

# List of values to try for max_depth:

max_depth_range = list(range(1, 6))

 

# List to store the average RMSE for each value of max_depth:

accuracy = []

 

for depth in max_depth_range:

    

    clf = DecisionTreeClassifier(max_depth = depth, 

                             random_state = 0)

    clf.fit(X_train, Y_train) 

 

  score = clf.score(X_test, Y_test)

    accuracy.append(score)

 

由于下图显示当max_depth大于或等于3时,准确度最高,所以最好设置max_depth = 3才最简单。

 

 

选择max_depth = 3是因为这一模型既准确又不复杂。

 

记住max_depth与决策树的深度不同至关重要。设置max_depth是一种预剪枝决策树的方法。换句话说,如果树在深度上已经尽可能纯净,它将不会继续分裂。下图显示了max_depth值为3,4和5的决策树。请注意,max_depth为4和5的树是相同的,树深都为4。

 

 

请注意我们如何拥有两个完全相同的树。

 

如果你想知道训练决策树的深度是多少,可以使用get_depth方法。此外,你可以使用get_n_leaves方法获取受过训练的决策树的叶节点数。

 

虽然本教程已经介绍了改变选择标准(基尼指数,熵等)和max_depth,但请记住,你还可以调整要拆分的节点的最小样本(min_samples_leaf),最大叶节点数(max_leaf_nodes)和其他。

 

 

特征重要性

 

分类树的一个优点是相对而言其更容易解释。scikit-learn中的分类树可以计算特征重要性,即由于对给定特征进行拆分而导致的基尼指数或熵减少的总量。Scikit-learn输出的特征为0到1之间的数字。将所有特征重要性归化为总和1。下面的代码显示决策树模型中每个特征的特征重要性。

 

importances = pd.DataFrame({'feature':X_train.columns,'importance':np.round(clf.feature_importances_,3)})

importances = importances.sort_values('importance',ascending=False)

 

 

在上面的示例中(对于鸢尾花的特定训练测试分类),花瓣宽度具有最高的特征重要性。可以通过查看相应的决策树来确认。

 

 

这个决策树分裂的唯一两个特征是花瓣宽度(cm)和花瓣长度(cm)

 

请记住,如果某个特征重要性值较低,并不一定意味着该特征对预测不重要,这只是意味着在树早期的特定阶段不会选择该特定特征。也可能该特征与另一个信息特征相同或高度相关。特征重要性值也不会告诉你哪个类可以预测,或者哪些特征之间的关系可能会影响预测。请务必注意,在执行交叉验证或类似操作时,你可以使用多个训练测试分类中特征重要性的平均值。

 

 

结论

 

分类和回归树(CART)技术已有些年头了(1984)是复杂技术的基础。决策树的主要缺点之一是其通常不是最准确的算法。部分是由于决策树是高方差算法,这意味着训练数据中的不同分裂可能导致树的差异大相径庭。

 

 

留言 点赞 发个朋友圈

我们一起分享AI学习与发展的干货

 

编译组:宋兰欣、廖馨婷

相关链接:

https://towardsdatascience.com/understanding-decision-trees-for-classification-python-9663d683c952

 

如需转载,请后台留言,遵守转载规范

 

推荐文章阅读

 

长按识别二维码可添加关注

读芯君爱你

 

你可能感兴趣的文章
java静态代理与动态代理简单分析
查看>>
JTS Geometry关系判断和分析
查看>>
阿里巴巴十年Java架构师分享,会了这个知识点的人都去BAT了
查看>>
Intellij IDEA 使用技巧一
查看>>
idea如何显示git远程与本地的更改对比?
查看>>
Git 分支 - 分支的新建与合并
查看>>
git创建与合并分支
查看>>
23种设计模式介绍以及在Java中的实现
查看>>
如何把本地项目上传到Github
查看>>
Git的使用--如何将本地项目上传到Github
查看>>
zookeeper客户端命令行查看dubbo服务的生产者和消费者
查看>>
intellij idea 相关搜索快捷键
查看>>
oracle查看数据库连接池中最大连接数和当前用户连接数等信息
查看>>
oracle中创建同义词(synonyms)表
查看>>
建立DB-LINK和建立视图
查看>>
普通视图和物化视图的区别(转)
查看>>
物化视图加DBLINK实现数据的同步_20170216
查看>>
Redis在京东到家的订单中的使用
查看>>
idea 注释模板设置
查看>>
单例模式singleton为什么要加volatile
查看>>