查看原文
其他

因果推断:因果表征学习的CV落地

Ostrich PaperWeekly 2022-07-06


©作者 | Ostrich
单位 | 阿里巴巴算法工程师
研究方向 | 自然语言处理/搜索算法

本文主要梳理因果推断与机器学习相结合的一些比较新的工作思路,也是尝试回答自己在学习因果推断基础知识时的一些疑问:“突然”被广泛谈及的因果可以以什么样的方式落地。文章主要介绍因果表征学习在 CV 领域的应用,后续也会学习一下其他领域的应用方案。

这里将直接从因果表征讲起,篇幅原因,不再介绍因果推断的基本概念和原理,速食可参考:
https://zhuanlan.zhihu.com/p/111306353

如果兴趣浓厚也可以参考另一个不错的系列:
https://www.zhihu.com/column/c_1217887302124773376



因果表征学习


因果模型并不能处理机器学习中常见底层的原始数据,例如图像、文本,而因果表征学习则可以将图像、文本这样的原始数据转化为可用于因果模型的结构化变量。其目标是从数据中学习任务的影响因子及其无偏的关联关系。换句话说是基于低级的观测数据学习高级的因果变量

因果表征学习不同于传统的机器学习,可以看作是一种不依赖 IID(独立同分布)假设的新学习范式,其潜在假设是:应用训练模型的训练数据和测试数据可能呈现不同的数据分布,但产生两者所涉及的因果机制(大多)相同。以此为假设,因果表征天然适合解决机器学习中小样本、样本不均衡、数据观测偏置等问题,下面将要介绍的工作也主要以此展开。


因果表征学习在CV领域的应用


接下来将介绍几篇因果表征学习在 CV 领域的应用,各篇文章从动机、因果干预方法、具体实现方法三块介绍。

2.1 Visual Commonsense R-CNN [1]

2.1.1 动机

作者认为观察偏差(observational bias)导致模型会倾向于根据共现信息做任务预测,而忽略一些常识性的因果关系。因此本文希望通过因果干预训练一个蕴含常识的视觉特征,新的视觉表征可以方便应用于如 VQA、Caption 之类的下游任务。这里的视觉表征期望能学到一些知识,如“椅子可以坐”,而不是简单的根据观察得到的“桌子和椅子”的共现关系。

2.1.2 因果干预



物体 A 出现情况下 B 也出现的真正原因可能被一些虚假的共现观察混淆,如“键盘”和“鼠标”和“桌子”一起出现的频率通常要高于其他物体,如果模型严重依赖于观察样本,“键盘和鼠标是计算机的一部分”这一基本常识将被错误地归因于桌子。

因此这里的混杂因子为数据集中出现过的全部物体结合 Z,Z 与物体 X 和 Y 存在一条后门路径,X<-Z->Y。要做的事情也很明确,通过 Do 算子进行后门调整,对 X 进行干预,切断 Z 和 X 之间的因果联系:



举个栗子:

不做干预情况下的 P(Y|X):数据集中出现“马桶”的地方有”人“的概率可以统计得到,因为比较私密,“人”和“马桶”的图像会偏少,因此 P(Person|Toilet) 的概率会偏小。


加了干预后 P(Y|do(X)):遍历 Z 集合中所有的物体,计算干预后的条件概率,可以看到 P(Y|do(X))>P(Y|X),直觉上在“人”和“马桶”观测数据不足的情况下,因果干预后的条件概率有相应正向调整。


2.1.3 实现

训练目标为预测指定 ROI 的类别,Loss 包括两部分任务,1)Self Predict:直接 ROI 特征 x 通过全连接层预测其 label,2)Context Predict:基于待识别物体 y 的 ROI 以及其上下文物体 ROI 特征,预测 y 的 label,一张图片中 K 个上下文物体特征求和取评价。完整 Loss 如下:



其中 ,而:


由于直接计算 需要比较大的采样计算成本,所以文章采用了Normalized Weighted Geometric Mean(NWGM)来做近似:



全连接网络实现方法有:



其中 为 y 对 Z 中所有 z 的 Attention 的加权平均:



其中 是先验统计值, 表示 A 和 Z 的 element-wise 乘,具体 A 和 Z 的来源细节如图。


2.2 Causal Intervention for Weakly-Supervised Semantic Segmentation [2]

本文将因果干预应用于弱监督的图像语义分割任务。语义分割任务需要对输入图像中的每一个像素都进行类别预测,识别出图像中物体的像素点,如图所示:

▲ 语义分隔任务


语义分割数据标注成本很高,相应的就有通过弱监督进行语义分分割任务优化的经典方法,大致流程如下:

1)首先通过多标签图像分类模型获取图像的类响应激活图(Class Activation Map)作为种子区域(Seed Area);2)在种子区域的基础上,通过计算像素之间的语义相似性对种子区域进行扩张(Expansion)得到图像的伪标签(Pseudo-Mask);3)使用伪标签作为 Ground-Truth 训练一个全监督的语义分割模型,并在训练好的模型上对 val/test 集合进行预测。



本文目标是通过因果推断方法优化伪标签生成的过程,即上图绿色部分。

2.2.1 动机

文章通过通过 image classification 模型进行 CAM 种子扩张的方法存在明显的问题,得到的伪标签各物体边界并不清晰,影响分割模型的训练效果。文章认为这里的上下文(其他物体、背景)是混杂因子,误导图像级别的分类模型学到了像素与 label 之间的伪相关关系。举例来说,图像分类模型预测“沙发”的时候也会考虑与“沙发”经常一起出现的“地板”特征,并将“地板”的特征也作为预测“沙发”的相关依据。因此本文的目标则通过因果干预的方法消除混杂因子(上下文特征)对模型的影响。


2.2.2 因果干预

下图(a)是文章“设计“的结构化因果关系图:

  • C:上下文先验,这里为数据集所有的类别特征表示的集合,各类别特征表示为数据集中各类别所有特征的平均。
  • X:输入图像。
  • M:输入图像 X 在 C 下的具体表示,可以理解为以 C 集合中的类别特征表示为基表示 X,如 X=0.12“bird” + 0.13“bottle” + ... + 0.29“person”。
  • Y:图像对应的标签。



由因果图可以看出 X 与 Y 存在以 C 为节点的后门路径,因此本文使用后门调整方法进行因果干预:



2.2.3 实现

具体实现是一个循环的过程:首先,通过初始化弱监督语义分割模型获取图像的 mask 信息;然后,构建 Confounder set 并去除 confounder;最后将去除 confounder 后的 M 拼接到下一轮的分类模型的 backbone 中以产生更高质量的CAM。产生的 CAM 又可以用来产生更高质量的 M,以此形成一个良性循环。如下图,P(c) 为先验的 1/n,T+1 时刻的 M 可以看作是一个 Confounder set 对 T 时刻 image Mask 的一个 Attention 表示。


2.3 Two Causal Principles for Improving Visual Dialog [3]

2.3.1 动机

主流的 VQA、VisDial 领域主流模型一般采用 Encoder-Decoder 结构,以 VisDial模 型为例:先是 Encoder 将 <I,Q, H> 编码为向量,然后利用 Decoder 解码得到 A(答案),Baseline 因果图如下:



其中 H 为历史问答、I 为图像、Q 为当前问题、V 是视觉知识表征、A 为答案。这样 VisDial 很像是带 History 的 VQA,而本文作者强调:VisDial 本质上并非带有 History 的 VQA,并提出两个因果原则给予修正:



  • P1:  删除链接 H—>A:直接使用 History 作为输入建模 A 时,模型会过度关注历史问答的词汇和句式,二者不应该直接链接。
  • P2:添加一个新的节点 U(标注偏好)和三条新的链接:U<—H, U—>Q, U—>A,答案标注者基于 History 进行标注会有一定偏好,倾向于标注历史问答中出现过的表达,是因果图中的混杂因子,且不可观测。


P1 和 P2 影响的实例如下,本文要做的事情也比较明确,基于新的因果图,使用因果干预剔除混杂因子以提升效果:


2.3.2 因果干预

本文的因果干预同样为切断 U 与 Q 和 A 的后门路径,消除混杂因子 U 的影响,后门干预方式推倒如下:


2.3.3 实现

先简化一下上面的公式方便描述:



其中



因为 U 无法观测,文中提出了三种近似方法:Question Type、Answer Score Sampling、Hidden Dictionary Learning。这里简单介绍相比之下效果较好的 Hidden Dictionary Learning(HDL)方法。

和《Visual Commonsense R-CNN》类似,HDL 方法使用数据集中最热门的 100 个答案作为 集合近似。由于模型的最后一层输入是 softmax 层,因此有 ,其中 是候选 的 embedding, 集合中采样, 是 {Q, I, H} 的联合表示 embedding(encoder 的输出)。 为各参数的网络计算。


通过 NWGM 近似有:


这篇文章里 ,由于期望计算线性可加,上述公式可进一步计算为:



本文实现里, 使用内积 Attention 计算得到,,其中 ,h 是 History 的 embedding, 是元素乘计算, 是模型参数权重。

2.4 Interventional Few-Shot Learning [4]

2.4.1 动机

这里将因果干预的思路用到了小样本学习,文章首先梳理了目前主流的小样本学习方法:1)预训练,然后小样本任务微调;2)预训练后通过 meta-learning 等技术进一步进行小样本学习。



两种方式都依赖于预训练策略。而文章认为基于预训练数据虽然量大,但如果存在与小样本数据集分布不一致的情况下,反而会因为预训练数据集的观测偏差,在小样本微调过程带来一些错误归因。换句话说,预训练在带来丰富的先验知识的同时,这是成了学习过程中的一个混杂因子(confounder)。因此文章希望通过因果干预剔除混杂因子以提升小样本学习的效果。



2.4.2 因果干预

本文建立的因果图以及调整方法如下:



  • 为预训练的先验知识(如“草”、“狮子”、“车”等表示信息), 为图片的特征表示。
  • :其中 代表一个样本 在预训练数据流形上面的投影,可以理解为以 中的知识为向量基表示
  • :分类器预测标签 时,会不可避免的使用 里面的信息,即 则是 里没有包含的冗余信息带来的影响。

相应的,干预方法依然也是通过 do 算子进行后门调整:



2.4.3 实现

利用 do 算子对 X 进行干预即对 D 进行分层计算条件概率并加权求和,文章提出了三种实现方法:

1)基于特征维度分层:将特征向量 X 分成 N 份,每一份对应一个分类器, 即为 N 个分类器输出概率的平均,如 ResNet-10 的 512 维分层 4 份,学习四个分类器。

2)基于类别的调整:所有 M 个类别的预先算一个表示 (数据集中所有该类别特征的平均);用预训练的模型预测 X 属于各个类别的概率 ,以 加权平均得到 c 的表示 ,而后同 x 和 c 的特征拼接来训练分类器。

3)二者结合。

本文方案相较前面提到的方案,做了一些简化的假设,选择了简单的实现方法。不过这里的近似方法也比较有意思,比如基于特征维度分层,也可以解释为 4 个小模型的融合,或者 multi-head 结构,又像是已知方法的因果角度的一种解释。

2.5 Long-Tailed Classification by Keeping the Good and Removing the Bad Momentum Causal Effect [5]

2.5.1 动机

如题,本文将因果推断应用于数据不均衡的长尾分类问题。文章先介绍了样本不均衡带来的预测偏置问题,而后分析了常用解决长尾问题的方法,如 re-sampling、re-weighting 以及目前比较优秀的 2-stage 模式的 Decoupling 方法(原始数据的长尾分布用 backbone 学,分类器用 re-balancing 后的数据学)。



作者希望用更简单的方法 end2end 地实现长尾问题的优化,指出优化器的动量项在训练数据时会引入数据分布(混杂因子),是模型对头部类目偏好的一大原因。但直接去除动量项,又会使模型收敛难度增加。因为动量可以大大提升训练的稳定性,其带来的好处要高于其带来的损失。由此,文章另辟蹊径,提出用因果推断中的技术,尝试在保持动量项的同时,训练阶段引入因果干预,并在测试阶段进一步剔除预测偏置,做到取其精华,去其糟粕。

▲ 带动量的梯度更新

2.5.2 因果干预

本文建立的因果图如下:


其中 M 就是优化器的动量,X 是模型提取的特征,Y 是预测值。D 是 X 在动量特征下的带偏置的表示(基于 M 和 X 的特性 D 更偏好头部类别)。优化器的动量 M 包含了数据集的分布信息,他的动态平均会显著地将优化方向倾向于多数类,这也就造成了模型中的参数会有生成头部类特征的倾向,该部分偏好则体现在 D 中。

类似的,消除混淆因子 M 带来的头部类别偏置的方法也比较明确,即通过后门调整进行因果干预:

训练阶段优化 ,以断开 M->X 之间的路径:


模型训练好后的预测阶段,由 得到最终预测结果,文中中 为 0 向量,其子项表示 M 对头部的偏置项,二者相减表示剔除了头部偏置的预测结果。


2.5.3 实现

训练时,后门调整需要计算:



由于无法得到 M 的分布,文章使用带 normalize 的 multi-head 来近似,这点类似 [4] 中的“基于特征维度分层”。具体的,将分类器输出维度等分成 K 份,可以看作是对模型进行了更细粒度的采样,而 P(M) 近似为 1/K,因此有:



从表达式上来看可以近似的理解为输入特征 x 和类别 的表示权重向量 的余弦相似度(表示该类别的置信概率),同时需要注意的是这里 x 依赖 m,而在公式中没有明示。

预测阶段:



公式比较复杂推导过程论文中有详细推倒过程,其意义是从 training 的 logits 中剔除代表对头部类过度倾向的部分。前者为 后者为 为协调因子进行 trade off。



小结


几篇论文读下来,感觉各文章的思路有不少共通之处(可能也与论文师出同源有关,哈哈)。大致“套路”是:

1)从数据和现象出发,分析基线模型的因果图假设和混杂因子;2)建立新的因果图(opt);3)因果干预;4)近似 do 算子网络计算;5)效果提升。话说回来,虽然脉络大致相同,但每个工作也都能给人新奇的感觉,尤其 [5] 想到从优化器的角度切入,实在佩服。

由于本人也是初次了解该领域,文章读下来应该有一些理解不到位之处,后续也会随着理解的深入再回来修正,同时,后续也会花时间整理一下因果推断其他研究领域的应用案例。


参考文献

[1] Wang T, Huang J, Zhang H, et al. Visual commonsense r-cnn[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020: 10760-10770.

[2] Zhang D, Zhang H, Tang J, et al. Causal intervention for weakly-supervised semantic segmentation[J]. arXiv preprint arXiv:2009.12547, 2020.

[3] Qi J, Niu Y, Huang J, et al. Two causal principles for improving visual dialog[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020: 10860-10869.

[4] Yue Z, Zhang H, Sun Q, et al. Interventional few-shot learning[J]. arXiv preprint arXiv:2009.13000, 2020.

[5] Tang K, Huang J, Zhang H. Long-tailed classification by keeping the good and removing the bad momentum causal effect[J]. arXiv preprint arXiv:2009.12991, 2020.

[6] Yang X, Zhang H, Cai J. Deconfounded image captioning: A causal retrospect[J]. arXiv preprint arXiv:2003.03923, 2020.

[7] https://zhuanlan.zhihu.com/p/111306353

[8] https://zhuanlan.zhihu.com/p/260967655

[9] https://zhuanlan.zhihu.com/p/356278102

[10] https://zhuanlan.zhihu.com/p/260876366

[11] https://zhuanlan.zhihu.com/p/259569655

[12] https://zhuanlan.zhihu.com/p/359033591


更多阅读




#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



·

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存