Attention原理及TensorFlow AttentionWrapper源码解析
Python3网络爬虫精华实战视频教程
点击上图立即了解学习限时优惠价308元
作者:崔庆才,Python技术控,爬虫博文访问量已过百万。喜欢钻研,热爱生活,乐于分享。《Python3网络爬虫开发实战》书籍作者。
个人博客:静觅 | http://cuiqingcai.com
本节来详细说明一下 Seq2Seq 模型中一个非常有用的 Attention 的机制,并结合 TensorFlow 中的 AttentionWrapper 来剖析一下其代码实现。
Seq2Seq
首先来简单说明一下 Seq2Seq 模型,如果搞过深度学习,想必一定听说过 Seq2Seq 模型,Seq2Seq 其实就是 Sequence to Sequence,也简称 S2S,也可以称之为 Encoder-Decoder 模型,这个模型的核心就是编码器(Encoder)和解码器(Decoder)组成的,架构雏形是在 2014 年由论文 Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation, Cho et al 提出的,后来 Sequence to Sequence Learning with Neural Networks, Sutskever et al 算是比较正式地提出了 Sequence to Sequence 的架构,后来 Neural Machine Translation by Jointly Learning to Align and Translate, Bahdanau et al 又提出了 Attention 机制,将 Seq2Seq 模型推上神坛,并横扫了非常多的任务,现在也非常广泛地用于机器翻译、对话生成、文本摘要生成等各种任务上,并取得了非常好的效果。
下面的图示意了 Seq2Seq 模型的基本架构:
可以看到图中有一个中间状态c向量,在c向量左侧的我们可以称之为编码器(Encoder),编码器这里示意的是 RNN 序列,另外 RNN 单元还可以使用 LSTM、GRU 等变体, 在编码器下方输入了
另外还有一种变体,c向量在每次解码的时候都会作为解码器的输入,其实原理都是类似的,如图所示:
这种模型架构是通用的,所以它的适用场景也非常广泛。如机器翻译、对话生成、文本摘要、阅读理解、语音识别,也可以用在一些趣味场景中,如诗词生成、对联生成、代码生成、评论生成等等,效果都很不错。
Attention
通过上图我们可以发现,Encoder 把所有的输入序列编码成了一个c向量,然后使用c向量来进行解码,因此,c向量中必须包含了原始序列中的所有信息,所以它的压力其实是很大的,而且由于 RNN 容易把前面的信息“忘记”掉,所以基本的 Seq2Seq 模型,对于较短的输入来说,效果还是可以接受的,但是在输入序列比较长的时候,c向量存不下那么多信息,就会导致生成效果大大折扣。
Attention 机制解决了这个问题,它可以使得在输入文本长的时候精确率也不会有明显下降,它是怎么做的呢?既然一个c向量存不了,那么就引入多个c向量,称之为
这里的每个
还是上面的例子,例如输入信息是“我爱中国”,输出的的理想结果应该是“I love China”,在解码的时候,应该首先需要解码出 “I” 这个字符,这时候会用到
下面我们以 Bahdanau 提出的 Attention 为例来详细剖析一下 Attention 机制。
在没有引入 Attention 之前,Decoder 在某个时刻解码的时候实际上是依赖于三个部分的,首先我们知道 RNN 中,每次输出结果会依赖于隐层和输入,在 Seq2Seq 模型中,还需要依赖于c向量,所以这里我们设在i时刻,解码器解码的内容是
同时
即每次的隐层输出是上一个隐层和上一个输出结果和c向量共同计算得出的。
但是刚才说了,这样会带来一些问题,c 向量不足以包含输入内容的所有信息,尤其是在输入序列特别长的情况下,所以这里我们不再使用一个c向量,而是每一个解码过程对应一个
同时
所以,这里每次解码得出
编码器输出的结果中,
那么
同时
这也就是说,这个权重就是
因此
以上便是整个 Attention 机制的推导过程。
TensorFlow AttentionWrapper
我们了解了基本原理,但真正离程序实现出来其实还是有很大差距的,接下来我们就结合 TensorFlow 框架来了解一下 Attention 的实现机制。
在 TensorFlow 中,Attention 的相关实现代码是在 tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py 文件中,这里面实现了两种 Attention 机制,分别是 BahdanauAttention 和 LuongAttention,其实现论文分别如下:
Neural Machine Translation by Jointly Learning to Align and Translate, Bahdanau, et al
Effective Approaches to Attention-based Neural Machine Translation, Luong, et al
整个 attention_wrapper.py 文件中主要包含几个类,我们主要关注其中几个:
AttentionMechanism、_BaseAttentionMechanism、LuongAttention、BahdanauAttention 实现了 Attention 机制的逻辑。
AttentionMechanism 是 Attention 类的父类,继承了 object 类,内部没有任何实现。
_BaseAttentionMechanism 继承自 AttentionMechanism 类,定义了 Attention 机制的一些公共方法实现和属性。
LuongAttention、BahdanauAttention 均继承 _BaseAttentionMechanism 类,分别实现了上面两篇论文的 Attention 机制。
AttentionWrapperState 用来存储整个计算过程中的 state,和 RNN 中的 state 类似,只不过这里额外还存储了 attention、time 等信息。
AttentionWrapper 主要用于对封装 RNNCell,继承自 RNNCell,封装后依然是 RNNCell 的实例,可以构建一个带有 Attention 机制的 Decoder。
另外还有一些公共方法,例如 hardmax、safe_cumpord 等。
下面我们以 BahdanauAttention 为例来说明 Attention 机制及 AttentionWrapper 的实现。
BathdanauAttention
首先我们来介绍 BahdanauAttention 类的具体原理。
首先我们来看下它的初始化方法:
def __init__(self,
num_units,
memory,
memory_sequence_length=None,
normalize=False,
probability_fn=None,
score_mask_value=None,
dtype=None,
name="BahdanauAttention"):
这里一共接受八个参数,下面一一进行说明:
numunits:神经元节点数,我们知道在计算
的时候,需要使用 和 来进行计算,而二者的维度可能并不是统一的,需要进行变换和统一,所以这里就有了Wa和Ua这两个系数,所以在代码中就是用 num_units 来声明了一个全连接 Dense 网络,用于统一二者的维度,以便于下一步的计算:
query_layer=layers_core.Dense(num_units, name="query_layer", use_bias=False, dtype=dtype)
memory_layer=layers_core.Dense(num_units, name="memory_layer", use_bias=False, dtype=dtype)
这里我们可以看到声明了一个 querylayer 和 memory_layer,分别和
memory:The memory to query; usually the output of an RNN encoder. 即解码时用到的上文信息,维度需要是 [batch_size, max_time, context_dim]。这时我们观察一下父类 _BaseAttentionMechanism 的初始化方法,实现如下:
with ops.name_scope(
name, "BaseAttentionMechanismInit", nest.flatten(memory)):
self._values = _prepare_memory(
memory, memory_sequence_length,
check_inner_dims_defined=check_inner_dims_defined)
self._keys = (
self.memory_layer(self._values) if self.memory_layer
else self._values)
这里通过 _prepare_memory() 方法对 memory 进行处理,然后调用 memory_layer 对 memory 进行全连接维度变换,变换成 [batch_size, max_time, num_units]。
memory_sequence_length:Sequence lengths for the batch entries in memory. 即 memory 变量的长度信息,类似于 dynamic_rnn 中的 sequence_length,被 _prepare_memory() 方法调用处理 memory 变量,进行 mask 操作:
seq_len_mask = array_ops.sequence_mask(
memory_sequence_length,
maxlen=array_ops.shape(nest.flatten(memory)[0])[1],
dtype=nest.flatten(memory)[0].dtype)
seq_len_batch_size = (
memory_sequence_length.shape[0].value
or array_ops.shape(memory_sequence_length)[0])
normalize:Whether to normalize the energy term. 即是否要实现标准化,方法出自论文:Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks, Salimans, et al。
probability_fn:A callable function which converts the score to probabilities. 计算概率时的函数,必须是一个可调用的函数,默认使用 softmax(),还可以指定 hardmax() 等函数。
score_mask_value:The mask value for score before passing into probability_fn. The default is -inf. Only used if memory_sequence_length is not None. 在使用 probability_fn 计算概率之前,对 score 预先进行 mask 使用的值,默认是负无穷。但这个只有在 memory_sequence_length 参数定义的时候有效。
dtype:The data type for the query and memory layers of the attention mechanism. 数据类型,默认是 float32。
name:Name to use when creating ops,自定义名称。
接下来类里面定义了一个 __call__() 方法:
def __call__(self, query, previous_alignments):
with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
processed_query = self.query_layer(query) if self.query_layer else query
score = _bahdanau_score(processed_query, self._keys, self._normalize)
alignments = self._probability_fn(score, previous_alignments)
return alignments
这里首先定义了 processedquery,这里也是通过 query_layer 过了一个全连接网络,将最后一维统一成 num_units,然后调用了 _bahdanau_score() 方法,这个方法是比较重要的,主要用来计算公式中的
def _bahdanau_score(processed_query, keys, normalize):
dtype = processed_query.dtype
# Get the number of hidden units from the trailing dimension of keys
num_units = keys.shape[2].value or array_ops.shape(keys)[2]
# Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
processed_query = array_ops.expand_dims(processed_query, 1)
v = variable_scope.get_variable(
"attention_v", [num_units], dtype=dtype)
if normalize:
# Scalar used in weight normalization
g = variable_scope.get_variable(
"attention_g", dtype=dtype,
initializer=math.sqrt((1. / num_units)))
# Bias added prior to the nonlinearity
b = variable_scope.get_variable(
"attention_b", [num_units], dtype=dtype,
initializer=init_ops.zeros_initializer())
# normed_v = g * v / ||v||
normed_v = g * v * math_ops.rsqrt(
math_ops.reduce_sum(math_ops.square(v)))
return math_ops.reduce_sum(normed_v * math_ops.tanh(keys + processed_query + b), [2])
else:
return math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2])
这里其实就是实现了 keys 和 processedquery 的加和,如果指定了 normalize 的话还需要进行额外的 normalize,结果就是公式中的
接下来再回到 __call__() 方法中,这里得到了 score 变量,接下来可以对齐求 softmax() 操作,得到
alignments = self._probability_fn(score, previous_alignments)
这就代表了在i时刻,Decoder 的时候对 Encoder 得到的每个
所以综上所述,BahdanauAttention 就是初始化时传入 num_units 以及 Encoder Outputs,然后调时传入 query 用即可得到权重变量 alignments。
AttentionWrapperState
接下来我们再看下 AttentionWrapperState 这个类,这个类其实比较简单,就是定义了 Attention 过程中可能需要保存的变量,如 cell_state、attention、time、alignments 等内容,同时也便于后期的可视化呈现,代码实现如下:
class AttentionWrapperState(
collections.namedtuple("AttentionWrapperState",
("cell_state", "attention", "time", "alignments",
"alignment_history"))):
可见它就是继承了 namedtuple 这个数据结构,其实整个 AttentionWrapperState 就像声明了一个结构体,可以传入需要的字段生成这个对象。
AttentionWrapper
了解了 Attention 机制及 BahdanauAttention 的原理之后,最后我们再来了解一下 AttentionWrapper,可能你用过很多其他的 Wrapper,如 DropoutWrapper、ResidualWrapper 等等,它们其实都是 RNNCell 的实例,其实 AttentionWrapper 也不例外,它对 RNNCell 进行了封装,封装后依然还是 RNNCell 的实例。一个普通的 RNN 模型,你要加入 Attention,只需要在 RNNCell 外面套一层 AttentionWrapper 并指定 AttentionMechanism 的实例就好了。而且如果要更换 AttentionMechanism,只需要改变 AttentionWrapper 的参数就好了,这可谓对 Attention 的实现架构完全解耦,配置非常灵活,TF 大法好!
接下来我们首先来看下它的初始化方法,其参数是这样的:
def __init__(self,
cell,
attention_mechanism,
attention_layer_size=None,
alignment_history=False,
cell_input_fn=None,
output_attention=True,
initial_cell_state=None,
name=None):
下面对参数进行一一说明:
cell:An instance of RNNCell. RNNCell 的实例,这里可以是单个的 RNNCell,也可以是多个 RNNCell 组成的 MultiRNNCell。
attention_mechanism:即 AttentionMechanism 的实例,如 BahdanauAttention 对象,另外可以是多个 AttentionMechanism 组成的列表。
attention_layer_size:是数字或者数字做成的列表,如果是 None(默认),直接使用加权计算后得到的 Attention 作为输出,如果不是 None,那么 Attention 结果还会和 Output 进行拼接并做线性变换再输出。其代码实现如下:
if attention_layer_size is not None:
attention_layer_sizes = tuple(attention_layer_size if isinstance(attention_layer_size, (list, tuple)) else (attention_layer_size,))
if len(attention_layer_sizes) != len(attention_mechanisms):
raise ValueError("If provided, attention_layer_size must contain exactly one integer per attention_mechanism, saw: %d vs %d" % (len(attention_layer_sizes), len(attention_mechanisms)))
self._attention_layers = tuple(layers_core.Dense(attention_layer_size, name="attention_layer", use_bias=False, dtype=attention_mechanisms[i].dtype) for i, attention_layer_size in enumerate(attention_layer_sizes))
self._attention_layer_size = sum(attention_layer_sizes)
else:
self._attention_layers = None
self._attention_layer_size = sum(attention_mechanism.values.get_shape()[-1].value for attention_mechanism in attention_mechanisms)
for i, attention_mechanism in enumerate(self._attention_mechanisms):
attention, alignments = _compute_attention(attention_mechanism, cell_output, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None)
alignment_history = previous_alignment_history[i].write(state.time, alignments) if self._alignment_history else ()
alignment_history:即是否将之前的 alignments 存储到 state 中,以便于后期进行可视化展示。
cell_input_fn:将 Input 进行处理的方式,默认会将上一步的 Attention 进行 拼接操作,以免造成重复关注同样的内容。代码调用如下:
cell_inputs = self._cell_input_fn(inputs, state.attention)
output_attention:是否将 Attention 返回,如果是 False 则返回 Output,否则返回 Attention,默认是 True。
initial_cell_state:计算时的初始状态。
name:自定义名称。
AttentionWrapper 的核心方法在它的 call() 方法,即类似于 RNNCell 的 call() 方法,AttentionWrapper 类对其进行了重载,代码实现如下:
def call(self, inputs, state):
# Step 1
cell_inputs = self._cell_input_fn(inputs, state.attention)
# Step 2
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
# Step 3
if self._is_multi:
previous_alignments = state.alignments
previous_alignment_history = state.alignment_history
else:
previous_alignments = [state.alignments]
previous_alignment_history = [state.alignment_history]
all_alignments = []
all_attentions = []
all_histories = []
for i, attention_mechanism in enumerate(self._attention_mechanisms):
attention, alignments = _compute_attention(attention_mechanism, cell_output, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None)
alignment_history = previous_alignment_history[i].write(state.time, alignments) if self._alignment_history else ()
all_alignments.append(alignments)
all_histories.append(alignment_history)
all_attentions.append(attention)
# Step 4
attention = array_ops.concat(all_attentions, 1)
# Step 5
next_state = AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
alignments=self._item_or_tuple(all_alignments),
alignment_history=self._item_or_tuple(all_histories))
# Step 6
if self._output_attention:
return attention, next_state
else:
return cell_output, next_state
在这里将一些异常判断代码去除了,以便于结构看得更清晰。
首先在第一步中,调用了 _cell_input_fn() 方法,对 inputs 和 state.attention 变量进行处理,默认是使用 concat() 函数拼接,作为当前时间步的输入。因为可能前一步的 Attention 可能对当前 Attention 有帮助,以免让模型连续两次将注意力放在同一个地方。
在第二步中,其实就是调用了普通的 RNNCell 的 call() 方法,得到输出和下一步的状态。
第三步中,这时得到的输出其实并没有用上 AttentionMechanism 中的 alignments 信息,所以当前的输出信息中我们并没有跟 Encoder 的信息做 Attention,所以这里还需要调用 _compute_attention() 方法进行权重的计算,其方法实现如下:
def _compute_attention(attention_mechanism, cell_output, previous_alignments, attention_layer):
alignments = attention_mechanism(cell_output, previous_alignments=previous_alignments)
expanded_alignments = array_ops.expand_dims(alignments, 1)
context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
context = array_ops.squeeze(context, [1])
if attention_layer is not None:
attention = attention_layer(array_ops.concat([cell_output, context], 1))
else:
attention = context
return attention, alignments
这个方法接收四个参数,其中 attentionmechanism 就是 AttentionMechanism 的实例,cell_output 就是当前 Output,previous_alignments 是上步的 alignments 信息,调用 attention_mechanism 计算之后就会得到当前步的 alignments 信息了,即
在第四步中,就是将 attention 结果每个时间步进行 concat,得到 attention vector。
第五步中,声明 AttentionWrapperState 作为下一步的状态。
第六步,判断是否要输出 Attention,如果是,输出 Attention 及下一步状态,否则输出 Outputs 及下一步状态。
好,以上便是整个 AttentionWrapper 源码解析过程,了解了源码之后,再做模型优化的话就非常得心应手了。
参考来源
Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation, Cho et al
Sequence to Sequence Learning with Neural Networks, Sutskever et al
Neural Machine Translation by Jointly Learning to Align and Translate, Bahdanau et al
Effective Approaches to Attention-based Neural Machine Translation, Luong, et al
Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks, Salimans, et al
http://news.ifeng.com/a/20170901/51842411_0.shtml
https://blog.csdn.net/qsczse943062710/article/details/79539005
https://zhuanlan.zhihu.com/p/34393028
下图扫码或点击阅读原文
报名学习崔老师的网络爬虫课程
已经1800人加入学习
限时优惠价308元
点击“
阅读原文
”,立即加速爬虫技能修炼!