查看原文
其他

【源头活水】视觉Transformer中的位置嵌入

“问渠那得清如许,为有源头活水来”,通过前沿领域知识的学习,从其他研究领域得到启发,对研究问题的本质有更清晰的认识和理解,是自我提高的不竭源泉。为此,我们特别精选论文阅读笔记,开辟“源头活水”专栏,帮助你广泛而深入的阅读科研文献,敬请关注。

作者:知乎—Ganso

地址:https://zhuanlan.zhihu.com/p/368188333

Transformer对于每个序列中的每个token生成键值对(query,key,value),相当于对序列不同位置做了Self-attention。但是这个过程是没有有效利用到位置信息的。


01

位置信息

改变输入顺序

如上图,如果改变x1、x2的位置。那么由于  ,所以左边x1对应的输出b1应该是和右边b1是一样的。也就是x1、x2的相对位置关系没有影响到对应位置输出结果。也就是输出对输入的位置无关性(permutation-invarian)。
而在nlp等任务中,我们应该考虑到token本身的位置信息,比如
从 上海 到 南京
从 南京 到 上海
也就是词汇本身的相对位置会影响到词汇的释意。
Transformer通过对每个位置设定一个位置编码来解决这个问题。


02

位置编码
那么为了解决这一问题,我们需要对于输入的每个向量加上独立的位置编码。

李宏毅 Transformer课件

位置编码的一个比较经典做法就是用one-hot向量。也就是在原始数据后面append一个基于位置的one-hot向量来使得同一数据在不同位置出现差异化。而在Transfomer中,作者选择在ai加上一个位置编码ei来做到这一点。
其实这两种方法是等效的,如上图所示,对于每个xi append 一个pi等价于对于原始数据矩阵append了一个单位矩阵E,而embedding矩阵w append了一个矩阵wp。在计算矩阵乘法的时候可以拆开,上述wi与xi的积就是ai,而wp与pi的结果是ei。


03

Sinusoidal位置编码
《Attention is All You Need》中使用了Sinusoidal位置编码,由于三角函数的性质:
所以这种位置编码可以很好地表达相对位置关系。


04

实现
参考博客:http://jalammar.github.io/illustrated-transformer/
这篇博客最早的可视化位置信号图片是这样的,也是流传的比较广的一个版本
参考了Tensor2Tensor(https://github.com/tensorflow/tensor2tensor)这个库里面的函数get_timing_signal_1d()。不过在Transformer论文里面有所改变,不是直接concatenate两个信号(sin、cos),而是去交错两个信号来得到最终的位置编码。
代码如下:
import numpy as npimport matplotlib.pyplot as plt
# Code from https://www.tensorflow.org/tutorials/text/transformerdef get_angles(pos, i, d_model): angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model)) return pos * angle_rates
def positional_encoding(position, d_model): angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model) # apply sin to even indices in the array; 2i angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) # apply cos to odd indices in the array; 2i+1 angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) pos_encoding = angle_rads[np.newaxis, ...] return pos_encoding
tokens = 10dimensions = 64
pos_encoding = positional_encoding(tokens, dimensions)print (pos_encoding.shape)
plt.figure(figsize=(12,8))plt.pcolormesh(pos_encoding[0], cmap='viridis')plt.xlabel('Embedding Dimensions')plt.xlim((0, dimensions))plt.ylim((tokens,0))plt.ylabel('Token Position')plt.colorbar()plt.show()
画出图片


05

CNN中的位置编码
手工设计的位置编码不好处理变长数据(图片)
《How Much Position Information Do Convolutional Neural Networks Encode?》这篇文章的观点表面CNN中的zero padding操作可以获得图片中的位置信息。目前部分Transformer 也有使用这一点来实现位置编码(CvT CPVT CeiT)。

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“源头活水”历史文章


更多源头活水专栏文章,

请点击文章底部“阅读原文”查看



分享、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存