神经网络相似性如何帮助我们理解训练和泛化?
文 / Google Brain 团队 Maithra Raghu 和 DeepMind 团队 Ari S. Morcos
为了完成任务,深度神经网络 (DNN) 将输入数据逐步转换为复杂表征序列(即,各个神经元之间的激活模式)。理解这些表征不仅对可解释性至关重要,而且还有助于我们更智能地设计机器学习系统。不过,目前已证明理解这些表征相当困难,特别是在跨网络比较表征时。在上一篇博文中,我们概述了利用典型相关分析 (CCA) 理解和比较卷积神经网络 (CNN) 表征的优势,表明它们按自下而上的模式收敛,在训练过程中,早期层先于后期层收敛到最终表征。
在 “Insights on Representational Similarity in Neural Networks with Canonical Correlation” 一文中,我们进一步推进此项研究,对 CNN 的表征相似性形成新的认识,包括记忆网络(例如,只能对之前看过的图像进行分类的网络)与泛化网络(例如,能够将之前未看过的图像正确分类的网络)之间的差异。重要的是,我们还对这一方法进行了扩展,以深入了解递归神经网络 (RNN) 的动态,这类模型对序列数据(如语言)特别有用。比较 RNN 在许多方面与 CNN 一样困难,但是 RNN 带来了额外的挑战,即它们的表征在序列化过程中会发生变化。这样一来,CCA 凭借其不变性这一优势,成了研究 RNN 和 CNN 的理想工具。因此,我们额外开源了用于在神经网络上应用 CCA 的代码,希望帮助研究社区更好地理解网络动态。
记忆与泛化 CNN 的表征相似性
最后,只有当机器学习系统能够泛化到之前未看到过的新情景时,它才是有用的。因此,了解泛化网络和非泛化网络的区分因素非常重要,并可能引出提高泛化性能的新方法。为了考察表征相似性是否可以预测泛化,我们研究了两种类型的 CNN:
泛化网络:此类 CNN 使用含有未修改的准确标签的数据进行训练,并学习泛化到新数据的解。
记忆网络:此类 CNN 使用具有随机标签的数据集进行训练,因此,它们必须记住训练数据,而不能根据定义进行泛化(如 Zhang et al., 2017 中所述)。
我们为每个网络训练了多个实例(只是网络权重的初始随机值和训练数据的顺序不同),并使用新的加权方法来计算 CCA 距离度量(详见我们的论文),以比较每组网络内的表征以及记忆和泛化网络之间的表征。
我们发现,与记忆网络组相比,不同的泛化网络组均一致地收敛到更相似的表征(尤其是在后期层中)(见下图)。在表示网络最终预测的 softmax 中,由于每个单独组中的网络做出的预测类似,因此每组泛化和记忆网络的 CCA 距离显著减小。
与记忆网络组(红色)相比,泛化网络组(蓝色)能够收敛到更相似的解。计算在真实 CIFAR-10 标签(“泛化”)或随机 CIFAR-10 标签(“记忆”)上训练的网络组之间的 CCA 距离以及记忆和泛化网络对之间(“组间”)的 CCA 距离
也许最令人惊讶的是,在后期隐藏层中,任一给定记忆网络对之间的表征距离与记忆和泛化网络之间的表征距离大致相同(上图中的“组间”),尽管这些网络是使用标签完全不同的数据训练的。
直观地说,这一结果表明,记忆训练数据的方法有很多(导致 CCA 距离更大),但是学习可泛化解的方法却很少。在未来的工作中,我们将探索是否可以利用这一发现来规范网络,以学习泛化能力更强的解。
理解递归神经网络的训练动态
到目前为止,我们只在使用图像数据训练的 CNN 上应用了 CCA。不过,CCA 也可以用于计算 RNN 中的表征相似性,无论是在训练过程中还是在序列化过程中。在将 CCA 应用于 RNN 前,我们首先考虑 RNN 是否同样显示出在关于 CNN 的先前研究中所观察到的自下而上收敛模式。
为了检验这一点,我们测量了训练过程中 RNN 各层的表征与训练结束时的最终表征之间的 CCA 距离。我们发现,在训练中,更接近输入的层的 CCA 距离比更深的层下降得更早,这表明,与 CNN 一样,RNN 也按自下而上的模式收敛(见下图)。
RNN 在训练过程中的收敛动态表现为自下而上的收敛,因为在训练中,相对于后期层,更接近输入的层更早地收敛到它们的最终表征。例如,在训练中,第 1 层收敛到最终表征的时间比第 2 层更早,第 2 层比第 3 层早,以此类推。周期表示模型看到整个训练集的次数,而不同的颜色表示不同层的收敛动态。
我们论文中的其他发现显示,与窄网络相比,较宽的网络(例如,每层具有更多神经元的网络)能够收敛到更相似的解。我们还发现,具有相同结构但学习率不同的训练后网络收敛到具有相似性能但表征差异非常大的不同集群。我们还在单个序列化过程中将 CCA 应用于 RNN 动态,而不仅仅是在训练过程中,对随时间推移影响 RNN 表征的各种因素形成了一些初步认识。
结论
这些发现强化了分析和比较 DNN 表征的效用,以便深入了解网络功能、泛化和收敛。不过,仍有许多未解决的问题:在未来的工作中,我们希望揭示对于 CNN 和 RNN,网络中保留了表征的哪些方面,以及这些分析结果是否可以用来提升网络性能。我们鼓励其他人尝试这篇论文中使用的代码,以便研究 CCA 在其他神经网络中的应用!
Be a Tensorflower