查看原文
其他

漫谈图上的分布外泛化:不变性视角下的求解

吴齐天 PaperWeekly 2023-02-02


©作者 | 吴齐天
单位 | 上海交通大学博士生
研究方向 | 机器学习与图深度学习


本文旨在深入浅出的介绍图上的分布外泛化问题(一个最近刚火的研究方向)与基于(因果)不变性原理的求解思路,对相关领域研究者提供 easy-to-follow 的讲解。本文内容主要基于今年年初笔者发表于 ICLR‘22 的论文《Handling Distribution Shifts on Graphs: An Invariance Perspective》。

原文链接:https://zhuanlan.zhihu.com/p/580112987


论文标题:
Handling Distribution Shifts on Graphs: An Invariance Perspective

论文链接:

https://openreview.net/pdf?id=FQOC5u-1egI

代码链接:

https://github.com/qitianwu/GraphOOD-EERM


这项工作首次对图上的节点级任务的分布外泛化问题给出了一般化定义,并基于不变性原理给出了有理论保障的解决思路。文末还会简单介绍笔者合作参与的三个刚被 NeurIPS‘22 接收的相关工作,并讨论可以进一步探索的方向。





写在前面


图机器学习目前依然是炙手可热的研究领域,但不少的已有方向都遇到了瓶颈期。本文将要重点介绍的分布外泛化问题(Out-of-distribution Generalization,简称 OOD 泛化)也为图学习引入了一个新的子赛道,与现有的很多场景和设定都存在可能的交叉,目前有很大的研究空间。




为什么要考虑分布外泛化的问题?

如何提高在新数据(例如未知分布或未见实体)上的泛化性能是机器学习的一个核心问题。我们知道一般的学习问题都是在一个训练集上完成模型训练,而后模型需要在一个新的测试集上给出结果。机器学习问题的误差可以被大致分解为两部分:


其中表征误差(反映了模型拟合训练数据的能力)是由模型的表达能力/容量决定的,而泛化误差则由在训练集与测试集模型表现的差异决定。当我们采用较为复杂的模型结构(例如神经网络)与有效的优化算法,可以大大降低表征误差。但是当测试数据分布与训练分布呈现明显不同时,模型的泛化误差则很难被控制。

这样的场景在实际中也很常见,比如在线下数据进行训练的推荐模型需要泛化到线上的真实场景,在模拟场景下训练的驾驶器要泛化到具有真实交互的环境中。这就是分布外泛化要解决的核心问题:如何利用有限观测的数据,学习一个稳健的模型,能够泛化到与训练分布有明显差异的测试数据上。


▲ 图数据分布偏移的典型场景:1)节点来自于不同domain(如社交网络中的用户对应不同的社区),2)图中的节点/连边动态变化(如随时间不断扩张的引用网络)




图上的节点级分布外泛化的挑战

目前大部分关于分布外泛化问题的研究集中在欧式数据(如图片),而对于图结构数据的相关研究还较少。与普通欧式数据不同的是,图结构数据上节点级预测任务的分布偏移问题需要解决两个核心的技术挑战。

1. 样本互连性:由于节点的互连特性,数据样本通常是非独立同分布的,这就为数据分布的建模带来了困难。下图给出了一个简单示意,对于图片数据我们可以把生成每张图片的分布看作相同且独立的;然而对于图结构数据,每个节点的生成依赖于邻居节点,数据分布不能被看作独立的。

2. 图结构信息:除了节点特征外,图的结构也蕴含了重要的信息,会影响到表示学习和预测任务。因此,在考虑数据分布建模与模型泛化的时候,也需要挖掘结构信息的特征并兼顾其影响。

▲ 图片数据与图结构数据的不同:图片可以看作独立产生的样本,而图中的节点(即样本)是非独立的。



从问题定义出发


“磨刀不误砍柴工“,在解决这一问题前我们先对其形式化定义,主要对图数据是如何生成过程作一个描述,然后将图上的分布外泛化问题用数学语言表达出来。假设输入数据是一个图 ,它包含了两部分信息:输入邻接矩阵 和节点特征 这里 表示节点的集合。
此外,每个节点对应一个标签,所有节点的标签组成了一个向量 我们定义 表示输入图的随机变量( 是它对应的一个具体采样),而 是标签向量的随机变量(同理 是对应的一个具体采样)。此外,我们引入一个环境变量 (称为 environment),它表示与数据生成相关联的某种上下文信息(例如时间、地点、平台、主题等)。于是,图数据的生成过程可以由联合分布的展开进行描述:

然而上述的定义方式不方便对每个节点进行分析,因此下面我们考虑一种以节点为单位的定义。

以节点邻居子图为单位的建模:将输入的图以节点为单位(在节点级任务中每个节点就是一个训练/测试样本)分解为一系列子图。具体的,假设 表示节点的随机变量,定义节点 阶邻居内的节点集合为 (这里是任意的正整数)。 中的节点形成了一个子图 ,它包含了一个(局部)节点特征矩阵 和一个(局部)邻接矩阵 。同样定义 为子图的随机变量而 是其具体采样。定义 是节点标签的随机变量, 对应具体的采样。
由此,我们将输入图分解为一系列(有重叠)子图的集合 ,这里我们可以将 视为模型(例如图神经网络)的输入, 是输出。当 给定后 与图中其他节点可以视为独立的,因此 可以被分解为 个独立相同的分布的乘积,即
基于上述定义,我们可以把观测数据 从数据生成分布 的采样生成过程看成两步:1) 首先采样一个完整的输入图 ,而它可以被视作一系列(有重叠)子图的集合 ;2) 接着对图上的每一个单一节点,采样其标签 。下面我们给出图上的分布外泛化问题的数学定义。
分布外泛化问题的形式化定义:给定训练数据 (其数据分布为 ),模型的目标是最终泛化到新的测试数据 (其数据分布为 )。更一般的,我们定义 表示环境变量 的取值集合, 是模型的预测函数即 是损失函数。于是,最终目标可以写为:
它表示我们希望模型 能够在最坏数据分布(对应损失最大的环境/domain)下依然给出稳健的结果。然而,上式定义的是一种理想情况,由于训练数据的有限无法遍历 中所有可能出现的环境,实际中只能通过设计有效的学习算法不断逼近 (1) 式的目标。下面则介绍一种本文探索的基于不变性原理的分布外泛化方法。



基于不变性原理的分布外泛化

直接解决上述的问题是非常困难的,因为模型在没有结构性假设和对学习任务的先验知识的情况下往往是不可能实现分布外泛化的(没有免费的午餐)。为此,本文从数据生成的角度,通过利用数据背后的因果不变性 [1,2,3],来引导模型学习到可以实现泛化的映射关系。
在进入技术细节之前,我们首先考虑一个具体的例子作为前序铺垫。我们考虑一个引用网络,每个节点表示一篇论文,每条连边表示论文之间的引用关系。每个节点有两个特征——论文发表的会议 与论文的影响力 ,标签 是论文的主题,环境 是论文发表的时间。我们可以将上述变量的因果关系表示为下图:


▲ 引用网络的例子。图中的三个因果依赖关系可以作如下理解:1)x1→y:论文发表的会议会决定论文研究的主题;2)y→x2:论文的影响力往往与论文的主题有关;3)e→x2:论文的影响力还与论文发表的时间有关(研究方向的流行度会随时间变化)。
在这个例子中, 会同时与 有关。也就是说,当环境发生变化时(对应于数据采样的分布发生了变化), 之间的关系也会发生变化。因此,如果模型在训练集上学习到了这部分关联 性,当迁移到测试集后就不能获得令人满意的结果(因为环境的改变导致了 关系的改变)。
相反的,如果模型在训练集中学习到了 的关系, 就能够成功迁移到测试集 (因为就算环境发生了改变, 之间的关系是稳定不变的)。这个例子提供的启发是:我们可以引导模型学习与 环境无关的关系(具体表现为当环境发生变化时,从 x 到 的关系保持不变),就能够帮助其泛化到新的环境中。这就是所谓的不变性原理 (Invariance Principle)
具体的,我们希望学习一个分类器 , 它能够从 中学到相对环境不变的表征 ,即 (表示 对应的随机变量)需要满足以下两点要求:
环境不变性 (Invariance):对于任意的环境 ,分类器给出的预测分布保持不变,即 (给定表征 ,环境与预测标签独立)。
预测充分性 (Sufficiency):表征包含的信息足够预测标签,即存在一个从 的映射 ,使得 其中 是一个随机噪声。
受以上思路的启发,我们把学习目标定义为在不同环境上对应风险损失的均值和方差:

这里定义 是一个权重超参数。这一目标的直观考虑是如果模型在不同的环境下能够给出相近的结果(即 loss 方差最小化),其学到的从 的映射就是相对环境不变的。这也有别于传统的监督学习方法 Empirical Risk Minimization(ERM),即只对每个样本的 loss 的均值进行优化,这种情况下模型就很容易学到与环境相关的映射,在训练数据上发生过拟合(对于 ERM 的局限性分析感兴趣的读者,可以进一步参考我们的论文)。
然而,上式则要求训练数据中包含来自多个环境的观测数据,并且每个数据样本对应的环境 id 也是已知的。对于图结构数据,尤其是节点级任务,这两个要求都是不满足的。通常情况下,训练数据只包含了一整张大图,也没有足够的每个节点对应哪个环境的信息。
为了解决这一困难,我们引入 个额外的数据生成器 ,基于输入图生成 份不同的图数据 来探索环境,模拟来自不同环境的观测数据。基于此,我们考虑如下的双层优化学习目标:

这里我们定义每个图数据所对应的损失函数 。针对数据生成器 ,我们将其参数化为一个图结构编辑器 (graph editor),即将每一条连边假设为自由参数,对输入图进行局部改变(删除或增加连边)。

具体的,我们将每一个改变视为动作 (action),最终使用基于策略梯度的 REINFORCE 算法进行优化,以解决离散动作空间采样不可导的问题。我们将本文提出的方法称为 Explore-to-Extrapolation Risk Minimization(EERM),下图给出了训练过程的数据流图。


▲ 提出的环境探索-风险外推最小化(Explore-to-Extrapolate Risk Minimization,简称EERM)


理论分析

为了证明提出方法的有效性,该工作给出了几点理论分析。这里将主要结论整理如下(对此部分感兴趣的读者可以阅读论文):


1. 提出的方法 EERM 可以引导模型产生的预测分布学习到稳定的从输入特征到标签的映射关系,从而在理论上保证取得理想的分布外泛化问题的最优解(由 (1) 式给出)。


2. 当模型给出的节点表示在训练集和测试集上具有相同的表达能力(具体量化为输入与输出包含在表示向量中的信息),本文提出的 EERM 可以降低测试分布上的泛化误差上界。




实验结果


为了进一步验证提出的方法,我们需要设计实验,测试模型在不同数据分布上的性能。真实的图数据中可能包含多种不同的分布偏移,这里我们考虑三种情况:人造混淆噪声(Artificial Transformation)、跨图领域迁移(Cross-Domain Transfer)、动态图时序泛化(Temporal Evolution)。下表展示了本文使用的 6 个数据集以及对应的分布偏移的形式。



处理人造混淆噪声:我们首先考虑 Cora 和 Amazon-Photo 数据集,对其引入噪声,方法如下:采用两个随机初始化的 GCN,第一个 GCN 基于原始节点特征生成节点真实标签,第二个 GCN 基于节点标签和环境 id 生成冗余特征,于是节点的特征为原始特征和冗余特征的拼接。对每个数据集,我们将环境 id 设为 1-10,总共生成 10 张图,第一张用于训练,第二张验证,其余的作为测试。如此下来,训练集与测试集之间就被引入了分布偏移,原始特征与标签的关系是对于环境不变的,而冗余特征与标签的关系则是环境敏感的。


我们考虑使用 GCN 作为预测模型主干,下图分别显示了使用传统方法(Empirical Risk Minimization,ERM,即直接优化训练数据的损失)与本文提出方法(EERM)在 Cora 和 Amazon 数据集上 8 个测试图的准确率(Accuracy)对比。这里,我们重复了 20 次实验(使用不同网络初始化),展示了准确率的分布情况。可以看到,EERM 在绝大多数情况下好于 ERM。


▲ Cora数据集的8张人造OOD测试图

▲ Amazon数据集的8张人造OOD测试图


跨图领域泛化:一种典型的分布外泛化场景是图数据上的领域泛化(Domain Generalization)。这里我们考虑 Twitch-Explicit 和 Facebook-100 数据集,它们都是社交网络,分别包含了 7 张和 100 张子图。我们使用一部分图作为训练集,另一部分作为测试。由于每一张子图都是来自不同地区的社交网络,而且大小、密度、标签分布都不尽相同,因此训练数据与测试数据就天然存在分布偏移。 


对于 Twitch 数据集,我们使用子图 DE 作为训练集,ENGB 作为验证集,其余作为测试集。由于是二分类问题且类别标签不均衡,所以我们使用 ROC-AUC 作为评测指标。下图显示了分别使用 GCN、GAT、GCNII 作为网络主干,ERM 与 EERM 在 5 个测试图上的性能对比。可以看到,EERM 在大部分情况下都超越了 ERM。


对于 Facebook 数据集,我们考虑使用多个图进行训练。具体的,我们考虑三种训练子图的组合。下表显示了使用不同训练子图的组合,在三个测试图(Penn,Brown,Texas)上的准确率对比。同样,EERM 在绝大部分情况下超越了 ERM。


动态图时序外推:另一种典型的分布偏移来源于时序动态图,训练数据往往是历史某个阶段收集的片段,测试数据则来源于未来。随着时间的推移,图数据可能发生变化。这里我们进一步考虑两种不同的情况。


第一种情况对应动态的时序 snapshot,我们考虑 Elliptic 数据集,它一共包含 49 个 graph snapshot,每一个记录了在一段时间内的金融交易,任务是识别网络中的非法节点。我们把 snapshot 按时间顺序排列,使用前 5 个作为训练,第 6-10 个作验证,其余的作为测试集(把每相邻的 4 个合并为一组)。


我们使用 F1 分数作为评测指标,下图显示了使用 GraphSAGE 和 GPRGNN 作为主干模型的效果对比。可以看到,EERM 显著好于 ERM,取得了平均 9.6%/10.0% 的提升。



接着我们考虑第二种情况,随着时间的推移,图中的节点和连边会发生变化。这里我们考虑 OGBN-Arxiv 数据集,其中每个节点是论文。我们按论文的发表时间将节点分为训练集和测试集。为了引入分布偏移,我们扩大训练节点和测试节点的时间间隔:使用 2011 前发表的论文作为训练集,2011-2014 年发表的论文作为验证集,2014 年之后的为测试集。


下表展示了时间在 2014-2016/2016-2018/2018-2020 年的测试节点上的测试准确率。可以看到,随着时间的推移(分布偏移进一步扩大),模型的性能都呈现下降趋势,但 ERM 的下降趋势更为明显。这也说明,EERM 能够有效提升模型对分布偏移的鲁棒性。




讨论与展望


图级别预测与节点级预测的联系与区别:近期也有不少工作关注神经网络在图结构数据上的分布外泛化/外推问题,例如 [4, 5, 6]。然而,他们主要专注于整图级别(graph-level)任务,有别于本文主要关注的节点级(node-level)任务。整图级别任务与节点级任务所关注的重点与技术难点是不同的。


▲ 节点级任务(本文主要关注)每个节点是一个样本(x, y),样本具有互连特性,不能视为独立采样(本文的思路是将连接样本的一张大图看作以每个节点邻居子图为单位的新的“独立”样本)。图级别任务每张图是一个样本(x, y),此时样本可以看作独立产生的。


对于图级别任务的分布外泛化问题,可以采用如下定义(将式 (1) 修改为):



这里 表示一个图样本(如分子图),分类器以每张图作为输入预测它的标签(例如分子的性质)。希望进一步了解如何利用不变性原理求解整图级别任务(分子图预测)下的分布外泛化问题,可以阅读笔者参与的一篇刚被 NeurIPS 2022 接收的论文 [7],相关方法也在 OGB 和 DrugOOD 标准 benchmark 上取得了 SOTA 效果。

更深层的图数据生成过程:图数据本身包含的结构拓扑信息是其不同于一般欧式数据的特性之一,因此在考虑图上的分布偏移时也需要对观测数据背后隐含拓扑进行考虑和建模。比如一个常见的场景是训练和测试在不同的图数据上,模型训练的图是一个完整的观测图,测试时的图拓扑发生了改变(例如节点、连边变化)。

笔者参与的另一个 NeurIPS 22 论文 [8] 就对此类图拓扑发生偏移的情形下如何提升模型泛化能力进行了探索,主要思路是从热传导过程与图神经网络的等价关系出发,挖掘拓扑背后的几何特性,引导模型学习对图的观测结构变化保持不变的映射关系。


分布外数据的判别:另一个与本文高度相关的问题是如何对分布外数据进行识别或检测。在本文所讨论的问题设定下,分布外数据只出现在了测试阶段。而现实中分布外数据也可能存在于训练集中,一个需要解决的问题就是如何识别与训练主体数据(分布内数据)有明显差异的分布外数据,帮助提升模型可靠性。


针对这种情形,另一个 NeurIPS 22 的工作 [9] 从数据生成过程出发提出了一个统一框架处理两个问题:1)如何识别训练集中的分布外数据;2)如何判别测试阶段模型未见的分布外数据。OOD 判别与 OOD 泛化本身存在很多交集,也期待后续更多的工作对其进行补充和探索。


参考文献

[1] Mateo Rojas-Carulla, et al. Invariant models for causal transfer learning. In Journal of Machine Learning Research (JMLR), 2018.

[2] Martín Arjovsky, et al. Invariant risk minimization. CoRR, abs/1907.02893, 2019.

[3] Peter Bühlmann. Invariance, causality and robustness. CoRR, abs/1812.08233, 2018.

[4] Keyulu Xu, et al. How neural networks extrapolate: From feedforward to graph neural networks. In International Conference on Learning Representations (ICLR), 2021.

[5] Beatrice Bevilacqua, et al. Size-invariant graph representations for graph classification extrapolations. In International Conference on Machine Learning (ICML), 2021.

[6] Haoyang Li et al. OOD-GNN: Out-of-Distribution Generalized Graph Neural Network. In Transactions on Knowledge and Data Engineering (TKDE), 2022.

[7] Nianzu Yang, et al, Learning Substructure Invariance for Out-of-Distribution Molecular Representations. In Advances in Neural Information Processing Systems (NeurIPS), 2022. 

[8] Chenxiao Yang, et al. Geometric Knowledge Distillation: Topology Compression for Graph Neural Networks. In Advances in Neural Information Processing Systems (NeurIPS), 2022. 

[9] Zenan Li, et al. GraphDE: A Generative Framework for Debiased Learning and Out-of-Distribution Detection on Graphs. In Advances in Neural Information Processing Systems (NeurIPS), 2022. 


更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存