其他

谷歌大脑「辛顿」团队最新研究:将神经网络提炼成「软决策树」

2017-11-29 雷克世界


原文来源:arXiv

作者:Nicholas Frosst、 Geoffrey Hinton

「雷克世界」编译:嗯~阿童木呀


现如今,经实践证明,深度神经网络是执行分类任务的一种非常有效的方法。当输入数据是高维度,输入输出之间关系异常复杂,标注训练样本数量非常大的时候,深度神经网络的性能表现是非常好的。但是很难解释为什么学习网络在一个特定的测试用例做出特定的分类决策。这主要是由于它们对于分布式分层表示的依赖。如果我们能够充分利用从神经网络所获得的知识,并在一个依赖分层决策的模型中表达相同的知识,那么解释一个特定的决策将会容易得多。我们描述了一种使用已训练的神经网络创建一种软决策树的方法,该方法的泛化效果要比直接从训练数据中得以学习要好得多。



深度神经网络的优秀泛化能力取决于它们在隐藏层中分布式表示的使用,但这些表示难以理解。对于第一个隐藏层,我们可以理解是什么原因导致了一个单元的激活,而对于最后一个隐藏层,我们可以理解激活一个单元所产生的效果,但是对于其他隐藏层,理解一个特征激活的产生原因和造成的影响要困难得多,尤其是就输入和输出变量这些有意义的变量而言。与此同时,隐藏层中的单元将输入向量的表示分解为一组特征激活,通过这种方式,激活特征的组合效果能够在下一隐藏层中产生适当的分布式表示。这使得我们很难独立性地理解任何特定特征激活的函数作用,因为它的边际效应依赖于同一层中所有其他单元的影响。


这个图显示了一个软二进制决策树,其中有一个内部节点和两个叶节点。

 

深度网络通过对训练数据的输入和输出之间关系中的大量弱统计规律进行建模从而做出可靠的决策,基于这一事实,上述困难进一步加深,而且,神经网络中没有任何东西可以从训练集的抽样特性所产生的伪规律中区分这些弱规律,即数据的真实属性。面对所有这些困难,放弃理解深度神经网络是如何通过理解单一隐藏单元所作所为来进行一个分类决策的想法,似乎是明智的。

  

相比之下,决策树是如何进行任意特定的分类就很容易解释了,因为这取决于一个相对较短的决策序列,且每个决策都直接基于输入数据。然而,决策树通常不会像神经网络那样泛化。与神经网络中的隐藏单元不同的是,决策树较低级别的典型节点仅被一小部分训练数据所使用,因此决策树的较低部分倾向于过度拟合,除非与树的深度相比,训练集的大小大的程度能够呈现出指数级。

 

这是一个在MNIST上进行训练的深度为4的软决策树的可视化图。内部节点的图像是已学习过的过滤器,而叶部的图像是覆盖所有类的学习概率分布的可视化。而最后对每一个叶部的,以及对每条边缘的可能分类都已有注释。如果我们以最右边的内部结点为例,可以看到,在树的那个层级上,潜在的分类只有3或8,因此,已学习的过滤器只是简单地学习该如何区分这两个数字。结果是一个在寻找这个两个区域存在的过滤器,会连接到3的末端,从而生成8。

 

在本文中,我们提出了一种全新的解决泛化和可解释性之间矛盾的方法。我们不是试图了解深度神经网络是如何做出决策的,而是使用深度神经网络来训练一个决策树,它会对神经网络所发现的输入输出函数进行模仿,但是以一种完全不同的方式运行。如果有大量未标注的数据,则可以使用神经网络来创建一个更大的标注数据集用以训练决策树,从而克服决策树的统计无效性问题。即使未标注的数据不可用,也有可能利用生成建模方面所取得的最新进展,从一个类似数据分布的分布中生成合成的未标注数据。在不使用未标注的数据的情况下,我们可以通过使用一种叫做提炼(distillation)的技术以及一种能够做软决策的决策,将神经网络的泛化能力迁移到决策树中。


这是一个在Connect4数据集上进行训练的软决策树的前两层的可视化视图。通过检查学习过滤器,我们可以看到,该游戏可以分为两个不同的子类型游戏,其中一个游戏中,玩家已经把金币放在板的边缘,而另一个游戏中,玩家将金币放置在板的中心。

 

在测试期间,我们使用决策树作为我们的模型。它的执行效果可能会比神经网络稍微差一点,但它通常会快得多,而且现在我们有了一个模型,可以直接对其决策进行解释和参与其中。现在,我们首先对我们所使用的决策树的类型进行描述。我们之所以做出这个选择是为了便于将从深度神经网络获得的知识简化到决策树中。

 

我们已经描述了一种使用已训练的神经网络,以软决策树的形式创建一个更具可解释性的模型的方法,其中,决策树是通过随机梯度下降进行训练的,利用神经网络的预测以便提供更多的信息目标。软决策树使用已学习的过滤器做出一个基于输入样本的分层决策,最终选择一个特定的覆盖所有类的静态概率分布作为其输出。这种软决策树的泛化能力要比直接在数据上进行训练好得多,但性能表现要比用来提供对其进行训练的软目标的神经网络差得多。因此,如果能够解释一个模型为什么要以特定方式对特定测试用例进行分类是至关重要的话,那么我们就可以使用软决策树,但是,如果我们使用深度神经网络来改进这个具有可解释性模型的训练性能的话,我们仍然可以从中获益。


 欢迎个人分享,媒体转载请后台回复「转载」获得授权,微信搜索「BOBO_AI」关注公众号


中国人工智能产业创新联盟于2017年6月21日成立,超200家成员共推AI发展,相关动态:

中新网:中国人工智能产业创新联盟成立

ChinaDaily:China forms 1st AI alliance

证券时报:中国人工智能产业创新联盟成立 启动四大工程搭建产业生态“梁柱”

工信部网站:中国人工智能产业创新联盟与贵阳市政府、英特尔签署战略合作备忘录


点击下图加入联盟


下载中国人工智能产业创新联盟入盟申请表


关注“雷克世界”后不要忘记置顶

我们还在搜狐新闻、雷克世界官网、腾讯新闻、网易新闻、一点资讯、天天快报、今日头条、雪球财经……

↓↓↓点击阅读原文查看中国人工智能产业创新联盟手册

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存