查看原文
其他

【他山之石】白话生成对抗网络GAN及代码实现

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

作者:养生的控制人

地址:https://www.zhihu.com/people/yilan-zhong-shan-xiao-29-98


本文主要是个简单的笔记,参考资料来自下面三部分
  1. Tutorial_HYLee_GAN
  2. Renu Khandelwal 的博客
  3. Jason 的博客


01

神经网络一览

各种神经网络(全连接前向网络、卷积神经网络、循环神经网络)的区别在于具有不同的输入/输出形式,比如可以是向量、矩阵或者是向量序列等。


02

GAN的基本思想

GAN由生成器和判别器组成:
生成器的本质也是一个神经网络,或者说是一个函数

如果给定一个向量可以生成一张漫画图片,向量的每一个维度具有不同含义

判别器的本质也是一个神经网络

如果给定一张图片,判别器就会告诉你这是不是真实图片

所以GAN的训练本质就是训练两个神经网络。

03

GAN的工作原理

生成器的目标是产生和训练数据相似的数据(以假乱真的图片),而判别器的目标是辨别真假。
生成器的输入通常为随机噪声,判别器有两个输入,一个来自训练数据中的真图片,一个来自生成器生成的假图片。
GAN的流程如下图所示

每一次迭代过程中:
  1. 更新判别器的网络参数。即给定假图片以及假图片的标签(上图中的generated example)、真图片以及真图片的标签(上图中的real example),让判别器能够区别出真假图片,也就是训练一个尽可能准确的二分类器。
  2. 固定判别器网络参数, 更新生成器网络。即给定假图片以及假标签(让判别器以为假图片是真的),从而误差反向传播来更新生成器,使得生成器生成更加逼真的照片。
GAN训练的目标函数如下所示

  • 判别器想要最大化目标函数使得对于真实数据 D(x) 接近 1,对于假数据 D(G(z)) 接近 0
  • 生成器想要最小化目标函数使得 D(G(z)) 接近 1,也就是欺骗判别器让它认为假数据为真


04

GAN的实现

这里采用 MNIST 数据集作为实验数据,最后我们会看到生成器能够产生看起来像真的数字!
导入需要用到的库
import numpy as npimport pandas as pdimport matplotlib.pyplot as plt%matplotlib inlineimport kerasfrom keras.layers import Dense, Dropout, Inputfrom keras.models import Model,Sequentialfrom keras.datasets import mnistfrom tqdm import tqdmfrom keras.layers.advanced_activations import LeakyReLUfrom keras.optimizers import Adam
导入数据
def load_data(): (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = (x_train.astype(np.float32) - 127.5)/127.5
# 将图片转为向量 x_train from (60000, 28, 28) to (60000, 784) # 每一行 784 个元素 x_train = x_train.reshape(60000, 784) return (x_train, y_train, x_test, y_test)(X_train, y_train,X_test, y_test)=load_data()print(X_train.shape)
定义优化器
def adam_optimizer(): return Adam(lr=0.0002, beta_1=0.5)
这里要采用的生成对抗网络的结构如下图所示
定义生成器:输入是 100 维,经过三层隐藏层,输出 784 维的向量(造假的图片)
def create_generator(): generator=Sequential() generator.add(Dense(units=256,input_dim=100)) generator.add(LeakyReLU(0.2))
generator.add(Dense(units=512)) generator.add(LeakyReLU(0.2))
generator.add(Dense(units=1024)) generator.add(LeakyReLU(0.2))
generator.add(Dense(units=784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam_optimizer()) return generatorg=create_generator()g.summary()
定义判别器:判别器的输入为真实图片或者由生成器造出来的假图片(784维),经过三层隐藏层,输出类别(1 维)
def create_discriminator(): discriminator=Sequential() discriminator.add(Dense(units=1024,input_dim=784)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3))

discriminator.add(Dense(units=512)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3))
discriminator.add(Dense(units=256)) discriminator.add(LeakyReLU(0.2))
discriminator.add(Dense(units=1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam_optimizer()) return discriminatord =create_discriminator()d.summary()
定义生成对抗网络
def create_gan(discriminator, generator): discriminator.trainable=False # 这是一个链式模型:输入经过生成器、判别器得到输出 gan_input = Input(shape=(100,)) x = generator(gan_input) gan_output= discriminator(x) gan= Model(inputs=gan_input, outputs=gan_output) gan.compile(loss='binary_crossentropy', optimizer='adam') return gangan = create_gan(d,g)gan.summary()
定义画图函数来可视化图片的生成
def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(10,10)): noise= np.random.normal(loc=0, scale=1, size=[examples, 100]) generated_images = generator.predict(noise) generated_images = generated_images.reshape(100,28,28) plt.figure(figsize=figsize) for i in range(generated_images.shape[0]): plt.subplot(dim[0], dim[1], i+1) plt.imshow(generated_images[i], interpolation='nearest') plt.axis('off') plt.tight_layout() plt.savefig('gan_generated_image %d.png' %epoch)
生成对抗网络的训练函数
def training(epochs=1, batch_size=128):
#导入数据 (X_train, y_train, X_test, y_test) = load_data() batch_count = X_train.shape[0] / batch_size
# 定义生成器、判别器和GAN网络 generator= create_generator() discriminator= create_discriminator() gan = create_gan(discriminator, generator)
for e in range(1,epochs+1 ): print("Epoch %d" %e) for _ in tqdm(range(int(batch_count))): #产生噪声喂给生成器 noise= np.random.normal(0,1, [batch_size, 100])
# 产生假图片 generated_images = generator.predict(noise)
# 一组随机真图片 image_batch =X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]
# 真假图片拼接 X= np.concatenate([image_batch, generated_images])
# 生成数据和真实数据的标签 y_dis=np.zeros(2*batch_size) y_dis[:batch_size]=0.9
# 预训练,判别器区分真假 discriminator.trainable=True discriminator.train_on_batch(X, y_dis)
# 欺骗判别器 生成的图片为真的图片 noise= np.random.normal(0,1, [batch_size, 100]) y_gen = np.ones(batch_size)
# GAN的训练过程中判别器的权重需要固定 discriminator.trainable=False
# GAN的训练过程为交替“训练判别器”和“固定判别器权重训练链式模型” gan.train_on_batch(noise, y_gen)
if e == 1 or e % 50 == 0: # 画图 看一下生成器能生成什么 plot_generated_images(e, generator)training(400,256)
经过训练后生成的图片
一个epoch后生成器还是个小学生
100个epoch后生成器已经有点样子了
400个epoch后生成器可以出师了
是不是已经学得像模像样了,这样就能够利用噪声通过生成器来生成以假乱真的图片了。

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


直播预告



“他山之石”历史文章




分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存