查看原文
其他

【源头活水】源码解析GraphSAGE原理与面试题汇总



“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。

来源:知乎—黎明程序员
地址:https://zhuanlan.zhihu.com/p/415905997


前言

本文主要以《Inductive Representation Learning on Large Graphs》为主(4000+引用),阐述GraphSAGE算法。

摘要

在大规模的图网络中,低维度的顶点嵌入(Low-dimensional embeddings)在各式各样的任务中越来越重要,从内容推荐到确定蛋白质功能。然而,现有的方法训练embedding的时候需要用到所有的结点,之前的方法是直推式的(transductive),不能自然地广义化到不可见的结点。现在我们提出GraphSAGE算法,一种归纳式的(inductive)技术可以利用结点特征信息(比如文本属性)来高效地生成不可见的顶点的嵌入。替代了为每个结点训练各自的embedding的方法,我们现在是学习一个函数,他可以通过采样和聚合结点局部的邻居生成embedding。我们的算法在实践中取得了很好的成绩。


01

介绍

生成顶点嵌入的思想是一种降维技术,从高维度的关于结点的邻居信息中提取出一个稠密向量。顶点嵌入能被当做下游机器学习任务的输入,然后可以进行结点分类、聚类和链接预测任务。

然而,之前的工作致力于从单一的固定的图中抽取顶点嵌入,现实中的应用需要embedding能被快速地从不可见的结点或全新的(子)图中生成。这种归纳能力对于高质量的机器学习系统是非常重要的,尤其是当数据处在一个不断进化的图中,并且不断加入新节点的情况下(比如Reddit上的帖子、Youtube上的用户和视频)。生成顶点嵌入的归纳式方法也可以促进泛化能力,比如:我们可以训练一个embedding生成模型(基于一种生物模型生成的蛋白质之间的交互的图数据),然后可以很容易地使用训练好的模型在新的生物模型生成的数据上,产生新的顶点嵌入。

归纳式顶点嵌入问题比直推式困难很多,因为推广到看不见的节点需要将新观察到的子图与算法已经优化过的节点嵌入“对齐”(aligning)。归纳式学习框架必须学会识别一个节点的邻域的结构属性,它揭示了节点在图中的局部角色以及它的全局位置。

大部分现有的生成顶点嵌入的方法都是直推式的。这些方法中的大多数使用基于矩阵分解的目标直接优化每个节点的嵌入,而不会自然地推广到看不见的数据,因为它们是对单个固定图中的节点进行预测。这些方法可以修改为归纳式的(比如DeepWalk就可以),但是这些修改在计算上很昂贵,在做出新的预测之前需要额外的梯度下降。到目前为止,GCN仅被应用在固定图直推式的任务上。本文中,我们将GCN扩展成归纳式无监督学习,并提出一种GCN的推广算法(训练聚合函数,而不是简单的卷积)。

我们的工作。我们提出了一种通用的框架,叫做 GraphSAGE (SAmple and aggreGatE),目的是进行归纳式的节点嵌入。不像基于矩阵分解的嵌入方法,我们的方法利用了节点特征(例如文本属性、节点描述信息、节点度信息),目的是学习一个嵌入方法以广义化到不可见的节点。通过协同节点特征,我们同时学习每个节点邻居的拓扑结构和邻居的节点特征。当我们致力于特征丰富的图(比如,引用数据还带有文本属性,生物信息还带有功能/分子标注)时,我们的算法还可以利用到图中展示出的结构特征。因此,我们的算法也可以应用在没有节点特征的图上。

代替训练每一个节点的唯一的嵌入向量,我们训练一组聚合函数(aggregator functions )来从节点的近邻中学习聚合特征信息(如图一)。给定义一个节点,每一个聚合函数从一个不同跳或搜索深度上聚合信息。在测试或者推理的时候,我们使用训练好的系统通过应用学习聚合函数,生成整个不可见节点的嵌入。顺着之前生成节点嵌入的思路,我们设计了一个无监督的损失函数,允许GraphSAGE能做任务无关的监督式训练。我们也展示了GraphSAGE可以以一种完全监督学习的方式训练。

我们在三个顶点分类的基准任务上评估了我们的模型,我们测试GraphSAGE在生成不可见节点的嵌入的能力。我们使用了两个动态文档图(一个是文献引用数据和一个是Reddit帖子数据,一个是预测论文分类,一个是预测帖子分类),并且还有一个多图广义化实验(蛋白质之间的交互作用,是预测蛋白质功能的任务)。使用这些基准任务,我们展示了我们的方法有能力生成不可见节点的表示,并且效果远超基线模型。


02

算法流程

下面,我们追溯源码,对每个步骤进行解析。


03

源码解析

3.1 数据集介绍

Cora dataset
http://www.research.whizbang.com/data

Cora 数据集包含了很多机器学习的论文,一共有7种分类(Case_Based、Genetic_Algorithms、Neural_Networks、Probabilistic_Methods、Reinforcement_Learning、Rule_Learning、Theory)。

每一篇论文至少被一年论文引用。一共有2708篇论文。

移除了关键词和频率小于10的词之后,还有1433个词。

数据集有两个文件,cora.content和cora.cites

  • cora.cites的格式为:

<paper_id> <word_attributes>+ <class_label>

每一行的第一个值是paper_id,最后一个是class_label,中间word_attributes是二值化词袋表示(出现=1,不出现=0)

  • cora.content的格式为:

<ID of cited paper> <ID of citing paper>

每一行的右边的引用了左边的,即"paper1 paper2" then the link is "paper2->paper1".

3.2. 无监督训练过程源码解析

3.2.1. 加载数据

注意:邻接矩阵,是当做无向边处理的,这也无可厚非,因为A引用了B,那应该就是A和B有关联。

3.2.2. 构造mini-batch数据

从训练集中选出20个结点作为一批数据源,构建mini-batch数据。

  • 第一步——构建正例样本对(1阶邻居对):

获取每个结点的1阶邻居,获取的方式为从某结点的邻居集合中随机抽取N_WALKS=6次,这里的邻居会有重复,也会有空(因为切分数据集时产生了孤立结点)。

  • 第二步——构建负例样本对(5阶以外样本对):

还是以上一步选中的20个结点作为数据源,先获取每一个结点的N_WALK_LEN=5阶邻居的集合,然后用训练集减去邻居的集合,再随机抽取100个结点和当前结点构成负例

3.2.3. GraphSAGE前向传播

这一部分主要是对前一节中正例和负例的样本对中涉及到的结点更新embedding,然后再使用正负例样本对的“监督”信息(其实这个监督信息是我们人为规定的,即:相近节点的向量应该相似,后面损失函数部分会具体化这个描述)来训练网络。

  • 第一步——整理数据:先拿到上一节中样本对涉及到的所有结点。

可以看到一共涉及到1022个结点。
针对这1022个结点,我们在整张图(注意,在上一节中我们是对结点做的划分,目的是最终在test集合中的结点上验证效果;这里是在整张图上寻找邻居,这并不冲突,这里需要利用例邻居信息去更新这1022个结点的embedding)上获取每个结点的邻居结点(如果邻居超过10个就随机采用10个)。
这次就涉及到2144个结点了。
因为我们GraphSAGE的采样邻居距离K=2(也是聚合函数的个数),所以还要以这2144为起点,再采集一次,得到:

这次就涉及到2555个结点了。这个过程可以用下图描述:

刚开始的1022是红色中心结点的个数,2144是k=1的圈内有颜色的点的个数,2555是k=2的圈内有颜色的点的个数。
  • 第二步——聚合:
这2555个结点都有各自的特征向量(1433维的二值化词袋向量),因为我们预设了两个聚合函数,所以执行过程如下图所示:
蓝色圈内的所有的结点,在绿圈范围内聚合各自的邻居,这里的聚合函数选的是MEAN,因此经过point-wise的mean聚合之后的蓝色结点还是1433维度。
从上图可以看到,聚合之后的蓝色结点个数就是2144,维度就是1433。
下一步就需要拼接自身向量,然后经过单层全连接网络。
拼接后的向量变成1433*2=2866,转换之后的维度我们预设的是128,因此FC的W的参数如下图所示:
最终得到每个蓝色结点的128维的向量
以上就完成了k=1,即绿色圈向蓝色圈的聚合,同理还要做从蓝色圈向红色圈的聚合。
上图可以看到,1022个红色结点都有各自的128维的向量。

3.2.4. 计算损失

上面这段话中提到几个关键的点:
第一个关键点是损失函数:
Graph-based的loss function鼓励相近的结点有相似的嵌入,同时强制不相近的邻居的嵌入有大的差异。因此损失函数中的第一部分刻画的是相近结点的相似度,第二部分是不相近结点的相似度的期望,Q可以认为是一个平衡因子(代码里Q=10)。
第二个关键点是无监督学习和有监督学习:
如果使用GraphSAGE进行无监督学习,一般是为了下游的机器学习任务提供embedding,比如把这些embedding存入数据库做相似推荐。如果想使用GraphSAGE在一个具体的下游任务进行监督学习也是可以的,只需要简单的替换一下损失函数(比如换为交叉熵损失)。
我们先演示无监督学习的代码:
前面的程序我们已经取得了1022个红色结点各自的128维的向量,接下来我们就要计算损失函数了。
回忆一下我们之前选过的20个结点,每个结点有几个相近的邻居,也有100个很远的邻居。先看相远的100个邻居的分数:
对每个分数求sigmod,然后log,再然后mean求期望,然后乘以Q=10,最后neg_score=-10.24.
再看相近的几个邻居,注意这里正例并没有做mean:
最终这20个结点中的第一个结点的损失的计算如下图:

这20个结点(有的结点没近邻舍弃了)的平均损失如下图所示:

3.2.5. GraphSAGE得到的模型

我们的程序里是两层的,即K=2.
因此,GraphSage会有两个SageLayer
SageLayer的定义为:
那么我们第一个SageLayer的W是(1433+1433)*128,第二个的W是128*128.

3.2.6. 模型评估

经过前面的迭代更新网络,我们就抽取出了每一个结点的embedding,根据这个embedding就可以做一些结点分类的任务。
损失函数为交叉熵损失:

3.3. 监督训练过程源码解析

监督学习的前几步和前面无监督的过程一样,只是得到每个结点的embedding之后,不再用相近应该相似,相远应该相异的方式刻画损失了。而是直接连接上全连接网络做监督学习(交叉熵损失),这样也可以反向更新embedding,达到end-to-end的效果。
另外,GraphSAGE还可以把监督和无监督结合起来训练。见上图的plus_unsup部分。


04

实验效果


05

面试题

5.1. 介绍GraphSAGE的思想

参考:
和GCN同年发表(2017),属于归纳式模型,是一种真正能让GNN应用在大规模动态图上的一个算法。
可以监督式训练、半监督式训练、无监督式训练。
用于生成图中结点的embedding(既包含结点信息,还包含拓扑信息)。

5.2. 介绍GraphSAGE的采样过程

参考:

5.3. 介绍GraphSAGE的聚合方法

参考:

5.4. GraphSAGE中的归一化问题

参考:
mean、pool聚合的时候不需要归一化,sum聚合的时候需要归一化
https://github.com/williamleif/GraphSAGE/issues/89

5.5. GraphSAGE在监督学习和非监督学习中的差异

参考:
无监督:获取一批结点的embedding,构建损失函数的原则是“相近则相似,相远则相异”,反向更新embedding。
监督:直接根据结点的embedding,去预测对应的label,反向更新embedding。

5.6. GraphSAGE和GCN的区别

参考:
GraphSAGE是归纳式模型,GCN是直推式模型。更多可以参考 5.7和 5.9

5.7. 直推式模型和归纳式模型的差异

参考:
Transduction is reasoning from obeserved, specific (training) cases to specific (test) cases. In contrast, induction is reasoning from obeserved training cases to gerneral rules, which are then applied to the test cases.
归纳式和直推式都是属于半监督学习。
归纳式模型就是从数据中总结一般化规律,可以应用在未来没见过的数据上。比如GraphSAGE,他学习到的是,根据自身信息和采样的邻居信息计算自身embedding的方法。
直推式模型所有的数据必须都见到过。比如GCN,我们回忆一下训练GCN的过程,有一张图,有的结点有标签,还有一些结点没有标签,GCN不停的向结点汇聚邻居信息,转换,然后在有标签的数据上构建损失函数,更新每个结点的词嵌入和模型的参数W,训练好的模型可以对没有标签的结点的类别进行预测。

5.8. GraphSAGE是如何结合用上结点信息和拓扑信息生成embedding的

参考:
在第一次聚合的时候,h0是xv,就是结点的信息(cora为例的话就是二值化词袋编码),这就考虑到了结点信息了;而整个聚合过程,就利用上了拓扑信息。

5.9. GraphSAGE在生成embedding上和SDNE的区别

参考:
从直推式模型和归纳式模型来说:GraphSAGE是归纳式的,适用于动态图。SDNE是直推式的,只能用于静态图。
从捕获信息角度来说:GraphSAGE可以利用上结点和拓扑信息;而SDNE仅利用了拓扑信息,没用到结点信息。
从采样角度来说:GraphSAGE有采样过程,更适合在大规模图上应用

参考资料

https://github.com/twjiang/graphSAGE-pytorch
https://github.com/williamleif/GraphSAGE/issues/89
https://zhuanlan.zhihu.com/p/413648055

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“源头活水”历史文章


更多源头活水专栏文章,

请点击文章底部“阅读原文”查看



分享、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存