精度更高,速度更快!锚点 DETR:基于 transformer 目标检测的查询设计(AAAI 2022)
The following article is from 旷视研究院 Author R
关注公众号,发现CV技术之美
本文转自旷视研究院。
● 简介 ●
近年来,以 DETR[1]为代表的基于 transformer 的端到端目标检测算法开始广受大家的关注。这类方法通过一组目标查询来推理物体与图像上下文的关系从而得到最终预测结果,且不需要 NMS 后处理,成为了一种目标检测的新范式。
首先,DETR 解码器的目标查询是一组可学习的向量。这组向量人类难以解释,没有显式的物理意义。同时,目标查询对应的预测结果的分布也没有明显的规律,这也导致模型较难优化。
● Attention 回顾 ●
这里 Q、K 和 V 分别为查询、键和值,下标 f 和 P 分别表示特征和位置编码向量,标量 为特征的维度。实际上,Q、K 和 V 还会分别经过一个全连接层,这里为了简洁省略了这部分。
DETR 的解码器包含两种 attention,一种是 self-attention,另一种是 cross-attention。
在 self-attention 中, 和 与 一样, 与 一样。其中 由上一个解码器层的输出得到,第一个解码器层的 被初始化为一个常数向量,如零向量;而 设置为一组可学的向量,为解码器中所有的 共享:
在 cross-attention 中, 由之前的 self-attention 的输出得到;而 和 是编码器的输出特征; 是编码器输出特征对应的位置编码向量,DETR 采用了正余弦函数来作为位置编码函数,我们将该位置编码函数记作 ,若编码器特征对应的位置记作 ,那么: 在此解释一下,H, W, C 分别是特征的高、宽和通道数目,而 是预设的目标查询数目。
● 查询设计 ●
通常我们把解码器中的 认作是目标查询,这是因为它负责分辨不同的物体(解码器中的初始 为零向量没有分辨能力)。
如前文所述,DETR 中的目标查询 是一组可学向量,其难以解释且没有显式的物理意义。观察这些目标查询对应的预测结果的分布,如图 1 所示,每个方格中的点表示一个目标查询对应的所有图像预测结果的中心点,可以看到,每个查询都负责非常大的范围,且导致负责的区域有很大的重叠,这种模糊性也使得网络很难优化。
图 1
在基于 CNN 的检测算法中,锚点通常都是特征网格点的坐标。而在本文中,锚点可以更加灵活。可以使用预设的网格位置的锚点,也可以是一组可以随网络学习的位置点。如图 2 所示,我们发现最终学习到的锚点分布与网格点较为相似,都是趋于均匀分布在整个图像上。这可能是因为在整个图像集中,图像的各个位置都会出现物体。
图 2
记锚点为 ,其表示有 个锚点,每个锚点记录点的(x,y)坐标。那么,基于锚点的目标查询则是: 即目标查询为锚点坐标的编码。那么如何选择位置编码函数呢?最自然地,本文选择与键特征共享一样的位置编码函数: 其中,g 为位置编码函数,它可以是前述的 ,也可以是其它的形式。在本文中我们对启发式的 额外加入了两个全连接层以更好地调整它。
更进一步考虑,有时一个位置可能会出现多个物体。显然,若一个锚点仅能预测一个物体的话,那么该位置的其它物体则需要其它位置的锚点来一同预测。这导致每个锚点负责的区域扩大,增加了其位置模糊性。为了解决这个问题,本文对每个锚点加入多种模式,使其可以有多个预测。
回顾 DETR,其中初始的查询特征为 ,对于 个目标查询来说,每个都只有一种模式 ,其中
因此,本文为每个目标查询设置多种模式 ,其中 为模式的数目,是一个较小的值,如 =3。具体而言,本文使用一组可学向量 作为目标查询的多种模式。考虑移动不变性,我们希望这些模式与位置无关,因此让各个锚点共享多种模式。如此,我们便可得到增广的初始查询特征 和查询位置编码 。
观察改进后的目标查询对应的预测结果的分布,如图 3 所示,其中最后一行为锚点,前三行是对应锚点的三种模式的预测,可以看到,基于锚点的查询将关注锚点附近的区域,查询对应的预测框中心点都分布在锚点周围。此时查询不需要预测离对应锚点很远的物体,因此其具体特定的语义,从而模型将更容易优化。
图 3
● Attention 变种 ●
目前许多的 attention 变种,如 Deformable DETR[2]、Efficient Attention[3]等,都可以大幅度降低 transformer 占用的显存。然而,也许是由于 DETR 类方法中 transformer 解码器的 cross attention 较难,若使用同样的特征,这些方法将会导致一定程度上的精度降低。
本文提出了一种行列特征解耦的 attention 变种(Row-Column Decoupled Attention, RCDA),将键特征解耦为列特征和行特征,再依次进行列 attention 和行 attention。该方法不仅可以降低显存消耗,还可以得到和原先的标准 attention 相似或者更高的精度。
首先,对于键特征 ,先将其解耦为行特征 和列特征 ,本文采用的解耦方式为分别沿着列和行做均值。
在之前的表述中,我们不失一般的假设 Attention 头的数目为 1 以更加简洁,现在我们设其为 M。在标准的 Attention 中,注意力图 为主要的显存占用瓶颈,而在行列解耦的attention中,行列注意力图 和 的显存远小于标准 attention 中的注意力图。
由于特征的通道数目 C 通常大于 M,RCDA 的中间结果 Z 的显存占用要大于行列注意力图,因此我们主要比较 RCDA 的中间结果 与标准Attention 中注意力图 之间的关系。显然,随着图像特征分辨率的增大(H 与 W 增大),标准 attention 的显存占用增长得更快。
行列解耦 attention 较标准 attention 可以节省显存的倍数为:
在默认的设置中,M=8,C=256,因此当特征长边 H 大于 32 时,RCDA 可以节省显存。在目标检测任务中,特征边长 32 是 C5 特征的一个典型值,因此使用 C5 特征显存占用相差不大,使用更大的 C4 特征显存可省 2 倍,依次类推。
● 总体流程 ●
算法的总体流程如图 5 所示,首先通过 CNN 网络提取图像特征,然后再经过transformer 编码器通过 self attention 处理图像特征,输出的图像特征将作为解码器的键和值特征。解码器的查询为前文所述基于锚点的多模式查询,在解码器中,各个查询分别根据注意力图聚合感兴趣的图像特征,最后输出最终的预测结果。预测框的中心点预测相对锚点的偏移量,而框的大小则预测其相对图像的大小。编码器和解码器中的 attention 可以采用标准的 attention,也可以采用本文所述的行列解耦 attention。对于 attention 中各特征的位置编码,则依据其位置使用共享的位置编码函数得到。
图 5
如表 2 所示,我们比较了本文算法与其它一些算法的性能比较,默认的骨干网络为 ResNet50。可以看到本文算法可以到达较好的性能,且继承了 DETR 无需手工设计锚框、无需 NMS 后处理,且不涉及随机内存访问的优秀性质。
表 2
举个例子,假设有个人(专用计算芯片)力气很大(算力很强),他可以轻松地把一叠共 1000 张纸搬到指定的地方(计算处理某个张量)。而假如让他取出其中的第 123 张和第 234 张纸搬到指定的地方,需要搬的纸虽然少了很多(计算量大幅度降低),但是由于需要找到这些指定的纸(随机内存访问),可能会更加费时(访存代价增加)。
通常来说,两阶段的检测算法由于感兴趣的区域(RoI)的坐标对硬件来说随机的,提取感兴趣区域的特征会涉及到随机内存访问。而 Deformable DETR 也涉及到提取特定坐标的特征的情况,因此也非 RAM-free。
表 3
如表 4 所示,我们比较了使用不同数目的锚点(anchor points)和模式(patterns)。100 个锚点数目过少性能较低,而 900 个锚点性能与 300 个锚点相差仅 0.3,因此我们默认使用 300 个锚点。可以看到,为每个锚点设置多种模式,性能会有明显的提升。另外,当预测结果的数目一致时,即保持锚点数目乘以模式数目的值不变时,多种模式的性能也比一种模式效果更好,这说明了多种模式的提升并非是因为预测的数目增加,而是本质更好。
表 4
[1] Carion N, Massa F, Synnaeve G, et al. End-to-end object detection with transformers[C]//European conference on computer vision. Springer, Cham, 2020: 213-229.
[2] Zhu X, Su W, Lu L, et al. Deformable detr: Deformable transformers for end-to-end object detection[J]. arXiv preprint arXiv:2010.04159, 2020.
[3] Shen Z, Zhang M, Zhao H, et al. Efficient attention: Attention with linear complexities[C]//Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. 2021: 3531-3539.
END
欢迎加入「目标检测」交流群👇备注:OD