查看原文
其他

【源头活水】GCN和GraphSAGE的较量:从源码角度告诉你谁是爸爸

GCN是GNN的一种,用于从图结构中提取特征,以便更好的进行节点分类、边预测、图分类等任务。本文将对GCN及GraphSAGE的原理做简单的介绍,最后结合GraphSAGE的代码做详细的讲解

作者:阿亮

地址:https://www.zhihu.com/people/a-liang-32-16


01

GCN原理

假设我们手头上有一批图数据,其中有N个节点,每个节点都有自己的特征,这些特征组成了一个N×D的矩阵X。节点间的关系也会组成一个N×N的邻接矩阵A。特征矩阵X及节点关系矩阵A将一起输入至GCN[1]。

GCN层与层之间的连接方式如下,其中H是每一层的特征,W是每一层需学习的权重矩阵,σ是非线性激活函数。

稍等一下,这个公式是不是看着有点眼熟,如果把 扔掉不管,那不就是传统意义上的DNN吗?所以说GCN实质是针对图结构,对传统DNN的改进而已。现在我们把目光聚焦在“改进点” 上 ,看看GCN到底对传统的DNN做了哪些骚操作。

在  中

  • A波浪=A+I,I为单位矩阵

  • D波浪是A波浪的度矩阵

你问我啥是度矩阵?甭管他,你只要知道 D波浪的出现是为了对A波浪做一个变换就够了。其原因为A波浪不是归一化的矩阵,冒然和特征矩阵H直接相乘会改变原有特征的分布,所以通过D波浪矩阵做了一个归一化的操作,得到一个对称且归一化的矩阵  ,再和H相乘。故D波浪矩阵的引入并不是GCN对DNN核心改进,其核心还是得看A波浪矩阵,即A+I的引入。

(A+I)H实质是对邻接节点特征的一种“特征聚合”。通过(A+I)与H矩阵相乘,可以将节点的邻接节点特征通过相加的方式简单聚合起来。太抽象?那可能你对矩阵乘法不太熟悉,不理解矩阵乘法的意义,推荐你看下 这里 [2],进入链接后搜关键词“彩蛋”直达矩阵的讲解。


02

GraphSAGE原理

GraphSAGE的核心思想与GCN相同,即对邻接节点的“特征聚合”。与GCN中简单的特征相加聚合方式不同,GraphSAGE把聚合方式玩出了花。GraphSAGE提出了多种聚合方式,比如max-pooling聚合、LSTM聚合等,当然你可以提出你自己的“特征聚合”方式。聚合方式的提出使GraphSAGE的使用更加灵活。图结构总归动态变化的,时不时会有新的节点加入。当图中有新的节点加入时,聚合方式的提出可以使GraphSAGE无需重复训练,便获取最新的embedding表示。即按照已经学习到的聚合方式,把新加入节点的特征聚合起来即可。因此,GraphSAGE比GCN最大的不同在于GraphSAGE不管图结构怎么变,总能得到最新的embedding表示。


03

GraphSAGE代码讲解

  • GraphSAGE网络的正负样本

正样本产生代码如下,batch_edges为随机选择的边。边由两点构成,其中一点存储至batch1,另一点存储至batch2,两者形成正样本点对。

def batch_feed_dict(self, batch_edges): batch1 = [] batch2 = [] for node1, node2 in batch_edges: batch1.append(self.id2idx[node1]) batch2.append(self.id2idx[node2])
feed_dict = dict() feed_dict.update({self.placeholders['batch_size'] : len(batch_edges)}) feed_dict.update({self.placeholders['batch1']: batch1}) feed_dict.update({self.placeholders['batch2']: batch2})
return feed_dict

负样本产生的代码如下,核心思想为随机采样选点,所选择出的点与batch2作为负样本点对。

labels = tf.reshape( tf.cast(self.placeholders['batch2'], dtype=tf.int64), [self.batch_size, 1]) # 类型需转化为相应格式,才能放到下面函数中 self.neg_samples, _, _ = (tf.nn.fixed_unigram_candidate_sampler( # 随机采样出类别子集 true_classes=labels, num_true=1, num_sampled=FLAGS.neg_sample_size, # 随机抽象的类数 = 负样本的batch size unique=False, # 所有采样类是否都唯一 range_max=len(self.degrees), distortion=0.75, # 扭曲unigram分布 unigrams=self.degrees.tolist())) # 指定每个类被采用的概率
  • 邻居节点的聚合

获取各层所需要点集

def sample(self, inputs, layer_infos, batch_size=None): """ Sample neighbors to be the supportive fields for multi-layer convolutions. Args: inputs: batch inputs batch_size: the number of inputs (different for batch inputs and negative samples). """
if batch_size is None: batch_size = self.batch_size samples = [inputs] # 1 * batch_size # size of convolution support at each layer per node support_size = 1 support_sizes = [support_size] for k in range(len(layer_infos)): t = len(layer_infos) - k - 1 support_size *= layer_infos[t].num_samples sampler = layer_infos[t].neigh_sampler node = sampler((samples[k], layer_infos[t].num_samples)) # 随机获取第k层邻居节点 samples.append(tf.reshape(node, [support_size * batch_size,])) support_sizes.append(support_size) return samples, support_sizes # 各层所需节点、视野域

聚合操作代码如下

def aggregate(self, samples, input_features, dims, num_samples, support_sizes, batch_size=None, aggregators=None, name=None, concat=False, model_size="small"): """ At each layer, aggregate hidden representations of neighbors to compute the hidden representations at next layer. // 计算下一层的隐式表示 Args: samples: a list of samples of variable hops away for convolving at each layer of the network. Length is the number of layers + 1. Each is a vector of node indices. input_features: the input features for each sample of various hops away. dims: a list of dimensions of the hidden representations from the input layer to the final layer. Length is the number of layers + 1. num_samples: list of number of samples for each layer. support_sizes: the number of nodes to gather information from for each layer. batch_size: the number of inputs (different for batch inputs and negative samples). Returns: The hidden representation at the final layer for all nodes in batch """
if batch_size is None: batch_size = self.batch_size
# length: number of layers + 1 hidden = [tf.nn.embedding_lookup(input_features, node_samples) for node_samples in samples] # 获取各层节点的embedding矩阵 new_agg = aggregators is None if new_agg: aggregators = [] for layer in range(len(num_samples)): if new_agg: dim_mult = 2 if concat and (layer != 0) else 1 # aggregator at current layer if layer == len(num_samples) - 1: aggregator = self.aggregator_cls(dim_mult*dims[layer], dims[layer+1], act=lambda x : x, dropout=self.placeholders['dropout'], name=name, concat=concat, model_size=model_size) # 输入参数:层的输入维度、输出维度,激活函数 else: aggregator = self.aggregator_cls(dim_mult*dims[layer], dims[layer+1], dropout=self.placeholders['dropout'], name=name, concat=concat, model_size=model_size) aggregators.append(aggregator) else: aggregator = aggregators[layer] # hidden representation at current layer for all support nodes that are various hops away next_hidden = [] # as layer increases, the number of support nodes needed decreases for hop in range(len(num_samples) - layer): dim_mult = 2 if concat and (layer != 0) else 1 neigh_dims = [batch_size * support_sizes[hop], num_samples[len(num_samples) - hop - 1], dim_mult*dims[layer]] h = aggregator((hidden[hop], tf.reshape(hidden[hop + 1], neigh_dims))) # h为下一层的输入 next_hidden.append(h) hidden = next_hidden return hidden[0], aggregators

作者源码中提供了多种聚合函数,如graphsage_mean、gcn、graphsage_seq、graphsage_maxpool、graphsage_meanpool等。这里我们拿graphsage_mean做举例

def _call(self, inputs): self_vecs, neigh_vecs = inputs
neigh_vecs = tf.nn.dropout(neigh_vecs, 1-self.dropout) self_vecs = tf.nn.dropout(self_vecs, 1-self.dropout) means = tf.reduce_mean(tf.concat([neigh_vecs, tf.expand_dims(self_vecs, axis=1)], axis=1), axis=1) # 增加一维,拼接接起来求和
# [nodes] x [out_dim] output = tf.matmul(means, self.vars['weights'])
# bias if self.bias: output += self.vars['bias']
return self.act(output)
  • 损失函数

代码如下,总体损失函数=正则化的损失+正负样本的损失

def _loss(self): for aggregator in self.aggregators: for var in aggregator.vars.values(): self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
self.loss += self.link_pred_layer.loss(self.outputs1, self.outputs2, self.neg_outputs)
正负样本损失定义代码如下:即正样本点对之间的距离尽可能的近,负样本点对之间的距离尽可能的远
def _skipgram_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None): # 用点积来衡量亲密程度,也就是余弦距离 aff = self.affinity(inputs1, inputs2) # 计算两连通节点的亲密程度 neg_aff = self.neg_cost(inputs1, neg_samples, hard_neg_samples) # 负样本和一节点的亲密程度 neg_cost = tf.log(tf.reduce_sum(tf.exp(neg_aff), axis=1)) loss = tf.reduce_sum(aff - neg_cost) return loss

注释版代码GIT地址

https://github.com/a-bean-sprout/GraphSAGE_commit

参考

1. GCN的详细原理

https://www.zhihu.com/question/54504471/answer/332657604

2. GraphSAGE参考

https://zhuanlan.zhihu.com/p/74242097


本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。

直播预告

左划查看更多




历史文章推荐



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存