查看原文
其他

【综述专栏】图神经网络及其在视觉/医学图像中的应用

在科学研究中,从方法论上来讲,都应“先见森林,再见树木”。当前,人工智能学术研究方兴未艾,技术迅猛发展,可谓万木争荣,日新月异。对于AI从业者来说,在广袤的知识森林中,系统梳理脉络,才能更好地把握趋势。为此,我们精选国内外优秀的综述文章,开辟“综述专栏”,敬请关注。

来源:知乎—摸鱼家
地址:https://zhuanlan.zhihu.com/p/427533727

写在前面

之前的工作主要是基于CNN和RNN在做,前段时间因为项目需要,相对系统的了解一下图神经网络,包括理论基础,代表性的GNN(GCN, GraphSAGE和GAT)以及它的一些应用。GNN主要是应用在一些存在复杂关系的场景中,比如推荐系统,社交网络,分子结构等,在CV中并不主流。但它在CV和医学图像分析中也有被用到,本文主要是针对GNN的原理及其在这两方面的应用简单做个分享,主要是由之前在组会上分享过的PPT内容整理而成。

01

Graph基础
先简单介绍一下graph一些基础的东西,为了衔接后续的论文,这里就以Cora数据集为例进行介绍。Cora可以简单理解成CV中的手写字符MNIST数据集,是个入门级的graph数据集。
图的构成: 图由顶点 (Vertex) 和边 (Edge) 组成。在Cora数据集中,包含了2708篇论文,每篇论文作为一个vertex,相互引用关系构成图节点的连接关系edge。

一个有向图

图的表示:那么图如何像图像一样用矩阵来表示呢。这里涉及到两个部分:顶点特征,邻接矩阵。
顶点特征:在Cora数据集中,每个顶点有1433维的特征,表示是否包含某个单词(paper的内容)。
顶点特征矩阵
邻接矩阵:顶点之间的连接关系通过邻接矩阵来表示。
一个简单的无向图及其邻接矩阵


02

图神经网络GNN
2013年首次提出图上的基于频域(Spectra)和基于空域(Spatial)的卷积神经网络。2016,2017有比较大的突破,开始成为研究热点。GCN,GAT,GraphSAGE是比较经典也是最常用的几个基础算法。下面我会分别介绍这3个工作,为了方便分析比较,就以图节点分类任务为例进行介绍。

2.1. 节点分类任务

以Cora数据集为例,它包含2708篇论文,每篇论文属于8个类别中的一类。其中一部分节点有类别标签,一部分没有,现在就是要通过训练GNN,来对没有类别标签的节点进行分类。
顶点分类任务,蓝色有标签,白色没标签

2.2. GCN

ICLR 2017
Paper: Semi-Supervised Classification with Graph Convolutional Networkshttps://arxiv.org/abs/1609.02907Code: https://github.com/tkipf/gcn
GCN全称Graph Convolutional Networks [1],刚开始看论文会觉得有点不好懂,但结合这代码来看,还是比较容易理解的,代码的实现也很简单。假设输入是顶点特征X (2708,1433)和邻接矩阵A (2708,2708),下面是一个GCN layer的示意图及公式。
Step 1: 首先进行特征变换,将每个节点的1433维特征变换为16维(可选),变换是通过权重W进行的,下面是示意图,很好理解。
Step 2: 特征变换之后根据图的连接关系更新顶点特征,这一步就需要用到邻接矩阵的,当前节点的特征更新为其相连节点特征之和,同样用一个简单的示意图。
上面两步其实就是一层GCN的操作,先对每个节点单独进行特征变换,再通过聚合相连节点的特征更新每个节点特征。当多叠加几层GCN的时候,感受野会变大,也就是每个节点的特征不止与一级节点有关,还会考虑多级节点特征。

多层GCN操作

2.3. GAT

ICLR 2018
Paper: Graph Attention Networkshttps://arxiv.org/abs/1710.10903Code: https://github.com/PetarV-/GAT
GAT [2] 和GCN其实很相似,不同点在于GCN在第二步更新节点特征时,相连节点取一样的权重,这一步没有考虑到不同节点的不同重要性。GAT主要就是根据这一点来改进的。所以最关键的一步就是获得不同相邻节点对中心节点的重要性权重。
Step 1: 第一步和GCN一样,还是进行特征变换,将节点特征从1433维变为16维, 即Wh。
Step 2:第二步就是最关键获取权重的过程。首先对于一个节点a,求每个其他节点和该节点的相关性,下面是具体的公式:
这个公式的意思是把经过变换后的当前节点的特征分别和其他所有节点的特征进行concat,经过一个FC层,得到的就是每个节点和当前节点的相关性。公式中的a()可以理解成是FC操作。一共2708个节点(包括自身),所以对于该节点可以得到2708个相关系数。每个节点都求和其他节点的相关系数,所以可以得到一个[2708, 2708]的矩阵,即相关性矩阵,将这个矩阵和邻接矩阵逐点相乘,相当于忽略掉不相邻节点的影响,最后将这个矩阵进行归一化得到最终的权重矩阵:
Step 3:最后就用新得到的权重矩阵更新每个节点的特征。可以看到就是比GCN多了上面这一步。

2.4. GraphSAGE

NIPS 2017
Paper: Inductive Representation Learning in Large Attributed Graphshttps://arxiv.org/abs/1710.09471Code: https://github.com/KimMeen/GraphSage
GraphSAGE主打的是归纳式学习,而GCN是直推式学习,这两者的区别这里就不展开了。主要的应用场景是对于工业场景中非常大的图,GCN和GAT都是整图训练,GraphSAGE是通过采样在子图上训练,对于内存要求更小。原理上和GCN并没有太大差别,唯一的区别就是GraphSAGE里面增加了一个采样的过程。具体流程是:先进行图采样(包含一阶和二阶),然后在子图上进行特征聚合(更新中心节点的embedding),最后对中心节点进行分类。还是比较好理解的,具体的实现结合代码看会比较容易理解。


03

GNN在图像处理领域的应用
GNN主要是引用在一些存在复杂关系的场景中,比如推荐系统,社交网络,分子结构等,在CV中并不主流。原因在于GNN的优势是关系建模和学习,而图像这种规则的东西天然的并不适合GNN。但CV/医学图像分析中还是围绕GNN做了一些工作。就像上面提到的,在CV场景中使用GNN,关键在于graph如何构建:顶点及顶点特征是什么?顶点的连接关系怎么定义?根据图的构建方式,下面要介绍的工作大致可分为两大类:
  • GNN在图像分类中的应用
  • GNN在分割/重建中的应用

3.1. 在分类中的应用

3.1.1 用于3D医学图像的分类(UG-GAT)

MedIA 2021
Paper: Uncertainty-guided graph attention network for parapneumonic effusion diagnosis - ScienceDirecthttps://www.sciencedirect.com/science/article/abs/pii/S1361841521002620?dgcid=coauthorCode: https://github.com/iMED-Lab/UG-GAT
这一篇是基于医学图像进行疾病分类的,和3.1.2不同的地方在于graph的构建方式。上一篇是一个病人作为一个节点,所有病人构成一张大图,这样的缺点是每次有新的数据进来都要重新训练。在这篇文章中,每个病人构成一个graph,每个节点是一个slice,算是代替3D CNN的一种方式。
这篇文章的另一个创新点是在GAT中引入了不确定性,因为每个slice的信息量及其对整个volume分类的贡献度是不一样的,在GAT学习的过程中引入不确定性有助于提升分类精度。

文章采用(B)来构建图

3.1.2 医学图像-AD预测

MedIA 2018
Paper: Disease prediction using graph convolutional networks: Application to Autism Spectrum Disorder and Alzheimer’s disease - ScienceDirecthttps://www.sciencedirect.com/science/article/abs/pii/S1361841518303554Code: 未公开
这篇文章是GNN在医学图像中的应用,和前面例子中提到的半监督任务很相似。每个病人是一个节点,节点特征是CNN得到的图像特征,连接关系根据非图像。
之后利用GCN来训练图,其中一部分病人(节点)是有疾病label的,一部分没有,通过labeled data约束预测无labeled data的疾病类别。

3.1.3 Multi-Label Image Recognition

CVPR 2019
Paper: CVPR 2019 Open Access Repository (thecvf.com)https://openaccess.thecvf.com/content_CVPR_2019/html/Chen_Multi-Label_Image_Recognition_With_Graph_Convolutional_Networks_CVPR_2019_paper.htmlCode: 未公开
这篇文章的任务是图片多标签预测,一个图片中可能含有n个类比,传统的做法是使用CNN后面接n个二分类器,预测是否含有该类。在这篇文章,作者将图卷积用到了该任务中,除了CNN之外,作者希望图的结构能够建模不同标签之间的关系,从而对分类起到辅助作用。
在这篇文章中,graph是基于label来构建的,每一类是图的一个节点,节点特征是该label的词向量embedding,类别间存在连接关系,也是这篇文章中重点介绍了的。构建好图之后,送入GCN,每个节点最终的输出是2048*1的向量。CNN用来提取图像特征,最后一层是2048*1的feature。将graph每个节点(代表一个类别)的向量和CNN特征相乘,得到该类的概率。
至于构建graph时节点的连接关系,是根据每两个label同时出现的概率得到的。

3.1.4 其他疾病分类

还有两篇医学图像分类的文章就和第一篇很相似了,就不做具体介绍了。

3.2. 在分割/重建中的应用

3.2.1 Curve-GCN

CVPR 2019
Paper: CVPR 2019 Open Access Repository (thecvf.com)https://openaccess.thecvf.com/content_CVPR_2019/html/Ling_Fast_Interactive_Object_Annotation_With_Curve-GCN_CVPR_2019_paper.htmlCode: 未公开
这篇文章的任务是基于GCN做交互式标注,将mask轮廓用边缘来表示,边缘就是graph。

本文的任务是使用GNN来进行交互式标注

之前也提到了,在图像中使用graph结构,关键就是graph怎么构建。这篇文章中初始化节点的连接关系就是一个椭圆形状,节点的特征由CNN提供。将构建好的初始化图给到GCN,GCN的作用是预测每个节点的坐标值,即每个node最终有两个坐标,代表其在图像中的位置,这些点连接起来的形状就是物体的轮廓。
loss函数方面包含两部分,一个是每个点的坐标约束,为了避免过度平滑,还是用了一个基于mask的约束,就是将坐标转换为mask,对mask进行L1约束。

3.2.2 Pixel2Mesh

ECCV 2018
Paper: ECCV 2018 Open Access Repository (thecvf.com)https://openaccess.thecvf.com/content_ECCV_2018/html/Nanyang_Wang_Pixel2Mesh_Generating_3D_ECCV_2018_paper.htmlCode: https://github.com/nywang16/Pixel2Mesh
这篇文章也是比较有名的,还有一个续作Pixel2Mesh++,思路差不多。这篇文章的思路是通过2D图像直接生成3D Mesh。
关键还是graph,和上一个工作思路类似,将一个椭圆作为初始化的图,CNN提取特征作为节点特征,送入GCN预测每个节点的坐标。生成的过程是coarse-to-fine,经过3次预测,每个节点的数量慢慢增加。

3.2.3 分割-One shot learning

MICCAI/TMI 2020;TMI 2020
Paper: Learning to Segment Anatomical Structures Accurately from One Exemplar | SpringerLinkhttps://link.springer.com/chapter/10.1007/978-3-030-59710-8_66Code: 未公开
这篇文章是做医学图像的分割,具体来说是One shot learning,即只有一张有标注的label。实际做法和上一篇文章很相似,先给一个初始graph, CNN提取特征作为节点特征,唯一的label作为初始轮廓。GCN来预测每个点相对初始轮廓的偏移量,本质也是坐标回归。

3.2.4 语义分割

IJCAI 2018
Paper: View-volume Network for Semantic Scene Completion from a Single Depth Image https://arxiv.org/abs/1806.05361Code: 未公开
这篇文章的任务是利用深度图来进行2D图像的语义分割。


04

小结
GNN是针对graph来进行学习,所以关键的在于graph是什么来自哪里。Graph的优势是对关系的建模,在图像处理中要应用GNN,最关键的还是graph怎么构建。
图像分割因为是pixel-level的分类,CNN明显是更适合的,上面的一些工作也可以看出,用GNN做分割或重建还是一些辅助性的或对精度要求不那么高的场景,实际的实验也发现,GNN很难得到非常精确边缘,只能得到大致的轮廓。
相比而言,分类问题会更适合GNN发挥所长,特别是存在多模态输入的时候,graph对于关系的建模和GNN的学习能力会起到很好的作用。

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“综述专栏”历史文章


更多综述专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存