查看原文
其他

Seq2Seq、SeqGAN、Transformer…你都掌握了吗?一文总结文本生成必备经典模型(一)

机器之心SOTA模型 计算机视觉研究院 2023-01-18

 

2023

点击蓝字 关注我们

关注并星标

从此不迷路

计算机视觉研究院

计算机视觉研究院专栏

作者:Edison_G

本专栏将逐一盘点自然语言处理、计算机视觉等领域下的常见任务,并对在这些任务上取得过 SOTA 的经典模型逐一详解。前往 SOTA!模型资源站(sota.jiqizhixin.com)即可获取本文中包含的模型实现代码、预训练模型及 API 等资源。
公众号ID|ComputerVisionGzq学习群|扫码在主页获取加入方式转自《机器之心》本期收录模型速览


图1. 模型读取一个输入句子 "ABC "并生成 "WXYZ "作为输出句子。该模型在输出句末标记后停止预测。请注意,LSTM是反向读取输入句子的,因为这样做在数据中引入了许多短期的依赖关系,使优化问题更加容易


项目SOTA!平台项目详情页
Seq2Seq(RNN)前往 SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/seq2seq


 


图2.  RNN Encoder–Decoder 架构


h_j的实际激活计算为:



在这种表述中,当复位门接近0时,隐藏状态被强制忽略之前的隐藏状态,只用当前的输入进行复位。这有效地允许隐藏状态放弃任何在未来发现不相关的信息,因此,允许一个更紧凑的表述。


当前 SOTA!平台收录 Seq2Seq(LSTM) 共 2 个模型实现资源,支持的主流框架包含 PyTorch等。

项目SOTA!平台项目详情页
Seq2Seq(LSTM)前往 SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/seq2seq-lstm




图4. 模型在给定的源句(x_1, x_2, ..., x_T)中生成第t个目标词y_t

项目SOTA!平台项目详情页
Seq2Seq+Attention前往 SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/seq2seq-attention




随机初始化G网络和D网络参数;通过MLE预训练G网络,目的是提高G网络的搜索效率;通过G网络生成部分负样预训练D网络;通过G网络生成sequence用D网络去评判,得到reward:




根据上式(4)计算得到每个action选择得到的奖励并求得累积奖励的期望,以此为loss function,并求导对网络进行梯度更新。其中,下式是标准的D网络误差函数,训练目标是最大化识别真实样本的概率,最小化误识别伪造样本的概率:



最后,GAN网络的误差函数如上,循环以上过程直至收敛。


当前 SOTA!平台收录 SeqGAN 共 22 个模型实现资源,支持的主流框架包含 PyTorch、TensorFlow 等。

项目SOTA!平台项目详情页
SeqGAN前往 SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/seqgan

Attention is all you need 


2017 年,Google 机器翻译团队发表的《Attention is All You Need》完全抛弃了RNN和CNN等网络结构,而仅仅采用Attention机制来完成机器翻译任务,并且取得了很好的效果,注意力机制也成为了研究热点。大多数竞争性神经序列转导模型都有一个编码器-解码器结构。编码器将输入的符号表示序列(x1, ..., xn)映射到连续表示的序列z=(z1, ..., zn)。给定z后,解码器每次生成一个元素的符号输出序列(y1, ..., ym)。在每个步骤中,该模型是自动回归的,在生成下一个符号时,将先前生成的符号作为额外的输入。Transformer遵循这一整体架构,在编码器和解码器中都使用了堆叠式自注意力和点式全连接层,分别在图6的左半部和右半部显示。



图6. Transformer架构


编码器。编码器是由N=6个相同的层堆叠而成。每层有两个子层。第一层是一个多头自注意力机制,第二层是一个简单的、按位置排列的全连接前馈网络。在两个子层的每一个周围采用了一个残差连接,然后进行层的归一化。也就是说,每个子层的输出是LayerNorm(x + Sublayer(x)),其中,Sublayer(x)是子层本身实现的函数。为了方便这些残差连接,模型中的所有子层以及嵌入层都会生成尺寸为dmodel=512的输出。

解码器。解码器也是由N=6个相同的层组成的堆栈。除了每个编码器层的两个子层之外,解码器还插入了第三个子层,它对编码器堆栈的输出进行多头注意力。与编码器类似,在每个子层周围采用残差连接,然后进行层归一化。进一步修改了解码器堆栈中的自注意力子层,以防止位置关注后续位置。这种masking,再加上输出嵌入偏移一个位置的事实,确保对位置i的预测只取决于小于i的位置的已知输出。

Attention。注意力函数可以描述为将一个查询和一组键值对映射到一个输出,其中,查询、键、值和输出都是向量。输出被计算为值的加权和,其中分配给每个值的权重是由查询与相应的键的兼容性函数计算的。在Transformer中使用的Attention是Scaled Dot-Product Attention, 是归一化的点乘Attention,假设输入的query q 、key维度为dk,value维度为dv , 那么就计算query和每个key的点乘操作,并除以dk ,然后应用Softmax函数计算权重。Scaled Dot-Product Attention的示意图如图7(左)。


图7. (左)按比例的点乘法注意力。(右)多头注意力由几个平行运行的注意力层组成


如果只对Q、K、V做一次这样的权重操作是不够的,这里提出了Multi-Head Attention,如图7(右)。具体操作包括:

  1. 首先对Q、K、V做一次线性映射,将输入维度均为dmodel 的Q、K、V 矩阵映射到Q∈Rm×dk,K∈Rm×dk,V∈Rm×dv;

  2. 然后在采用Scaled Dot-Product Attention计算出结果;

  3. 多次进行上述两步操作,然后将得到的结果进行合并;

  4. 将合并的结果进行线性变换。

在完整的架构中,有三处Multi-head Attention模块,分别是:

  1. Encoder模块的Self-Attention,在Encoder中,每层的Self-Attention的输入Q=K=V , 都是上一层的输出。Encoder中的每个位置都能够获取到前一层的所有位置的输出。

  2. Decoder模块的Mask Self-Attention,在Decoder中,每个位置只能获取到之前位置的信息,因此需要做mask,其设置为−∞。

  3. Encoder-Decoder之间的Attention,其中Q 来自于之前的Decoder层输出,K、V 来自于encoder的输出,这样decoder的每个位置都能够获取到输入序列的所有位置信息。

在进行了Attention操作之后,encoder和decoder中的每一层都包含了一个全连接前向网络,对每个位置的向量分别进行相同的操作,包括两个线性变换和一个ReLU激活输出:



因为模型不包括recurrence/convolution,因此是无法捕捉到序列顺序信息的,例如将K、V按行进行打乱,那么Attention之后的结果是一样的。但是序列信息非常重要,代表着全局的结构,因此必须将序列的token相对或者绝对位置信息利用起来。这里每个token的position embedding 向量维度也是dmodel=512, 然后将原本的input embedding和position embedding加起来组成最终的embedding作为encoder/decoder的输入。其中,position embedding计算公式如下:



其中,pos表征位置,i表征维度。也就是说,位置编码的每个维度对应于一个正弦波。波长形成一个从2π到10000-2π的几何级数。选择这个函数是因为假设它可以让模型很容易地学会通过相对位置来参加,因为对于任何固定的偏移量k,PE_pos+k可以表示为PE_pos的线性函数。


当前 SOTA!平台收录 Transformer 共 9 个模型实现资源,支持的主流框架包含 TensorFlow、PyTorch等。

项目SOTA!平台项目详情页
Transformer前往 SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/transformer-2

前往 SOTA!模型资源站(sota.jiqizhixin.com)即可获取本文中包含的模型实现代码、预训练模型及API等资源。 

网页端访问:在浏览器地址栏输入新版站点地址 sota.jiqizhixin.com ,即可前往「SOTA!模型」平台,查看关注的模型是否有新资源收录。 

移动端访问:在微信移动端中搜索服务号名称「机器之心SOTA模型」或 ID 「sotaai」,关注 SOTA!模型服务号,即可通过服务号底部菜单栏使用平台功能,更有最新AI技术、开发资源及社区动态定期推送。

© THE END 

转载请联系本公众号获得授权


计算机视觉研究院学习群等你加入!


计算机视觉研究院主要涉及深度学习领域,主要致力于人脸检测、人脸识别,多目标检测、目标跟踪、图像分割等研究方向。研究院接下来会不断分享最新的论文算法新框架,我们这次改革不同点就是,我们要着重”研究“。之后我们会针对相应领域分享实践过程,让大家真正体会摆脱理论的真实场景,培养爱动手编程爱动脑思考的习惯!

扫码关注

计算机视觉研究院

公众号ID|ComputerVisionGzq

学习群|扫码在主页获取加入方式


 往期推荐 

🔗

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存