查看原文
其他

对抗神经网络初探

2018-01-15 Peter 混沌巡洋舰

GAN,也就是对抗神经网络背后的道理,具有普适的应用场景,毕竟,在金庸小说中就有左右互搏,而GAN说到底不过是一个生成器一个判别器,让俩者“自我对弈”,从而互相进步。然而,神经网络在复杂搜索空间上能够找出其他的机器学习方法所不能的路径,就如同如果你只会一俩种拳法,那么即使教会你双手互搏的心法,你也无法武功大进,而若是你本身就会顶级武功,那这时左右互搏就能让你取长补短。


回到GAN,该领域的进步使得对抗神经网络的思路可以完成很多传统的深度学习方法所不能完成的任务,例如训练一个能够写新闻稿的机器人,能够模仿知名画家风格的艺术家(参考 怎么样用深度学习取悦你的女朋友(有代码))。大神Yann LeCun在他的Quora回答中说到,“(GANs), and the variations that are now being proposed is the most interesting idea in the last 10 years in ML, in my opinion.

不必担心GAN太难,在这篇文章中, 我将向您介绍GAN的基础概念, 并解释它们是如何与挑工作的。我也会让你知道人们已经做了使用GAN做的一些很酷的事, 并给你一些可以深入了解这些技术细节的链接。我们先从GAN的名字说的Generative指的是生成式的,Adversarial是对手的意思,而N则是Network。


让我们先通过一个类比来解释GAN:


如果你想在国际象棋中表现的更好;你会怎么做?你多半会分析你做错了什么, 他做对了什么, 并考虑你可以做什么来击败他在下一场比赛。你持续的重复这样的思考, 直到你击败对手。这个例子可以用来指导该如何建立更好的模型。简单地说, 为了获得一个强大的生成器 , 我们需要一个更强大的对手 ,即鉴别器。


另一个来自生活中的例子是名画的伪造者和文物专家之间的关系。



伪造者的任务是创造足以以假乱真的著名艺术家的杰作的模仿品。如果这个模仿的画作被判定为原作, 伪造者就会得到许多金钱上的报酬。另一方面, 一个文物专家的任务是抓住这些伪造者的模仿品。他是怎么做到的?他知道什么真品的特有属性, 他用头脑中积累的这些知识, 来检查画作是否真实。这样的猫鼠游戏使得艺术鉴定家和仿照者都提升了姿势水平。对应到神经网络中,就是训练一个神经网络来生成数据,另一个判定网络来判定生成器生成的数据是真的(来自于现实世界的训练数据)还是假的。


此图中的生成器使用随机数以及先验的知识(编程者告诉神经网络数据预期会是怎样,非必需)来生成一些伪造的数据,而图中的手写数字则是真实的图片,之后这俩部分数据被放进了判别器中,判别器要做的判定哪一个是生成器生成的,输出是0到1之间的一个数,代表每个数据点是真实数据的概率。


我们来做一些数学上的严谨定义:

P(x) ->真实数据的分布函数
X ->真实数据p(x)的取样
P(z) -> 生成函数的概率分布
Z ->  生成数据p(z)的取样
G(z) -> 生成网络
D(x) -> 判别网络


而训练的过程可以看成是如下的优化任务:



这里的V(D,G)的第一项是真实数据通过判别网络鉴定为真的期望,第二项为生成的数据通过判别网络的鉴定的期望,我们希望通过训练的判别器能使第一项为1,第二项为0,总的来说,训练的目标是使V(D,G)最大。而对于生成器,网络的训练目标在于使V(D,G)最小,也就是让判别网络犯迷糊。在实际的训练中,可以先训练生成网络,让判别网络冬眠,之后在训练判别网络,同时保持生成网络的参数不变,通过这样的方式,来实现上述目标。



训练GAN的7个步骤

步骤 1: 定义问题。你想生成图像或文本。在这里, 您应该完全定义问题并为其收集足够的真实的训练数据。

步骤 2: 定义 GAN 的体系结构。定义你的 GAN 应该是什么样子。你的生成器和鉴别器应该是多层感知机, 还是卷积神经网络?这一步将取决于您试图解决的问题。

步骤 3: 先用真实数据训练鉴别器。获取已有的(或随机生成的)伪造的数据,目标是训练鉴别器使其可以正确地预测数据那些是真的。

步骤4:生成器生成假数据, 并对假数据进行鉴别。获取生成的数据, 目标是让让鉴别器正确地预测它们是假的。

步骤 5: 用鉴别器输出的训练生成器。现在, 鉴别器已经训练好了, 你可以得到它的判定, 并把它作为一个训练生成器的目标。目标是训练生成器来愚弄鉴别器。

步骤 6: 迭代式的重复步骤3到步骤5

7步: 如果假数据看似没问题 请手动检查。如果它看起来合适, 停止训练, 否则去步骤3。这是一个有点不够自动的任务, 手工评估数据是用来检测GAN训练成果的最好方式。


理论上,GAN可以看成是强化学习的变种,是最容易实现强人工智能的一种方式,你可以用GAN训练可以完成任意任务的目标,不管是写作权力的游戏的续集,驾驶汽车,或者是整理法律文书。只要你能够说清楚你想要模仿的任务是什么就可以了。但是,在现实中,GAN面临着诸多的挑战。


关于这个话题,可以参考下面的论文 


https://arxiv.org/pdf/1606.03498.pdf


首要的问题是GAN的稳定性,即生成器和判别器是相克相生的,如果一个表现的不好,另一个也无法表现好。如果训练中判别器比生成器好的太多了,那生成器会需要特别长的时间才能稍微进步一点点,而若是你将判别器调整的过于“是非不分”,那么生成器也没什么可以学习的,同样无法进步。这对于生活在“别人家的孩子”阴影下的孩子应该很熟悉,你的父母就是那个判别器,如果你和别人家的孩子差的太远,那么你会难以进步,若是你的父母对你过度溺爱,你也无法成长。


下面说说具体的问题,下面的这些图片是GAN生成的


我们看到,神经网络无法计数,因此会生成这样荒谬的照片,虽然图中的眼睛看起来确实是对应动物的,但神经网络却不知道该将其放在那里,放几个。


上面的图片指出,生成器难以正确的的将三维的物体投影到2维上,也无法区分向前看和向后看的区别。

生成器不能生成理解局部和整体之间的关系,因此会生成看起来荒谬的图片。


下面让我们看一个具体的案例,我们训练一个网络,来判定一幅28×28的图片是否是数字,用到的训练数据集可以从这里下载https://datahack.analyticsvidhya.com/contest/practice-problem-identify-the-digits/


下面看一看训练GAN的伪代码:


来源:http://papers.nips.cc/paper/5423-generative-adversarial


接着看看用到的python 包

import os import numpy as np import pandas as pd from scipy.misc import imread import keras from keras.models import Sequential from keras.layers import Dense, Flatten, Reshape, InputLayer from keras.regularizers import L1L2

接着给生成器设定一个随机数种子

# to stop potential randomness seed = 128 rng = np.random.RandomState(seed)

导入数据

# set path root_dir = os.path.abspath('.') data_dir = os.path.join(root_dir, 'Data')

# load data train = pd.read_csv(os.path.join(data_dir, 'Train', 'train.csv')) test = pd.read_csv(os.path.join(data_dir, 'test.csv')) temp = [] for img_name in train.filename:     image_path = os.path.join(data_dir, 'Train', 'Images', 'train', img_name)     img = imread(image_path, flatten=True)     img = img.astype('float32')     temp.append(img)      train_x = np.stack(temp) train_x = train_x / 255.

看一看导入的数据的具体状况

img_name = rng.choice(train.filename) filepath = os.path.join(data_dir, 'Train', 'Images', 'train', img_name) img = imread(filepath, flatten=True) pylab.imshow(img, cmap='gray') pylab.axis('off') pylab.show()

下面的图片来自于训练数据集,也是这段代码会展示的:


接着我们定义生成器,这是一个三层的神经网络,前俩层是500个神经员,激活函数是常用的Relu,同时每一层进行了L2正则化,最后一层是28×28个神经元,使用Sigmoid函数来确定输出的图像中每一个点是黑,白还是灰色。

# generator model_1 = Sequential([     Dense(units=hidden_1_num_units, input_dim=g_input_shape, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),     Dense(units=hidden_2_num_units, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),              Dense(units=g_output_num_units, activation='sigmoid', kernel_regularizer=L1L2(1e-5, 1e-5)),          Reshape(d_input_shape), ])

判别网络有着类似的构成

# discriminator model_2 = Sequential([     InputLayer(input_shape=d_input_shape),          Flatten(),              Dense(units=hidden_1_num_units, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),     Dense(units=hidden_2_num_units, activation='relu', kernel_regularizer=L1L2(1e-5, 1e-5)),              Dense(units=d_output_num_units, activation='sigmoid', kernel_regularizer=L1L2(1e-5, 1e-5)), ])

接着可以看一看网络的整体情况

接着倒入Keras中和GAN相关的包:

from keras_adversarial import AdversarialModel, simple_gan, gan_targets from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling

最后我们开始GAN的训练

gan = simple_gan(model_1, model_2, normal_latent_sampling((100,))) model = AdversarialModel(base_model=gan,player_params=[model_1.trainable_weights, model_2.trainable_weights]) model.adversarial_compile(adversarial_optimizer=AdversarialOptimizerSimultaneous(), player_optimizers=['adam', 'adam'], loss='binary_crossentropy') history = model.fit(x=train_x, y=gan_targets(train_x.shape[0]), epochs=10, batch_size=batch_size)

让看看训练10轮后的结果

plt.plot(history.history['player_0_loss']) plt.plot(history.history['player_1_loss']) plt.plot(history.history['loss'])

这里我们看到生成器一开始表现是很差的,判别器表现也不好,但是不要急,经过100轮的训练,得到了下面的图片,看起来就挺像数字的了。


下面介绍一下GAN的应用场景:

预测视频中下一帧会发生什么? : https://arxiv.org/pdf/1511.06380.pdf

用较低像素的图片中生成较高像素的图片:https://arxiv.org/pdf/1609.04802.pdf

创意图片的动态生成,艺术家只需要画出草稿,由GAN来填充细节

https://github.com/junyanz/iGAN

利用原始图片生成新图片:https://arxiv.org/pdf/1611.07004.pdf

根据图片生成描述的文字


总结:GAN是神经网络研究中很热门的一个子领域,也有很多变种,进展很快,工业界的应用也很广。GAN的思路会让人想起无监督学习中的自编码器,其思路虽然说起来很简单,但要想训练好相应的GAN,却需要很多和问题相关的技巧。


最后列出对GAN学习有关的链接:

https://github.com/zhangqianhui/AdversarialNetsPapers

http://www.deeplearningbook.org/contents/generative_models.html


本文参考:https://www.analyticsvidhya.com/blog/2017/06/introductory-generative-adversarial-networks-gans/

http://www.iangoodfellow.com/slides/2016-12-04-NIPS.pdf


扩展阅读

用R语言实现深度学习情感分析例子

用深度学习玩图像的七重关卡

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存