机器学习之决策树详解
摘要决策树Decision Tree是一种基于树结构进行决策的机器学习算法广泛应用于分类与回归任务。其核心思想是通过对特征空间进行递归分裂构建一棵能够对数据进行高效预测的树形模型。本文系统讲解决策树的基本原理、分裂准则信息增益、基尼系数、信息增益率、经典算法ID3、C4.5、CART及其对比、剪枝策略并结合 scikit-learn 提供完整的实战代码示例帮助读者从理论到实践全面掌握决策树算法。本文适合机器学习初学者及希望深入理解决策树原理的开发者参考。关键词决策树、信息增益、信息熵、基尼系数、CART、scikit-learn、剪枝一、决策树概述1.1 什么是决策树决策树是一种监督学习算法可用于分类Classification和回归Regression任务。它模拟人类决策过程通过一系列的是/否问题对特征值的判断将数据逐层划分最终得到预测结果。因其模型结构形似一棵倒置的树而得名。决策树的组成要素根节点Root Node树的最顶端包含整个数据集是分裂的起点。内部节点Internal Node对应特征的判断节点表示对某个特征的测试。叶节点Leaf Node树的末端节点代表最终的分类标签或回归值。分支Branch节点的输出路径对应特征的不同取值。1.2 决策树的工作原理给定一个输入样本决策树从根节点开始根据该样本在各特征上的取值沿着对应的分支向下递归直到抵达叶节点叶节点的标签即为预测结果。这一过程类似于医生诊断疾病医生依次询问症状特征测试根据回答逐步缩小可能病因的范围最终确定诊断结论叶节点标签。1.3 决策树的优势与局限优势局限易于理解和解释能可视化容易过拟合泛化能力差训练和预测速度快对数据敏感微小变化可能导致树结构大幅改变支持连续值和离散值特征偏向于选择取值更多的特征能处理多分类问题不擅长处理不平衡数据二、决策树原理详解2.1 树结构基础决策树通过递归分裂Recursive Splitting构建基本算法如下输入数据集 D特征集 A 1. 从根节点开始用全部数据构建树 2. 如果节点中所有样本属于同一类别 C则该节点为叶节点标记为 C 3. 如果特征集为空或数据集为空则停止分裂 4. 选择最优分裂特征 a*分裂准则信息增益最大或基尼系数最小 5. 按特征 a* 的取值将数据划分为若干子集 6. 对每个子集递归执行步骤 2-52.2 分裂准则信息增益与信息熵熵Entropy是度量数据混乱程度的指标。熵越高数据越混乱熵越低数据越纯净。信息熵的公式定义为$$H(X) -\sum_{i1}^{n} p(x_i) \log_2 p(x_i)$$其中 $p(x_i)$ 表示事件 $x_i$ 发生的概率。对于二分类问题设正类比例为 $p$则$$H(X) -p \log_2 p - (1-p) \log_2 (1-p)$$信息增益Information Gain表示在已知某个特征后数据集不确定性的减少量。计算公式为$$IG(D, a) H(D) - \sum_{v \in \text{Values}(a)} \frac{|D_v|}{|D|} \cdot H(D_v)$$其中$H(D)$ 为分裂前数据集的熵$H(D_v)$ 为按特征 $a$ 的取值 $v$ 分裂后子集的熵$\frac{|D_v|}{|D|}$ 为子集权重ID3 算法选择信息增益最大的特征作为当前最优分裂特征。2.3 分裂准则基尼系数基尼系数Gini Impurity是 CARTClassification and Regression Tree算法使用的分裂准则度量从数据集中随机抽取两个样本、其类别不一致的概率$$Gini(D) 1 - \sum_{k1}^{K} p_k^2$$其中 $p_k$ 为第 $k$ 类样本在数据集 $D$ 中的比例。使用特征 $a$ 分裂后的加权基尼系数$$Gini_a(D) \sum_{v \in \text{Values}(a)} \frac{|D_v|}{|D|} \cdot Gini(D_v)$$CART 选择基尼系数最小的特征进行分裂。2.4 分裂准则信息增益率ID3 算法存在一个明显缺陷倾向于选择取值更多的特征。例如为每个样本赋予唯一 ID则按 ID 分裂的信息增益最大但毫无泛化能力。为解决这一问题C4.5 算法引入信息增益率Gain Ratio$$GainRatio(D, a) \frac{IG(D, a)}{IV(a)}$$其中 $IV(a)$ 为特征 $a$ 的固有值Intrinsic Value$$IV(a) -\sum_{v \in \text{Values}(a)} \frac{|D_v|}{|D|} \log_2 \frac{|D_v|}{|D|}$$特征取值越多固有值越大从而抑制信息增益的偏好。C4.5 算法通过先筛选信息增益高于平均水平的特征再选择增益率最高的特征来解决这一偏置问题。三、决策树经典算法3.1 ID3 算法ID3Iterative Dichotomiser 3由 Ross Quinlan 于 1986 年提出是最早的决策树算法。核心特点使用信息增益作为分裂准则仅支持分类任务仅支持离散型类别型特征不支持连续值特征、缺失值和剪枝3.2 C4.5 算法C4.5 是 ID3 的改进版本由 Quinlan 于 1993 年提出。核心改进使用信息增益率替代信息增益支持连续值特征通过二分阈值处理支持缺失值处理支持后剪枝基于错误率的剪枝3.3 CART 算法CARTClassification and Regression Tree由 Breiman 等人于 1984 年提出是目前应用最广泛的决策树算法。核心特点使用基尼系数分类或方差回归作为分裂准则二叉树结构每个内部节点只有两个分支既支持分类也支持回归内置剪枝机制3.4 三种算法对比特性ID3C4.5CART提出年份198619931984分裂准则信息增益信息增益率基尼系数 / 方差树结构多叉树多叉树二叉树支持分类✅✅✅支持回归❌❌✅支持连续值❌✅✅支持缺失值❌✅✅剪枝策略无后剪枝后剪枝 / 预剪枝四、剪枝策略决策树如果不加限制地分裂会完全拟合训练数据导致过拟合。剪枝Pruning是解决这一问题的核心手段。4.1 预剪枝Pre-pruning预剪枝在决策树构建过程中通过设置停止条件来提前终止分裂。常用停止条件树的深度达到设定阈值节点样本数少于设定阈值分裂后信息增益基尼系数减少量低于设定阈值优点计算效率高适合大规模数据。缺点可能过早终止欠拟合风险较高。4.2 后剪枝Post-pruning后剪枝先让决策树充分生长再自底向上地将某些子树替换为叶节点通过验证集评估剪枝效果。REPReduced Error Pruning自底向上尝试剪枝如果剪枝后验证集精度不下降则保留剪枝。CCPCost-Complexity Pruning代价复杂度剪枝在 CART 中常用定义代价复杂度指标$$R\alpha(T) R(T) \alpha \cdot |T{leaf}|$$其中 $R(T)$ 为树的训练误差$|T_{leaf}|$ 为叶节点数$\alpha$ 为复杂度参数。通过逐步增加 $\alpha$生成一系列逐渐简化的树序列再用验证集选择最优树。五、决策树使用场景5.1 客户分群在电商和金融领域决策树可用于根据用户的年龄、收入、消费行为等特征将客户划分为不同群体为精准营销提供依据。决策树规则易于业务人员理解便于落地执行。5.2 信用评估银行和金融机构利用决策树评估借款人的信用风险。通过分析申请人的收入水平、工作年限、负债比例、历史逾期记录等特征构建信用评分模型决定是否放贷及贷款额度。5.3 医疗诊断决策树在医学辅助诊断中应用广泛。根据患者的症状、检查指标和病史数据决策树可以构建疾病筛查模型辅助医生进行早期诊断。例如判断患者是否患有糖尿病、心脏病等。5.4 规则提取决策树的叶节点路径可以直接转化为 if-then 业务规则。例如如果客户购买频率 10次/月 且 平均订单金额 500元则为高价值客户。这类规则无需建模背景知识的业务人员也能理解和使用。六、实战代码鸢尾花分类本节使用 scikit-learn 内置的鸢尾花数据集展示决策树分类器的完整使用流程。6.1 基础分类器构建# 导入必要的库 from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import accuracy_score, classification_report, confusion_matrix # 1. 加载数据集 iris load_iris() X iris.data # 特征矩阵花萼长度、花萼宽度、花瓣长度、花瓣宽度 y iris.target # 目标标签0-Setosa, 1-Versicolor, 2-Virginica # 2. 划分训练集和测试集80%训练20%测试 X_train, X_test, y_train, y_test train_test_split( X, y, test_size0.2, random_state42, stratifyy ) # 3. 创建决策树分类器使用基尼系数CART算法默认 # max_depth: 限制树的最大深度防止过拟合 # random_state: 设置随机种子保证结果可复现 clf DecisionTreeClassifier( criteriongini, # 分裂准则gini基尼系数或 entropy信息增益 max_depth4, # 树的最大深度 min_samples_split5, # 节点分裂所需最少样本数 min_samples_leaf2, # 叶节点最少样本数 random_state42 ) # 4. 训练模型 clf.fit(X_train, y_train) # 5. 在测试集上进行预测 y_pred clf.predict(X_test) # 6. 评估模型性能 accuracy accuracy_score(y_test, y_pred) print(f模型准确率{accuracy:.4f}) print(\n混淆矩阵) print(confusion_matrix(y_test, y_pred)) print(\n分类报告) print(classification_report(y_test, y_pred, target_namesiris.target_names))运行结果示例模型准确率1.0000 混淆矩阵 [[10 0 0] [ 0 10 0] [ 0 0 10]] 分类报告 precision recall f1-score support setosa 1.00 1.00 1.00 10 versicolor 1.00 1.00 1.00 10 virginica 1.00 1.00 1.00 10 accuracy 1.00 306.2 决策树可视化使用plot_tree函数将决策树结构可视化直观理解模型的分裂逻辑。import matplotlib.pyplot as plt from sklearn.tree import plot_tree # 设置中文字体 plt.rcParams[font.sans-serif] [SimHei, DejaVu Sans] plt.rcParams[axes.unicode_minus] False # 创建画布 fig, ax plt.subplots(figsize(24, 12)) # 绘制决策树 plot_tree( clf, feature_namesiris.feature_names, # 特征名称 class_namesiris.target_names, # 类别名称 filledTrue, # 用颜色填充节点颜色深浅表示类别纯度 roundedTrue, # 圆角矩形 fontsize10, axax ) plt.title(鸢尾花数据集决策树分类器, fontsize16) plt.tight_layout() plt.savefig(decision_tree_visualization.png, dpi150, bbox_inchestight) plt.show() print(决策树可视化图已保存为 decision_tree_visualization.png)6.3 不同树深度对模型的影响树的深度是影响模型复杂度最重要的超参数。本节实验不同深度下的训练集和测试集准确率观察过拟合与欠拟合现象。import numpy as np # 测试不同深度的准确率 depths range(1, 11) train_accuracies [] test_accuracies [] for depth in depths: # 使用不同的随机种子进行多次实验取平均 temp_train_acc [] temp_test_acc [] for seed in range(10): dt DecisionTreeClassifier(max_depthdepth, random_stateseed) dt.fit(X_train, y_train) temp_train_acc.append(dt.score(X_train, y_train)) temp_test_acc.append(dt.score(X_test, y_test)) train_accuracies.append(np.mean(temp_train_acc)) test_accuracies.append(np.mean(temp_test_acc)) # 绘制准确率曲线 fig, ax plt.subplots(figsize(10, 6)) ax.plot(depths, train_accuracies, o-, label训练集准确率, linewidth2, markersize8) ax.plot(depths, test_accuracies, s-, label测试集准确率, linewidth2, markersize8) ax.set_xlabel(决策树最大深度, fontsize12) ax.set_ylabel(准确率, fontsize12) ax.set_title(决策树深度与模型准确率的关系, fontsize14) ax.set_xticks(depths) ax.legend(fontsize11) ax.grid(True, linestyle--, alpha0.7) ax.set_ylim(0.85, 1.02) plt.tight_layout() plt.savefig(depth_vs_accuracy.png, dpi150, bbox_inchestight) plt.show() # 打印详细数据 print(\n深度 | 训练集准确率 | 测试集准确率) print(- * 40) for d, train_acc, test_acc in zip(depths, train_accuracies, test_accuracies): print(f {d:2d} | {train_acc:.4f} | {test_acc:.4f})结果分析当深度为 1 时模型过于简单训练集和测试集准确率都较低存在欠拟合。当深度增加时训练集准确率持续上升并趋近于 1.0完全拟合。测试集准确率在某个最优深度处达到峰值后开始下降表明过拟合开始出现。建议选择测试集准确率最高对应的深度作为最优超参数。6.4 特征重要性分析决策树提供了特征重要性Feature Importance指标度量每个特征对分类任务的贡献程度。# 获取特征重要性 importances clf.feature_names_in_ importances clf.feature_importances_ # 打印各特征的重要性 print(特征重要性排名) print(- * 45) for name, importance in sorted( zip(iris.feature_names, importances), keylambda x: x[1], reverseTrue ): bar █ * int(importance * 40) print(f{name:15s} : {importance:.4f} {bar}) # 可视化特征重要性 fig, ax plt.subplots(figsize(8, 5)) colors plt.cm.Reds(np.linspace(0.4, 0.9, len(iris.feature_names))) sorted_idx np.argsort(importances) ax.barh( [iris.feature_names[i] for i in sorted_idx], importances[sorted_idx], colorcolors[sorted_idx] ) ax.set_xlabel(重要性, fontsize12) ax.set_title(决策树特征重要性分析, fontsize14) ax.set_xlim(0, max(importances) * 1.15) plt.tight_layout() plt.savefig(feature_importance.png, dpi150, bbox_inchestight) plt.show()结果解读特征重要性之和为 1数值越大表示该特征在分类决策中越关键。在鸢尾花数据集中花瓣长度Petal length通常是最重要的特征对分类的贡献最大。6.5 使用信息增益熵准则除了默认的基尼系数我们还可以使用信息增益criterionentropy构建决策树。# 使用信息增益熵作为分裂准则 clf_entropy DecisionTreeClassifier( criterionentropy, # 使用信息增益替代基尼系数 max_depth4, min_samples_split5, min_samples_leaf2, random_state42 ) clf_entropy.fit(X_train, y_train) # 比较两种准则的准确率 acc_gini clf.score(X_test, y_test) acc_entropy clf_entropy.score(X_test, y_test) print(f基尼系数准则 - 测试集准确率{acc_gini:.4f}) print(f信息增益准则 - 测试集准确率{acc_entropy:.4f}) print(f\n两种准则在鸢尾花数据集上准确率差异{abs(acc_gini - acc_entropy):.4f})七、完整实战使用决策树进行信用评估本节以模拟的信用评估数据集为例展示决策树在金融场景中的完整应用流程。import pandas as pd from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier, export_text from sklearn.preprocessing import LabelEncoder # 1. 创建模拟的信用评估数据集 data { 年龄: [25, 35, 45, 28, 52, 38, 42, 50, 23, 33], 月收入: [8000, 15000, 20000, 6000, 30000, 12000, 18000, 25000, 5000, 9000], 工作年限: [2, 5, 10, 1, 20, 7, 8, 15, 0, 4], 负债比例: [0.2, 0.1, 0.3, 0.5, 0.15, 0.4, 0.25, 0.2, 0.6, 0.35], 有房产: [否, 是, 是, 否, 是, 否, 是, 是, 否, 否], 信用评级: [低, 高, 高, 低, 高, 中, 高, 高, 低, 中] } df pd.DataFrame(data) print(数据集前5行) print(df.head()) # 2. 数据预处理 # 对分类特征进行标签编码 le_house LabelEncoder() df[有房产] le_house.fit_transform(df[有房产]) # 否0, 是1 le_credit LabelEncoder() df[信用评级] le_credit.fit_transform(df[信用评级]) # 低0, 中1, 高2 # 3. 划分特征和标签 X df.drop(信用评级, axis1) y df[信用评级] # 4. 划分训练集和测试集 X_train, X_test, y_train, y_test train_test_split( X, y, test_size0.3, random_state42, stratifyy ) # 5. 训练决策树 credit_tree DecisionTreeClassifier( criteriongini, max_depth4, min_samples_split2, random_state42 ) credit_tree.fit(X_train, y_train) # 6. 评估模型 train_acc credit_tree.score(X_train, y_train) test_acc credit_tree.score(X_test, y_test) print(f\n训练集准确率{train_acc:.4f}) print(f测试集准确率{test_acc:.4f}) # 7. 提取并打印决策规则文本形式 feature_names list(X.columns) rules export_text(credit_tree, feature_namesfeature_names) print(\n决策树规则文本形式) print(rules) # 8. 打印特征重要性 print(\n信用评估特征重要性) for name, imp in sorted(zip(feature_names, credit_tree.feature_importances_), keylambda x: x[1], reverseTrue): if imp 0: print(f {name}: {imp:.4f})运行结果示例数据集前5行 年龄 月收入 工作年限 负债比例 有房产 信用评级 0 25 8000 2 0.2 否 低 1 35 15000 5 0.1 是 高 2 45 20000 10 0.3 是 高 ... 训练集准确率1.0000 测试集准确率0.6667 决策树规则文本形式 |--- 负债比例 0.30 | |--- 月收入 12500 | | |--- class: 中 | |--- 月收入 12500 | | |--- class: 高 |--- 负债比例 0.30 | |--- class: 低 信用评估特征重要性 负债比例: 0.7111 月收入: 0.2889八、总结本文系统介绍了决策树算法的核心概念与实战应用原理基础决策树通过递归分裂构建树结构每个节点对特征进行判断最终叶节点输出预测结果。分裂准则信息增益ID3、信息增益率C4.5和基尼系数CART是三种主流的分裂准则分别从不同角度度量数据纯度的提升。算法对比ID3、C4.5、CART 在分裂准则、树结构、任务类型和剪枝能力上各有差异实际应用中 CART 使用最为广泛。剪枝策略预剪枝通过提前停止分裂控制复杂度后剪枝通过验证集评估剪枝效果。CART 中的代价复杂度剪枝CCP是经典的后剪枝方法。实战要点使用 scikit-learn 构建决策树时max_depth、min_samples_split、min_samples_leaf是防止过拟合的关键超参数。特征重要性分析有助于理解模型决策依据。决策树不仅是高效的机器学习模型更是理解更复杂集成算法如随机森林、梯度提升树的重要基础。希望本文能帮助读者建立对决策树的完整认知并在实际项目中灵活运用。参考库版本本文代码适用scikit-learn 1.0 matplotlib 3.5 pandas 1.3 numpy 1.20