从2021年多篇顶会论文看OOD泛化新理论、新方法和新讨论
©PaperWeekly 原创 · 作者 | 张一帆
学校 | 华南理工大学本科生
研究方向 | CV,Causality
arXiv 2021
论文标题:
Towards a Theoretical Framework of Out-of-Distribution Generalization
论文链接:
https://arxiv.org/abs/2106.04496
这篇文章应该是今年投稿 NeurIPS 的文章,文章贡献有两点:
在 OOD 泛化受到极大关注的今天,一个合适的理论框架是非常难得的,就像 DA 的泛化误差一样; 本文通过泛化误差提出了模型选择策略,不单纯使用验证集的精度,二十同时考虑验证集的精度和在各个 domain 验证精度的方差。
1.1 Preliminary
1.2 Framework of OOD Generalization Problem
作者先介绍了两个定义:特征的 “variation(变化)”和 “informativeness(信息量)”。前者是一个类似于 divergence 的概念,我们希望对同一个 label,在各个域上的特征变化不大。后者表示了这个特征要有足够的表示能力,包含了区分各个标签的能力。
Variation:给定如下定义,如果一个特征满足 ,那么我们说他是是 -invariant 的: Informativeness:给定如下定义,如果一个特征满足 ,那么我们说他是是 -Informative 的:
Expansion Function:这是一个函数 ,如果它满足:1)单调递增且 ;2),我们称之为一个扩增函数。
有了这三样东西,我们来定义最后一个最重要的概念
Learnability:对所有满足信息容量 的特征提取器而言,如果存在上述的 和一个扩增函数 ,使得 我们称一个 OOD 问题是可学习的。
1.3 Generalization Bound
看到这里可能有人疑惑了,上下界都和 variation 有关,但是和 Informativeness 无关,那我输出全 0 向量不就可以做到 invariant 了吗?答案是否定的,在 bound 的证明中总是假设该问题满足 Learnability,而 Learnability 关键的一点就是限制信息容量大于一个定值。
1.4 Variation as a Factor of Model Selection Criterion
本文中提出了一种新的模型选择策略,如果我们按照验证集的总体精确度来选择最终的模型,其实没有几个模型比 ERM 好很多,这一结果并不奇怪,因为传统的选择方法主要关注(验证)准确性,这在 OOD 概化中有偏倚。
ICML 2021
论文标题:
Can Subnetwork Structure be the Key to Out-of-Distribution Generalization?
论文链接:
https://arxiv.org/abs/2106.02890
给定数据,完整的网络,子网络的 logits ,logit 是一个用于产生 mask 的随机分布,比如网络第 层有 个参数,那么 。该层的 mask 通过从 中采样得到,mask 将完整网络转化为子网络(= 0 即忽略第 层的第 个参数); 我们对模型进行初始化然后使用 ERM 的目标进行训练 个 step; 我们从整个网络中采样子网络,结合交叉熵和稀疏正则化作为损失函数来学习有效的子网结构; 最后只需要简单地只使用所得到的子网中的权值重新进行训练,并将其他权值固定为零。
ICLR 2021
论文标题:
Understanding the failure modes of out-of-distribution generalization
论文链接:
https://arxiv.org/abs/2010.15775
代码链接:
https://github.com/google-research/OOD-failures
3.1 Motivation
现有的理论可以解释为什么当不变性特征本身信息不足时,分类器依赖于虚假特征(下图 a)。但是,当不变特征完全能够预测标签时,这些解释就不成立了。
3.2 Easy to learn domain generalization tasks
不变性特征和虚假特征都只能部分预测标签,因此一个优化负对数似然的分类器当然不能错过虚假特征包含的信息; 不变性特征和虚假特征都能完全预测标签,但是虚假特征更容易学习(更加线性),因此梯度下降会选择更容易学习的特征进行分类。
本文对这些假设进行了质疑,构造任务时针对以上每一点进行了回应,其任务有以下特点:
不变性的特征有足够的能力完成对标签的预测,虚假特征不能完全预测标签; 不变性特征有一个线性的分类 boundary,很好学习。
多数类 ,对应 cow/camel 在 green/yellow 背景下。 少数类 ,对应 cow/camel 在 yellow/green 背景下。
作者通过观察发现,即使我们的数据集不存在 geometric skew,即 max-margin 分类器不会失败,我们花费超长时间训练一个线性分类器使他收敛,他依然会依赖于虚假特征。作者在文章推导出了一个收敛性随伪相关而变化的 bound 来讨论使用梯度下降训练的过程中引入的伪相关。
总结一下,目前大部分注意力都集中在实用主义或启发式解决方案(设计或学习“不变”特性的各种技巧)上,而我们对 OOD 情况中出错原因的基本理解仍然不完整。本文旨在通过研究简化的设置来填补这些理解上的空白,并提出这样一个问题:当任务可以只使用安全的(“不变的”)特性来解决时,为什么统计模型要学习使用易变化的特性(“虚假的”特性)。在制定了多个约束条件(保证对容易学习的任务适用)后,他们表明失败有两种形式:几何倾斜和统计倾斜。他们依次进行分析和解释,同时也提供了说明性的实证结果。
ICML 2021 Oral
论文标题:
Domain Generalization using Causal Matching
论文链接:
https://arxiv.org/abs/2006.07500
代码链接:
https://github.com/microsoft/robustdg
这篇文章乍一看非常简单,但是细看之后发现其实有很多地方理解起来并不容易。
这篇文章的主要贡献在于:
作者 argue 了一件事情,我们以往学习的不变性特征表达包括 与 domain 无关还是 与 domain 无关其实都是有问题的,根据文中假设的因果图来看,要真正捕捉到域不变特征,我们需要约束 不变,其中 是图像的 object 信息。 作者加了一项看着很简单的约束:拥有相同的对象(object)的跨域图像也应该有相同的表示。
作者证明了:
满足上述约束的分类器中包含了最有分类器; 在具有虚假相关性的数据集中,优化如下的损失函数能够带来最优分类器。
到这里文章的内容好像已经完整了?其实不然,考虑一个数据非常不平衡的数据集,一个 domain 中拥有超多 object A,其他 domain 基本没有,那么上述的 match 其实是在不断地减小同一个 domain 下同一类的特征距离,这对泛化是没有太大好处的。
对于 Rotated MNIST 这类的数据集,因为是通过数据增强的方式构造的,因此非常的 balance,但是对于更加真实的数据集,这个关系显然是不成立的,这就是我对于文中 object information is not always available, and in many datasets there maynot be a perfect “counterfactual” match based on same object across domain 这句话的理解。
那么如何避免我们对 class-balance 的过度依赖,在没有非常好的 counterfactual sample 的情况下也能近似上述的约束呢?答案是学习一个 matching,这才是文章的关键。
具体的实现过程是这样的:
Initialization(构造 random match):首先我们对每一个类选择一个基域(包含该类元素最多的类),对基类的所有数据点进行遍历。对每个数据点,我们随机的在剩下 K-1 个域中给他匹配标签相同的元素,因此会构造出一个 (N',K) 大小的数据矩阵,这里 N' 即所有类的基域大小之和,K 是总共的域的数目。
Phase 1:采样一个 batch 的数据 (B,K),对 batch 中的每个数据点最小化对比损失,和他具有相同 object 不同域的样本作为正样本,不同 object 样本作为负样本。
每 t 个 epoch 使用通过对比学习学到的 representation 更新一次我们的 match。首先还是要选基域,但是在基域选定后,我们不再随机的在剩下域中挑选 sample,我们为基域中的该类的每个样本在其他域中找 representation 距离最近的点作为正样本。
简单看一下实验效果,对 MNIST 类的任务,存在 perfect match,效果非常显著。
论文标题:
Environment Inference for Invariant Learning
论文链接:
https://arxiv.org/abs/2010.07249
代码链接:
https://github.com/ecreager/eiil
没有 domain label 怎么做 OOD 泛化?这篇文章就回答了这样一个有趣的问题。给出的答案也非常的 interesting:我们自己推断 domain label 甚至能达到比使用真实域标签更好的性能。
首先文章的 motivation 在于,无论是从隐私还是标签的获取来看,域标签都是难以取得的。除此之外,在某些情况下,相关的信息或元数据(例如,人的注释、用于拍摄医疗图像的设备 ID、医院或部门 ID 等)可能非常丰富,但目前还不清楚如何最好地基于这些信息指定环境。设计算法避免人工定义环境是这篇文章的出发点。
所以很直观的,算法应该分成两部走:
推断环境标签;
利用环境标签学习域不变性特征。
在第一步推断标签的时候,我们选择最违背域不变特征的标签分配方式,分配标签使得 IRM,GroupDRO 这些算法的分类性能最差。即固定住模型 ,然后优化 EI(environment inference EI)目标,估计标签变量 最违背域不变特征。 固定住我们 inference 的标签 ,优化 invariant learning(IL)目标来产生新模型 。
论文标题:
Reducing Domain Gap by Reducing Style Bias
论文链接:
https://arxiv.org/abs/1910.11645
代码链接:
https://github.com/hyeonseobnam/sagnet
CNN 对图像纹理这类的风格元素具有很强的归纳偏置,因此对域变化非常敏感。相反其对物体形状这类真正和标签相关的元素却不敏感。本文提出了一种将 style和 content 分离开的简单方法,可以作为一种新的 backbone。
文章结构非常简单,一个 feature extractor 两个 head。content-bias head 想要做的事是将 style 信息打乱,同时还确保分类结果正确,也就是让这个 head 更关注于 content 信息。相反 style-bias head 将风格信息打乱,让这个 head 更关注于 style 信息,与此同时一个对抗学习就可以让 backbone 产生更少的 style-bias representation。
看到这里其实难点已经很明确了,如何将 style/content 信息打乱? 文章基于这样一个假设,channel-wise 的均值和方差作为风格信息,spatial configuration 作为 style 信息,这样一个假设已经被以往很多工作采用了,不过本文提出了一个更新的使用方式。首先我们来看如何打乱 style 信息。
最大熵其实是具有很不错的性质的,在我最近的一篇工作中我简单的分析了这个类型的损失函数,他能起到风格信息和 representation 互信息最小化的作用。
https://arxiv.org/abs/2103.15890
文章选择的 baseline 其实并不多,也没有 resnet50 这种大型 backbone 的结果,但是从文中展示的内容来看,SagNet 相比于现有的大多数方法还是有一定优势的。对我而言我觉得难得的是,它提供了一种 style/content 信息新的提取方式,以往的工作往往需要两个 encoder 来提取 content/style 信息。
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
关于PaperWeekly
PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。