查看原文
其他

因果学习新进展:深度稳定学习

张兴璇 集智俱乐部 2022-04-08


导语


大部分当前的机器学习算法都假设并依赖训练数据和测试数据满足独立同分布的性质,但是在现实应用中测试数据的分布往往不可预见,且易与训练数据分布产生偏移,导致这些算法的准确率显著下降。为了解决这一问题,清华大学崔鹏团队将稳定学习理论与深度学习框架相结合,提出深度稳定学习模型——StableNet。在现有存在分布迁移的数据集(如NICO、PACS、VLCS和MNIST-M)上均取得SOTA结果。该模型对不同深度学习模型具有广泛适用性,以ResNet18作为骨干网络的StableNet比ResNet18性能提升8%以上。论文“Deep Stable Learning for Out-Of-Distribution Generalization”已被CVPR2021接收。


集智俱乐部联合智源社区,以因果科学和Causal AI为主题举办系列读书会,精读基础教材、研读重要论文,探讨如何借助因果科学构建可解释的人工智能系统。详情见文末。后续会邀请崔鹏团队在读书会中深度解读这项研究,欢迎大家参与。

张兴璇 | 作者

邓一雪 | 编辑





1. 数据分布迁移与稳定学习




目前深度学习在很多研究领域特别是计算机视觉领域(如图像识别、物体检测等技术领域)取得了前所未有的进展,而深度模型性能依赖于模型对训练数据的拟合。当训练数据(应用前可获取的数据)与测试数据(实际应用中遇到的实例)分布不同时,传统深度模型对训练数据的充分拟合会造成其在测试数据上的预测失败,进而导致模型应用于不同环境时的可信度降低。为了解决模型在分布迁移下的泛化问题,崔鹏老师团队提出深度稳定学习,提高模型在任意未知应用环境中的准确率和稳定性。


图1. 独立同分布学习、迁移学习和稳定学习

上图给出了常见的独立同分布模型、迁移学习模型和稳定学习模型的异同。独立同分布模型的训练和测试都在相同分布的数据下完成,测试目标是提升模型在测试集上的准确度,对测试集环境有较高的要求;迁移学习同样期望提升模型在测试集上的准确度,但是允许测试集的样本分布与训练集不同。独立同分布学习和迁移学习都要求测试集样本分布已知。而稳定学习则希望在保证模型平均准确度的前提下,降低模型性能在各种不同样本分布下的准确率方差。理论上稳定学习可以在不同分布的测试集下都有较好的性能表现。


 



2. 基于本质特征的稳定学习




现有深度学习模型试图利用所有可观测到的特征与数据标签的相关性进行学习和预测,而在训练数据中与标签相关的特征并不一定是其对应类别的本质特征。如下图所示,左图为训练集中包含狗的图片集,其中大多数图片的背景是草地,深度模型会学习到草地的特征与标签“狗”之间的虚假关联,并以此关联为基础做出预测。所以在测试集中,模型对同样为草地背景的图片有良好的判断力(右图上);对非草地背景的图片判断准确度下降(右图中、下)

图2. 计算机视觉任务中的训练集(左图),测试集(右图)

深度稳定学习的基本思路是提取不同类别的本质特征,去除无关特征与虚假关联,并仅基于本质特征(与标签存在因果关联的特征)作出预测。如下图所示,当训练数据的环境较为复杂且与样本标签存在强关联时,ResNet等传统卷积网络无法将本质特征与环境特征区分开来,所以同时利用所有特征进行预测,而StbleNet则可将本质特征与环境特征区分开来,并仅关注本质特征而忽略环境特征,从而无论环境(域)如何变化,StableNet均能做出稳定的预测。

图3. 传统深度模型与深度稳定学习模型的saliency map,其中亮度越高的点对预测结果的贡献越大,可以看到两者特征的显著不同,StableNet更关注与物体本身而传统深度模型也会关注环境特征


目前已有的稳定学习方法多针对线性模型,通过干扰变量平衡(Confounder Balancing)的方法来使得神经网络模型能够推测因果关系[3][4]。具体而言,如果要推断变量A对变量B的因果关系(存在干扰变量C),以变量A是离散的二元变量(取值为0或1)为例,根据A的值将总体样本分为两组(A=0或A=1),并给每个样本赋予不同的权重,使得在A=0和A=1时干扰变量C的分布相同(即D(C|A=0) = D(C|A=1),其中D代表变量分布),此时判断D(B|A=0) 和D(B|A=1)是否相同可以得出A是否与B有因果关系。

而在计算机视觉相关的场景中,由于经卷积网络后的各维特征为连续值且存在复杂的非线性依赖关系,无法通过直接应用上述干扰变量平衡方法来消除特征间的相关性;另外由于用于深度学习的训练数据集通常尺寸较大,深度特征的维度也较大,所以无法直接计算出全局的样本权重。本文要解决的问题,就是如何在深度学习网络中找到一组样本权重,使得所有变量之间都可以做到互相独立,即任意选取一个变量为目标变量,目标变量的分布不随其它变量的值的改变而改变。

 



3. 基于随机傅立叶特征的深度特征去相关




去除特征间相关性的基本思路是干扰变量平衡,其基本原理如下图所示:

图4. 样本变量之间独立性函数(图左);神经网络优化公式(图右)

而深度网络的各维特征间存在复杂的依赖关系,仅去除变量间的线形相关性并不足以完全消除无关特征与标签之间的虚假关联,所以一个直接的想法就是通过kernel(核方法)将映射到高维空间,但是经过kernel映射后原始特征的特征图维度被扩大到无穷维,使得各维变量间的相关性无法计算。鉴于随机傅立叶特征(Random Fourier Feature, RFF)在近似核函数以及衡量特征独立性方面的优良性质[1] [2],本文采用RFF将原始特征映射到高维空间中(可以理解为在样本维度进行扩充),消除新特征间的线形相关性即可保证原始特征严格独立,如下图所示。

图5. 用于独立性检测的随机傅立叶特征(图左);StableNet网络与样本权重更新(图右)


 



4. 全局优化样本权重




上述公式要求在训练过程中为每个训练样本都学习一个特定的权重,但在实践中,尤其对于深度学习任务,要想利用全部样本全局地学习样本权重需要巨大的计算和存储开销。此外,使用SGD对网络进行优化时,每轮迭代中仅有部分样本对模型可见,因此无法获取全部样本的特征向量。本文提出了一种存储、重加载样本特征与样本权重的方法,在每个训练迭代的结束融合并保存当前的样本特征与权重,在下一个训练迭代开始时重加载,作为训练数据的全局先验知识优化新一轮的样本权重,如下图所示。

图6. 全局先验知识(图左);先验知识更新(图右)

StableNet的结构图如下图所示,输入图片经过卷积网络后提取得视觉特征,后经过两个分支。其中上方分支为样本权重学习子网络,下方分支为常规分类网络。最终训练损失为分类网络预测损失与样本权重的加权求和。其中LSWD为去相关样本权重学习模块(Learning Sample Weights for Decorrelation),利用RFF学习使特征各维独立的样本权重。

图7. StbelNet结构图

以识别狗的应用为例,如果训练样本中大部分的狗在草地上,少部分的狗在沙滩上,图片相应的视觉特征经样本重加权后各维独立,即狗对应的特征与草地、沙滩对应的特征在统计上不相关,所以分类器在预测狗是否存在时更容易关注与狗相关的特征(若关注草地、沙滩等特征会造成预测损失激增),所以测试时无论狗在草地上或沙滩上与否,StableNet均能依据本质特征给出较准确的预测,实现模型在OOD数据上的泛化。

图8. StbelNet训练流程


 



5. 含义更广泛的域泛化任务




在常规的域泛化(DG)任务中,训练集的不同源域容量相近且异质性清晰,然而在实际应用中,绝大部分数据集都是若干潜在源域的组合,当源域异质性不清晰或未被显式标注时,我们很难假定来自于各源域的数据数量大致相同。为了更加全面地验证StableNet的泛化性能,本文提出三种新的域泛化任务来仿真更加普适且挑战性更强的分布迁移泛化场景。

1)不均衡的域泛化


对于源域不明确的域泛化问题,假定源域容量相近过于理想化,一个更普适的假设为来自不同源域的数据量可能不同且可能差异巨大。在这种情况下,模型对于未知目标域的泛化能力更满足实际应用的需求。例如在识别狗的例子中,我们很难假定背景为草地、沙滩或水里的图片数量相同,实际情况下狗较多地出现在草地上而较少出现在水里。这就要求模型的预测不能被经常与狗一起出现的背景草地误导,所以本任务的普适性和难度显著高于均衡的域泛化。

使用ResNet18作为特征提取网络的实验结果如下表,在PACS和VLCS数据集上StableNet取得了最优性能。

表1. 不均衡的域泛化实验结果

2)部分类别缺失的域泛化


我们考虑一种挑战性更大且在现实场景中经常存在的情况,某些源域中有部分类别的数据缺失,而在测试集中模型需要识别所有类别。例如,鸟经常出现在树上而几乎不会出现在水里,鱼经常出现鱼缸里而几乎不会出现在树上,所以并不是所有源域都一定包含全部类别。这种场景要求更高的模型泛化能力,由于每个源域中仅有部分类别,所以域相关的特征与标签间的虚假关联更强且更易误导分类器。

下表为实验结果,由于对域异质性及类别完整性的要求,很多现有域泛化方法无法显著优于ResNet,而StableNet在PCAS,VLCS及NICO上均取得了最优结果。

表2. 部分类别缺失的域泛化实验结果

3)存在对抗的域泛化


一种难度更大的场景是任一给定类别的主导源域与主导目标域不同。例如,训练数据中的狗大多在草地上而猫大多在室内,而测试数据中的狗大多在室内而猫大多在草地上,这就导致如果模型不能区分本质特征与域相关特征,就会被域信息所误导而做出错误预测。下表为在MNIST-M数据集上的实验结果,StableNet仍显著优于其他方法,且可见随主导域比例升高,ResNet的表现显著下降,StableNet的优势也越发明显。

表3. 存在对抗的域泛化实验结果

4)NICO数据集简介


NICO是清华大学提出的一套用于图像分类数据集,适合验证模型在异分布环境下的性能,图中的数据包含了标注的背景信息,具体参考论文[5]。

图9. NICO数据集示例


 



6. 总 结




本文提出了一种基于样本重加权的深度网络框架StableNet,将稳定学习拓展到深度学习领域,并在一系列更广泛的域泛化实验中取得了当前最优效果。StableNet利用RFF及干扰变量平衡帮助网络动态地学习与类别相关联的本质特征,消除环境特征与样本标签间的虚假关联,已实现在不同未知目标环境中的稳定预测。

参考文献

1. Ali Rahimi and Benjamin Recht. Random features for largescale kernel machines. In Advances in neural information processing systems, pages 1177–1184, 2008.

2. Eric V Strobl, Kun Zhang, and Shyam Visweswaran. Approximate kernel-based conditional independence tests for fast non-parametric causal discovery. Journal of Causal Inference, 7(1), 2019.

3. Zheyan Shen, Peng Cui, Tong Zhang, and Kun Kuang. Stable learning via sample reweighting. In AAAI, pages 5692–5699, 2020.

4. Kun Kuang, Ruoxuan Xiong, Peng Cui, Susan Athey, and Bo Li. Stable prediction with model misspecification and agnostic distribution shift. In AAAI, pages 4485–4492, 2020.

5. Yue He, Zheyan Shen, and Peng Cui. Towards non-iid image classification: A dataset and baselines. Pattern Recognition, page 107383, 2020.

 

因果科学第二季读书会报名中


因果推断与机器学习领域的结合已经吸引了越来越多来自学界业界的关注,为深入探讨、普及推广因果科学议题,智源社区携手集智俱乐部将举办第二季「因果科学与CausalAI读书会」。本期读书会着力于实操性、基础性,将带领大家精读因果科学方向两本非常受广泛认可的入门教材。

1. Pearl, Judea, Madelyn Glymour, and Nicholas P. Jewell. Causal inference in statistics: A primer. John Wiley & Sons, 2016.(本书中译版《统计因果推理入门(翻译版)》已由高等教育出版社出版)

2. Peters, Jonas, Dominik Janzing, and Bernhard Schölkopf. Elements of causal inference: foundations and learning algorithms. The MIT Press, 2017.

读书会每周将进行直播讨论,进行问题交流、重点概念分享、阅读概览和编程实践内容分析。非常适合有机器学习背景,希望深入学习因果科学基础知识和重要模型方法,寻求解决相关研究问题的朋友参加。

目前因果科学读书会系列,已经有接近400多位的海内外高校科研院所的一线科研工作者以及互联网一线从业人员参与,吸引了国内和国际上大部分的因果科学领域的专业科研人员,如果你也对这个主题感兴趣,想要深度地参与,就快加入我们吧!

详情请点击:
连接统计学、机器学习与自动推理的新兴交叉领域——因果科学读书会再起航


推荐阅读



点击“阅读原文”,即可报名

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存