查看原文
其他

Pytorch 中如何处理 RNN 输入变长序列 padding

2018-03-18 忆臻 AI研习社

本文作者忆臻,原载于知乎专栏 —— 机器学习乱发与自然语言处理。

  一、为什么 RNN 需要处理变长输入

假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示:

思路比较简单,但是当我们进行 batch 个训练数据一起计算的时候,我们会遇到多个训练样例长度不同的情况,这样我们就会很自然的进行 padding,将短句子 padding 为跟最长的句子一样。

比如向下图这样:

但是这会有一个问题,什么问题呢?比如上图,句子 “Yes” 只有一个单词,但是 padding 了 5 的 pad 符号,这样会导致 LSTM 对它的表示通过了非常多无用的字符,这样得到的句子表示就会有误差,更直观的如下图:

那么我们正确的做法应该是怎么样呢?

这就引出 pytorch 中 RNN 需要处理变长输入的需求了。在上面这个例子,我们想要得到的表示仅仅是 LSTM 过完单词 "Yes" 之后的表示,而不是通过了多个无用的 “Pad” 得到的表示:如下图:

  二、pytorch 中 RNN 如何处理变长 padding

主要是用函数 torch.nn.utils.rnn.pack_padded_sequence() 以及 torch.nn.utils.rnn.pad_packed_sequence() 来进行的, 分别来看看这两个函数的用法。


这里的 pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

输入的形状可以是 (T×B×*)。T 是最长序列长度,B 是 batch size,* 代表任意维度 (可以是 0)。如果 batch_first=True 的话,那么相应的 input size 就是 (B×T×*)。

Variable 中保存的序列,应该按序列长度的长短排序,长的在前,短的在后(特别注意需要进行排序)。即 input[:,0] 代表的是最长的序列,input[:, B-1] 保存的是最短的序列。

参数说明:

input (Variable) – 变长序列 被填充后的 batch

lengths (list[int]) – Variable 中 每个序列的长度。(知道了每个序列的长度,才能知道每个序列处理到多长停止)

batch_first (bool, optional) – 如果是 True,input 的形状应该是 B*T*size。

返回值:

一个 PackedSequence 对象。一个 PackedSequence 表示如下所示:

具体代码如下:

embed_input_x_packed = pack_padded_sequence(embed_input_x, sentence_lens, batch_first=True)
encoder_outputs_packed, (h_last, c_last) = self.lstm(embed_input_x_packed)

此时,返回的 h_last 和 c_last 就是剔除 padding 字符后的 hidden state 和 cell state,都是 Variable 类型的。代表的意思如下(各个句子的表示,lstm 只会作用到它实际长度的句子,而不是通过无用的 padding 字符,下图用红色的打钩来表示):

但是返回的 output 是 PackedSequence 类型的,可以使用:

encoder_outputs, _ = pad_packed_sequence(encoder_outputs_packed, batch_first=True)

将 encoderoutputs 在转换为 Variable 类型,得到的_代表各个句子的长度。

  三、总结

这样综上所述,RNN 在处理类似变长的句子序列的时候,我们就可以配套使用 torch.nn.utils.rnn.pack_padded_sequence() 以及 torch.nn.utils.rnn.pad_packed_sequence() 来避免 padding 对句子表示的影响

参考:

pytorch 对可变长度序列的处理 http://suo.im/2c3XOA 

pytorch RNN 变长输入 padding http://suo.im/2xkED8 

限时拼团

3 大模块,30 个课时

高校数学系教授带班

100% 学员好评

与 100 + 同学一起夯实数学基础,走稳机器学习入门第一步!

▼▼▼





新人福利




关注 AI 研习社(okweiwu),回复  1  领取

【超过 1000G 神经网络 / AI / 大数据,教程,论文】



PyTorch 合辑

▼▼▼

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存