cart树怎么进行剪枝?

能否详细介绍一下cart剪枝算法流程? 推荐经典的剪枝论文? 谢谢。
关注者
233
被浏览
103,331

27 个回答

李航老师的《统计学习方法》第五章中有CART剪枝,主要思路是:

对于原始的CART树A0,先剪去一棵子树,生成子树A1,然后再从A1剪去一棵子树生成A2,直到最后剪到只剩一个根结点的子树An。于是得到了A0-AN一共n+1棵子树。然后再用n+1棵子树预测独立的验证数据集,谁的误差最小就选谁。

面临一个问题:一棵树,应该剪哪个结点?

这部分始终困扰我的是(5.31)式,即:

书里对g(t)的解释是:

它表示剪枝后整体损失函数减少的程度。

这么说应该剪掉以g(t)最大结点为根的子树,因为g(t)最大,那么剪枝后整体损失函数减少程度也最大。但书中的算法却说优先剪去g(t)最小的子树,困惑了我好久。

实际上这个g(t)表示剪枝的阈值,即对于某一结点a,当总体损失函数中的参数alpha = g(t)时,剪和不剪总体损失函数是一样的(这可以在书中(5.27)和(5.28)联立得到)。这时如果alpha稍稍增大,那么不剪的整体损失函数就大于剪去的。即alpha大于g(t)该剪,剪了会使整体损失函数减小;alpha小于g(t)不该剪,剪了会使整体损失函数增大。

(请注意上文中的总体损失函数,对象可以是以a为根的子树,也可以是整个CART树,对a剪枝前后二者的总体损失函数增减是相同的。)

对于同一棵树的结点,alpha都是一样的,当alpha从0开始缓慢增大,总会有某棵子树该剪,其他子树不该剪的情况,即alpha超过了某个结点的g(t),但还没有超过其他结点的g(t)。这样随着alpha不断增大,不断地剪枝,就得到了n+1棵子树,接下来只要用独立数据集测试这n+1棵子树,试试哪棵子树的误差最小就知道那棵是最好的方案了。

如有错误请一定指正。

CART剪枝算法分为两步:

1.首先从CART生成算法产生的决策树T0的底端开始,不断剪枝,直到T0的根节点,从而获得一个子树序列{T0,T1,...,Tn};

2.通过交叉验证子树序列中的每个子树进行测试,从中选择最优子树作为最终的剪枝结果;

那么如何构造1中的子树序列{T0,T1,...,Tn}呢:

原则上进行剪枝需要满足以下条件:

子树序列的损失函数至少要≤剪枝前的损失函数(剪枝前后损失函数不变,但剪枝后复杂度降低,亦可以剪枝);

CART树T剪枝时的损失函数如下:

C_\alpha(T)=C(T)+\alpha|T|

其中C(T)为训练数据的预测误差(分类可以是sum(每个叶子节点的gini 乘以 叶子节点样本数目),回归则是平方损失);α≥0为参数(正则化项的系数);|T|为树的叶子节点个数;

当α=0时,损失函数等于预测误差,相当于不进行剪枝,对应的最优子树即决策树本身(计为T0);

α越大,则惩罚越大,会得到更加简单的树,即剪枝幅度更大;

对于固定的α,一定存在一个使得损失函数Cα(T)最小的子树Tα,即对于α序列{α0,α1,...,αn},有最优子树序列{T0,T1,...,Tn}与其一一对应,因此我们只需要在最优子树序列中寻找交叉验证集效果最好的那个作为最终的剪枝结果就可以了;

到这里,子树序列的构造可以转变为α序列{α0,α1,...,αn}的构造;

下面阐述如何构造α序列:

我们将α序构造为递增的序列,则子树序列T是满树到根节点树的递减树序列;

首先令α0=0,决策树本身为T0,在T0上进行第一次剪枝(构造α1):

上面已经提到剪枝后的损失函数要≤剪枝前,

因此,如果我们想在内部节点t处进行剪枝(即将t的所有子节点剔除,将t设置为叶子节点),只需要将问题聚焦于节点t以及t对应的子树Tt上(Tt为在t处剪枝剪掉的部分):

若在节点t处进行剪枝,则t变成了单节点树,对应的损失函数为:

C_\alpha(t)=C(t)+\alpha*1

若不进行剪枝,t及以下部分构成的子树Tt的损失函数为:

C_\alpha(T_t)=C(T_t)+\alpha|T_t|

已知剪枝后的损失函数≤剪枝前的损失函数,则有:

C_\alpha(t) \leq C_\alpha(T_t)

即:

C(t)+\alpha \leq C(T_t)+\alpha|T_t|

可得:

α\geq\frac{C(t)-C(T_t)}{|T_t|-1}

即当 \alpha\in[0,\frac{C(t)-C(T_t)}{|T_t|-1}) 时,不满足剪之后的损失函数小于剪枝前的损失函数,不能进行剪枝;

因此对于T0,能够在t处剪枝的最小的 α为:\alpha_{min}=\frac{C(t)-C(T_t)}{|T_t|-1}

自下而上地对内部每个内部节点t计算可以剪枝的最小α_min,取其中最小的α_min作为α1(α序列逐渐增大,树序列T逐渐更简单,最小的α_min保证没有遗漏可以剪枝的子树),对应的子树为T1,剪切点为t1;

这样我们便得到了α序列中的α1以及对应的子树T1。

接着对子树T1执行以上过程便可以得到(α2和T2),不停地重复以上过程便可以得到最优子树序列{T0,T1,...,Tn},且每一颗树都是上一颗树的子集(在上一颗的基础上进行剪枝)。

获得了可以剪枝的最优子树序列{T0,T1,...,Tn}之后,再将每棵树进行交叉验证,交叉验证结果最好的那颗子树便是最终的剪枝结果。