【源头活水】源码解析GraphSAGE原理与面试题汇总
“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。
前言
本文主要以《Inductive Representation Learning on Large Graphs》为主(4000+引用),阐述GraphSAGE算法。
摘要
在大规模的图网络中,低维度的顶点嵌入(Low-dimensional embeddings)在各式各样的任务中越来越重要,从内容推荐到确定蛋白质功能。然而,现有的方法训练embedding的时候需要用到所有的结点,之前的方法是直推式的(transductive),不能自然地广义化到不可见的结点。现在我们提出GraphSAGE算法,一种归纳式的(inductive)技术可以利用结点特征信息(比如文本属性)来高效地生成不可见的顶点的嵌入。替代了为每个结点训练各自的embedding的方法,我们现在是学习一个函数,他可以通过采样和聚合结点局部的邻居生成embedding。我们的算法在实践中取得了很好的成绩。
01
生成顶点嵌入的思想是一种降维技术,从高维度的关于结点的邻居信息中提取出一个稠密向量。顶点嵌入能被当做下游机器学习任务的输入,然后可以进行结点分类、聚类和链接预测任务。
然而,之前的工作致力于从单一的固定的图中抽取顶点嵌入,现实中的应用需要embedding能被快速地从不可见的结点或全新的(子)图中生成。这种归纳能力对于高质量的机器学习系统是非常重要的,尤其是当数据处在一个不断进化的图中,并且不断加入新节点的情况下(比如Reddit上的帖子、Youtube上的用户和视频)。生成顶点嵌入的归纳式方法也可以促进泛化能力,比如:我们可以训练一个embedding生成模型(基于一种生物模型生成的蛋白质之间的交互的图数据),然后可以很容易地使用训练好的模型在新的生物模型生成的数据上,产生新的顶点嵌入。
归纳式顶点嵌入问题比直推式困难很多,因为推广到看不见的节点需要将新观察到的子图与算法已经优化过的节点嵌入“对齐”(aligning)。归纳式学习框架必须学会识别一个节点的邻域的结构属性,它揭示了节点在图中的局部角色以及它的全局位置。
大部分现有的生成顶点嵌入的方法都是直推式的。这些方法中的大多数使用基于矩阵分解的目标直接优化每个节点的嵌入,而不会自然地推广到看不见的数据,因为它们是对单个固定图中的节点进行预测。这些方法可以修改为归纳式的(比如DeepWalk就可以),但是这些修改在计算上很昂贵,在做出新的预测之前需要额外的梯度下降。到目前为止,GCN仅被应用在固定图直推式的任务上。本文中,我们将GCN扩展成归纳式无监督学习,并提出一种GCN的推广算法(训练聚合函数,而不是简单的卷积)。
我们的工作。我们提出了一种通用的框架,叫做 GraphSAGE (SAmple and aggreGatE),目的是进行归纳式的节点嵌入。不像基于矩阵分解的嵌入方法,我们的方法利用了节点特征(例如文本属性、节点描述信息、节点度信息),目的是学习一个嵌入方法以广义化到不可见的节点。通过协同节点特征,我们同时学习每个节点邻居的拓扑结构和邻居的节点特征。当我们致力于特征丰富的图(比如,引用数据还带有文本属性,生物信息还带有功能/分子标注)时,我们的算法还可以利用到图中展示出的结构特征。因此,我们的算法也可以应用在没有节点特征的图上。
代替训练每一个节点的唯一的嵌入向量,我们训练一组聚合函数(aggregator functions )来从节点的近邻中学习聚合特征信息(如图一)。给定义一个节点,每一个聚合函数从一个不同跳或搜索深度上聚合信息。在测试或者推理的时候,我们使用训练好的系统通过应用学习聚合函数,生成整个不可见节点的嵌入。顺着之前生成节点嵌入的思路,我们设计了一个无监督的损失函数,允许GraphSAGE能做任务无关的监督式训练。我们也展示了GraphSAGE可以以一种完全监督学习的方式训练。
我们在三个顶点分类的基准任务上评估了我们的模型,我们测试GraphSAGE在生成不可见节点的嵌入的能力。我们使用了两个动态文档图(一个是文献引用数据和一个是Reddit帖子数据,一个是预测论文分类,一个是预测帖子分类),并且还有一个多图广义化实验(蛋白质之间的交互作用,是预测蛋白质功能的任务)。使用这些基准任务,我们展示了我们的方法有能力生成不可见节点的表示,并且效果远超基线模型。
02
下面,我们追溯源码,对每个步骤进行解析。
03
3.1 数据集介绍
Cora 数据集包含了很多机器学习的论文,一共有7种分类(Case_Based、Genetic_Algorithms、Neural_Networks、Probabilistic_Methods、Reinforcement_Learning、Rule_Learning、Theory)。
每一篇论文至少被一年论文引用。一共有2708篇论文。
移除了关键词和频率小于10的词之后,还有1433个词。
数据集有两个文件,cora.content和cora.cites
cora.cites的格式为:
<paper_id> <word_attributes>+ <class_label>
每一行的第一个值是paper_id,最后一个是class_label,中间word_attributes是二值化词袋表示(出现=1,不出现=0)
cora.content的格式为:
<ID of cited paper> <ID of citing paper>
每一行的右边的引用了左边的,即"paper1 paper2" then the link is "paper2->paper1".
3.2. 无监督训练过程源码解析
3.2.1. 加载数据
注意:邻接矩阵,是当做无向边处理的,这也无可厚非,因为A引用了B,那应该就是A和B有关联。
3.2.2. 构造mini-batch数据
从训练集中选出20个结点作为一批数据源,构建mini-batch数据。
第一步——构建正例样本对(1阶邻居对):
获取每个结点的1阶邻居,获取的方式为从某结点的邻居集合中随机抽取N_WALKS=6次,这里的邻居会有重复,也会有空(因为切分数据集时产生了孤立结点)。
第二步——构建负例样本对(5阶以外样本对):
还是以上一步选中的20个结点作为数据源,先获取每一个结点的N_WALK_LEN=5阶邻居的集合,然后用训练集减去邻居的集合,再随机抽取100个结点和当前结点构成负例
3.2.3. GraphSAGE前向传播
这一部分主要是对前一节中正例和负例的样本对中涉及到的结点更新embedding,然后再使用正负例样本对的“监督”信息(其实这个监督信息是我们人为规定的,即:相近节点的向量应该相似,后面损失函数部分会具体化这个描述)来训练网络。
第一步——整理数据:先拿到上一节中样本对涉及到的所有结点。
第二步——聚合:
3.2.4. 计算损失
3.2.5. GraphSAGE得到的模型
3.2.6. 模型评估
3.3. 监督训练过程源码解析
04
05
5.1. 介绍GraphSAGE的思想
5.2. 介绍GraphSAGE的采样过程
5.3. 介绍GraphSAGE的聚合方法
5.4. GraphSAGE中的归一化问题
5.5. GraphSAGE在监督学习和非监督学习中的差异
5.6. GraphSAGE和GCN的区别
5.7. 直推式模型和归纳式模型的差异
5.8. GraphSAGE是如何结合用上结点信息和拓扑信息生成embedding的
5.9. GraphSAGE在生成embedding上和SDNE的区别
参考资料
本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。
“源头活水”历史文章
Copying Mechanism缓解未登录词问题的模型--CopyNet
旋转目标检测方法解读(KLD, NeurIPS2021)
NeurIPS-2021 | 图像未必值16x16词:可变序列长度的动态视觉Transformer来了
单目3D物体检测——基于不确定度的几何投影模型
ICCV-2021 Oral | AdaFocus:利用空间冗余性实现高效视频识别
CoSTA:用于空间转录组分析的无监督卷积神经网络学习方法
CARE-GNN论文理解
结构重参数化:利用参数转换解耦训练和推理结构
CodeVIO:紧耦合神经网络与视觉惯导里程计的稠密深度重建
RepVGG:极简架构,SOTA性能,让VGG式模型再次伟大(CVPR-2021)
降维打击!基于多模态框架的行为识别新范式
Reformer,一种高效的Transformer结构
RoadMap:一种用于自动驾驶视觉定位的轻质语义地图(ICRA2021)
ResRep:剪枝SOTA!用结构重参数化实现CNN无损压缩(ICCV)
更多源头活水专栏文章,
请点击文章底部“阅读原文”查看
分享、在看,给个三连击呗!