变分自编码器系列:VAE + BN = 更好的VAE
©PaperWeekly 原创 · 作者|苏剑林
单位|追一科技
研究方向|NLP、神经网络
值得一提的是,本文最后得到的方案还是颇为简洁的——只需往编码输出加入BN(Batch Normalization),然后加个简单的 scale——但确实很有效,因此值得正在研究相关问题的读者一试。同时,相关结论也适用于一般的 VAE 模型(包括 CV 的),如果按照笔者的看法,它甚至可以作为 VAE 模型的“标配”。
1.1 VAE的训练流程
VAE 的训练流程大概可以图示为:
1.2 NLP中的VAE
这种情况下的 VAE 模型并没有什么价值:KL 散度为 0 说明编码器输出的是 0 向量,而解码器则是一个普通的语言模型。而我们使用 VAE 通常来说是看中了它无监督构建编码向量的能力,所以要应用 VAE 的话还是得解决 KL 散度消失问题。
1.3 BN的巧与秒
KL 散度消失不就是 KL 散度项变成 0 吗?我调整一下编码器输出,让 KL 散度有一个大于零的下界,这样它不就肯定不会消失了吗?
1.4 推导过程简述
1.5 为什么不是LN?
善于推导的读者可能会想到,按照上述思路,如果只是为了让 KL 散度项有个正的下界,其实 LN(Layer Normalization)也可以,也就是在式(3)中按 j 那一维归一化。
那为什么用BN而不是LN呢?
进一步的结果
2.1 联系到先验分布
2.2 参考的实现方案
关键代码参考(Keras):
class Scaler(Layer):
"""特殊的scale层
"""
def __init__(self, tau=0.5, **kwargs):
super(Scaler, self).__init__(**kwargs)
self.tau = tau
def build(self, input_shape):
super(Scaler, self).build(input_shape)
self.scale = self.add_weight(
name='scale', shape=(input_shape[-1],), initializer='zeros'
)
def call(self, inputs, mode='positive'):
if mode == 'positive':
scale = self.tau + (1 - self.tau) * K.sigmoid(self.scale)
else:
scale = (1 - self.tau) * K.sigmoid(-self.scale)
return inputs * K.sqrt(scale)
def get_config(self):
config = {'tau': self.tau}
base_config = super(Scaler, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def sampling(inputs):
"""重参数采样
"""
z_mean, z_std = inputs
noise = K.random_normal(shape=K.shape(z_mean))
return z_mean + z_std * noise
e_outputs # 假设e_outputs是编码器的输出向量
scaler = Scaler()
z_mean = Dense(hidden_dims)(e_outputs)
z_mean = BatchNormalization(scale=False, center=False, epsilon=1e-8)(z_mean)
z_mean = scaler(z_mean, mode='positive')
z_std = Dense(hidden_dims)(e_outputs)
z_std = BatchNormalization(scale=False, center=False, epsilon=1e-8)(z_std)
z_std = scaler(z_std, mode='negative')
z = Lambda(sampling, name='Sampling')([z_mean, z_std])
文章内容小结
参考链接
[1] https://arxiv.org/abs/2004.12585
更多阅读
#投 稿 通 道#
让你的论文被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得或技术干货。我们的目的只有一个,让知识真正流动起来。
📝 来稿标准:
• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向)
• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接
• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志
📬 投稿邮箱:
• 投稿邮箱:hr@paperweekly.site
• 所有文章配图,请单独在附件中发送
• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
关于PaperWeekly
PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。