查看原文
其他

矩阵视角下的Transformer详解(附代码)

孙裕道 PaperWeekly 2022-07-06


©PaperWeekly 原创 · 作者 | 孙裕道

单位 | 北京邮电大学博士生

研究方向 | GAN图像生成、情绪对抗样本生成




引言

Transformer 模型是 Google 团队在 2017 年 6 月由 Ashish Vaswani 等人在论文《Attention Is All You Need》所提出,当前它已经成为 NLP 领域中的首选模型。Transformer 抛弃了 RNN 的顺序结构,采用了 Self-Attention 机制,使得模型可以并行化训练,而且能够充分利用训练资料的全局信息,加入 Transformer 的 Seq2seq 模型在 NLP 的各个任务上都有了显著的提升。本文从矩阵视角下做了大量的图示目的是能够更加清晰地讲解 Transforme 的运行原理,以及相关组件的操作细节,文末还有完整可运行的代码示例。


注意力机制
Transformer 中的核心机制就是 Self-Attention。Self-Attention 机制的本质来自于人类视觉注意力机制。当人视觉在感知东西时候往往会更加关注某个场景中显著性的物体,为了合理利用有限的视觉信息处理资源,人需要选择视觉区域中的特定部分,然后集中关注它。注意力机制主要目的就是对输入进行注意力权重的分配,即决定需要关注输入的哪部分,并对其分配有限的信息处理资源给重要的部分。
2.1 Self-Attention
Self-Attention 工作原理如上图所示,给定输入 word embedding 向量 ,然后对于输入向量 通过矩阵 进行线性变换得到 向量 向量 ,以及 向量 ,即:

如果令矩阵 ,,,则此时则有:

接着再利用得到的 Query 向量和 Key 向量计算注意力得分,论文中采用的注意力计算公式为点积缩放公式:

论文中假定 向量 的元素和 Query 向量 的元素独立同分布,且令均值为 ,方差为 ,则此时注意力向量 的第 个分量  的均值为 ,方差 具体的计算公式如下:

令注意力分数矩阵 ,则有:

注意分数向量 经过 层得到归一化后的注意力分布 ,即为:
最后利用得到的注意力分布向量 矩阵 获得最后的输出 ,则有:

令输出矩阵 ,则有:

2.2 Multi-Head Attention

Multi-Head Attention 的工作原理与 Self-Attention 的工作原理非常类似。为了方便图解可视化将 Multi-Head 设置为 2-Head,如果 Multi-Head 设置为 8-Head,则上图的 的下一步的分支数为
给定输入 word embedding 向量 ,然后对于输入向量 通过矩阵 进行第一次线性变换得到 Query 向量 ,Key向量 ,以及 Value 向量
然后再对 Query 向量 通过矩阵 进行第二次线性变换得到 ,同理对 Key 向量 通过矩阵 进行第二次线性变换得到 ,对 Value 向量 通过矩阵 进行第二次线性变换得到 ,具体的计算公式如下所示:
令矩阵:
此时则有:

对于每个 Head 利用得到对于 Query 向量和 Key 向量计算对应的注意力得分,其中注意力向量 的第 个分量的计算公式为:

令注意力分数矩阵 ,则有:

注意分数向量 经过 softmax 层得到归一化后的注意力分布 ,即为:
对于每一个 Head 利用得到的注意力分布向量 和 Value 矩阵 获得最后的输出 ,则有:

两个 Head 的 的向量按照如下方式拼接在一起,则有:

给定参数矩阵 ,则输出矩阵为:

综上所述则有:

2.3 Mask Self-Attention

如下图左半部分所示,Self-Attention 的输出向量 综合了输入向量 的全部信息,由此可见,Self-Attention 在实际编程中支持并行运算。如下图右半部分所示,Mask Self-Attention 的输出向量 只利用了已知部分输入的向量 的信息。例如, 只是与 有关; 有关; 有关; 有关。Mask Self-Attention 在 Transformer 中被用到过两次。
  • Transformer 的 Encoder 中如果输入一句话的 word 长度小于指定的长度,为了能够让长度一致往往会用 0 进行填充,此时则需要用 Mask Self-Attention 来计算注意力分布。
  • Transformer 的 Decoder 的输出是有时序关系的,当前的输出只与之前的输入有关,所以此时算注意力分布时需要用到 Mask Self-Attention。




Transformer模型

以上对 Transformer 中的核心内容即自注意力机制进行了详细解剖,接下来会对 Transformer 模型架构进行介绍。Transformer 模型是由 Encoder 和 Decoder 两个模块组成,具体的示意图如下所示,为了能够对 Transformer 内部的操作细节进行更清晰的展示,下图以矩阵运算的视角对 Transformer 的原理进行讲解。

Encoder 模块操作的具体流程如下所示:

Encoder 的输入由两部分组成分别是词编码矩阵 和位置编码矩阵 ,其中 表示句子数目, 表示一句话单词的最大数目, 表示的是词向量的维度。位置编码矩阵 表示的是每个单词在一句里的所有位置信息,因为 Self-Attention 计算注意力分布的时候只能给出输出向量和输入向量之间的权重关系,但是不能给出词在一句话里的位置信息,所以需要在输入里引入位置编码矩阵 。位置编码向量生成方法有很多。一种比较简单粗暴的方式就是根据单词在句子中的位置生成一个 one-hot 的位置编码;还有的方法是将位置编码当成参数进行训练学习;在该论文里是利用三角函数对位置进行编码,具体的公式如下所示:

其中 表示的是位置编码向量, 表示词在句子中的位置, 表示编码向量的位置索引。
输入矩阵 通过线性变换生成矩阵 。在实际编程中是将输入 直接赋值给 。如果输入单词长度小于最大长度并 来填充的时候,还要相应引入 Mask 矩阵。
将矩阵 输入到 Multi-Head Attention 模块中进行注意分布的计算得到矩阵 ,计算公式为:

具体的计算细节参考上文关于 Multi-Head Attention 原理的讲解不在这里赘述。然后将原始输入 与注意力分布 进行残差计算得到输出矩阵
对矩阵 进行层归一化操作得到 ,具体的计算公式为:

输入到全连接神经网络中得到 ,然后再让全连接神经网络的输入 与输出 进行残差计算得到 ,接着对 进行层归一化操作。
以上是一个 Block 的操作原理,将 个 Block 进行堆叠就组成了 Encoder 的模块,得到的最后输出为 。这里需要注意的是 Encoder 模块中的各个组件的操作顺序并不是固定的,也可以先进行归一化操作,然后再计算注意力分布,再归一化,再预测等。

Decoder 模块操作的具体流程如下所示:

Decoder 的输入也由两部分组成分别是词编码矩阵 和位置编码矩阵 。因为 Decoder 的输入是具有时顺序关系的(即上一步的输出为当前步输入)所以还需要输入 Mask 矩阵 以便计算注意力分布。
输入矩阵 通过线性变换生成矩阵 。在实际编程中是将输入 直接赋值给 。如果输入单词长度小于最大长度并 0 来填充的时候,还要相应引入 Mask 矩阵。
将矩阵 以及 Mask 矩阵 输入到 Mask Multi-Head Attention 模块中进行注意分布的计算得到矩阵 ,计算公式为:

具体的计算细节参考上文关于 Mask Self-Attention 的讲解不在这里赘述。然后将原始输入 与注意力分布 进行残差计算得到输出矩阵
接着再对矩阵 进行层归一化操作得到
Encoder 的输出 通过线性变换得到 进行线性变换得到 ,利用矩阵 进行交叉注意力分布的计算得到 ,计算公式为:

这里的交叉注意力分布综合 Encoder 输出结果和 Decoder 中间结果的信息。实际编程编程中将 直接赋值给 直接赋值给 。然后将 与注意力分布 进行残差计算得到输出矩阵
接着对 进行层归一操作得到 ,再将 输入到全连接神经网络中得到 ,接着再做一步残差操作得到 ,最后再进行一层归一化操作。
以上是一个 Block 的操作原理,将 个 Block 进行堆叠就组成了 Decoder 的模块,得到的输出为 。然后在词汇字典中找到当前预测最大概率的单词,并将该单词词向量作为下一阶段的输入,重复以上步骤,直到输出“end”字符为止。


程序代码
Transformer 具体的代码示例如下所示。根据上文中 Multi-Head Attention 原理示例图可知,严格来看 Multi-Head Attention 在求注意分布的时候中间其实是有两步线性变换。给定输入向量 第一步线性变换直接让向量 赋值给 ,这一过程以下程序中有所体现,在这里并不会产生歧义。
第二步线性变换产生多 Head,假设 的时候,按理说 要与 个矩阵 进行线性变换得到 ,同理 要与 个矩阵 进行线性变换得到 要与 个矩阵 进行线性变换得到 ,如果按照这个方式在程序实现则需要定义 24 个权重矩阵,非常的麻烦。
以下程序中有一个简单的权重定义方法,通过该方法也可以实现以上多Head的线性变换,以向量 为例:
首先将向量 进行截断分成 个向量,即为:

其中 的第 个截断向量, 是单位矩阵, 是零矩阵。
然后对 用相同的权重矩阵 进行线性变换,此时可以发现,训练过程的时候只需要更新权重矩阵 即可,而且可以进行多 Head 线性变换, 个权重矩阵可以表示为:

其中权重矩阵 。

    import torch
    import torch.nn as nn
    import os

    class SelfAttention(nn.Module):
        def __init__(self, embed_size, heads):
            super(SelfAttention, self).__init__()
            self.embed_size = embed_size
            self.heads = heads
            self.head_dim = embed_size // heads

            assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"

            self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

        def forward(self, values, keys, query, mask):
            N =query.shape[0]
            value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]

            # split embedding into self.heads pieces
            values = values.reshape(N, value_len, self.heads, self.head_dim)
            keys = keys.reshape(N, key_len, self.heads, self.head_dim)
            queries = query.reshape(N, query_len, self.heads, self.head_dim)

            values = self.values(values)
            keys = self.keys(keys)
            queries = self.queries(queries)

            energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
            # queries shape: (N, query_len, heads, heads_dim)
            # keys shape : (N, key_len, heads, heads_dim)
            # energy shape: (N, heads, query_len, key_len)

            if mask is not None:
                energy = energy.masked_fill(mask == 0, float("-1e20"))

            attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)

            out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
            # attention shape: (N, heads, query_len, key_len)
            # values shape: (N, value_len, heads, heads_dim)
            # (N, query_len, heads, head_dim)

            out = self.fc_out(out)
            return out


    class TransformerBlock(nn.Module):
        def __init__(self, embed_size, heads, dropout, forward_expansion):
            super(TransformerBlock, self).__init__()
            self.attention = SelfAttention(embed_size, heads)
            self.norm1 = nn.LayerNorm(embed_size)
            self.norm2 = nn.LayerNorm(embed_size)

            self.feed_forward = nn.Sequential(
                nn.Linear(embed_size, forward_expansion*embed_size),
                nn.ReLU(),
                nn.Linear(forward_expansion*embed_size, embed_size)
            )
            self.dropout = nn.Dropout(dropout)

        def forward(self, value, key, query, mask):
            attention = self.attention(value, key, query, mask)

            x = self.dropout(self.norm1(attention + query))
            forward = self.feed_forward(x)
            out = self.dropout(self.norm2(forward + x))
            return out


    class Encoder(nn.Module):
        def __init__(
                self,
                src_vocab_size,
                embed_size,
                num_layers,
                heads,
                device,
                forward_expansion,
                dropout,
                max_length,
            ):

            super(Encoder, self).__init__()
            self.embed_size = embed_size
            self.device = device
            self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
            self.position_embedding = nn.Embedding(max_length, embed_size)

            self.layers = nn.ModuleList(
                [
                    TransformerBlock(
                        embed_size,
                        heads,
                        dropout=dropout,
                        forward_expansion=forward_expansion,
                        )
                    for _ in range(num_layers)]
            )
            self.dropout = nn.Dropout(dropout)


        def forward(self, x, mask):
            N, seq_length = x.shape
            positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
            out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
            for layer in self.layers:
                out = layer(out, out, out, mask)

            return out


    class DecoderBlock(nn.Module):
        def __init__(self, embed_size, heads, forward_expansion, dropout, device):
            super(DecoderBlock, self).__init__()
            self.attention = SelfAttention(embed_size, heads)
            self.norm = nn.LayerNorm(embed_size)
            self.transformer_block = TransformerBlock(
                embed_size, heads, dropout, forward_expansion
            )

            self.dropout = nn.Dropout(dropout)

        def forward(self, x, value, key, src_mask, trg_mask):
            attention = self.attention(x, x, x, trg_mask)
            query = self.dropout(self.norm(attention + x))
            out = self.transformer_block(value, key, query, src_mask)
            return out

    class Decoder(nn.Module):
        def __init__(
                self,
                trg_vocab_size,
                embed_size,
                num_layers,
                heads,
                forward_expansion,
                dropout,
                device,
                max_length,
        ):

            super(Decoder, self).__init__()
            self.device = device
            self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
            self.position_embedding = nn.Embedding(max_length, embed_size)
            self.layers = nn.ModuleList(
                [DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                for _ in range(num_layers)]
                )
            self.fc_out = nn.Linear(embed_size, trg_vocab_size)
            self.dropout = nn.Dropout(dropout)

        def forward(self, x ,enc_out , src_mask, trg_mask):
            N, seq_length = x.shape
            positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
            x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

            for layer in self.layers:
                x = layer(x, enc_out, enc_out, src_mask, trg_mask)

            out =self.fc_out(x)
            return out


    class Transformer(nn.Module):
        def __init__(
                self,
                src_vocab_size,
                trg_vocab_size,
                src_pad_idx,
                trg_pad_idx,
                embed_size = 256,
                num_layers = 6,
                forward_expansion = 4,
                heads = 8,
                dropout = 0,
                device="cuda",
                max_length=100
            ):

            super(Transformer, self).__init__()
            self.encoder = Encoder(
                src_vocab_size,
                embed_size,
                num_layers,
                heads,
                device,
                forward_expansion,
                dropout,
                max_length
                )
            self.decoder = Decoder(
                trg_vocab_size,
                embed_size,
                num_layers,
                heads,
                forward_expansion,
                dropout,
                device,
                max_length
                )


            self.src_pad_idx = src_pad_idx
            self.trg_pad_idx = trg_pad_idx
            self.device = device


        def make_src_mask(self, src):
            src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
            # (N, 1, 1, src_len)
            return src_mask.to(self.device)

        def make_trg_mask(self, trg):
            N, trg_len = trg.shape
            trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
                N, 1, trg_len, trg_len
            )
            return trg_mask.to(self.device)

        def forward(self, src, trg):
            src_mask = self.make_src_mask(src)
            trg_mask = self.make_trg_mask(trg)
            enc_src = self.encoder(src, src_mask)
            out = self.decoder(trg, enc_src, src_mask, trg_mask)
            return out


    if __name__ == '__main__':
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(device)
        x = torch.tensor([[1,5,6,4,3,9,5,2,0],[1,8,7,3,4,5,6,7,2]]).to(device)
        trg = torch.tensor([[1,7,4,3,5,9,2,0],[1,5,6,2,4,7,6,2]]).to(device)

        src_pad_idx = 0
        trg_pad_idx = 0
        src_vocab_size = 10
        trg_vocab_size = 10
        model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)
        out = model(x, trg[:, : -1])
        print(out.shape)


更多阅读




#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



·

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存