其他
【他山之石】白话生成对抗网络GAN及代码实现
“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。
地址:https://www.zhihu.com/people/yilan-zhong-shan-xiao-29-98
Tutorial_HYLee_GAN Renu Khandelwal 的博客 Jason 的博客
01
神经网络一览
02
GAN的基本思想
03
GAN的工作原理
更新判别器的网络参数。即给定假图片以及假图片的标签(上图中的generated example)、真图片以及真图片的标签(上图中的real example),让判别器能够区别出真假图片,也就是训练一个尽可能准确的二分类器。 固定判别器网络参数, 更新生成器网络。即给定假图片以及假标签(让判别器以为假图片是真的),从而误差反向传播来更新生成器,使得生成器生成更加逼真的照片。
判别器想要最大化目标函数使得对于真实数据 D(x) 接近 1,对于假数据 D(G(z)) 接近 0 生成器想要最小化目标函数使得 D(G(z)) 接近 1,也就是欺骗判别器让它认为假数据为真
04
GAN的实现
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import keras
from keras.layers import Dense, Dropout, Input
from keras.models import Model,Sequential
from keras.datasets import mnist
from tqdm import tqdm
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam
def load_data():
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5)/127.5
# 将图片转为向量 x_train from (60000, 28, 28) to (60000, 784)
# 每一行 784 个元素
x_train = x_train.reshape(60000, 784)
return (x_train, y_train, x_test, y_test)
(X_train, y_train,X_test, y_test)=load_data()
print(X_train.shape)
def adam_optimizer():
return Adam(lr=0.0002, beta_1=0.5)
def create_generator():
generator=Sequential()
generator.add(Dense(units=256,input_dim=100))
generator.add(LeakyReLU(0.2))
generator.add(Dense(units=512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(units=1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(units=784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())
return generator
g=create_generator()
g.summary()
def create_discriminator():
discriminator=Sequential()
discriminator.add(Dense(units=1024,input_dim=784))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(units=512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(units=256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dense(units=1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())
return discriminator
d =create_discriminator()
d.summary()
def create_gan(discriminator, generator):
discriminator.trainable=False
# 这是一个链式模型:输入经过生成器、判别器得到输出
gan_input = Input(shape=(100,))
x = generator(gan_input)
gan_output= discriminator(x)
gan= Model(inputs=gan_input, outputs=gan_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')
return gan
gan = create_gan(d,g)
gan.summary()
def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(10,10)):
noise= np.random.normal(loc=0, scale=1, size=[examples, 100])
generated_images = generator.predict(noise)
generated_images = generated_images.reshape(100,28,28)
plt.figure(figsize=figsize)
for i in range(generated_images.shape[0]):
plt.subplot(dim[0], dim[1], i+1)
plt.imshow(generated_images[i], interpolation='nearest')
plt.axis('off')
plt.tight_layout()
plt.savefig('gan_generated_image %d.png' %epoch)
def training(epochs=1, batch_size=128):
#导入数据
(X_train, y_train, X_test, y_test) = load_data()
batch_count = X_train.shape[0] / batch_size
# 定义生成器、判别器和GAN网络
generator= create_generator()
discriminator= create_discriminator()
gan = create_gan(discriminator, generator)
for e in range(1,epochs+1 ):
print("Epoch %d" %e)
for _ in tqdm(range(int(batch_count))):
#产生噪声喂给生成器
noise= np.random.normal(0,1, [batch_size, 100])
# 产生假图片
generated_images = generator.predict(noise)
# 一组随机真图片
image_batch =X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]
# 真假图片拼接
X= np.concatenate([image_batch, generated_images])
# 生成数据和真实数据的标签
y_dis=np.zeros(2*batch_size)
y_dis[:batch_size]=0.9
# 预训练,判别器区分真假
discriminator.trainable=True
discriminator.train_on_batch(X, y_dis)
# 欺骗判别器 生成的图片为真的图片
noise= np.random.normal(0,1, [batch_size, 100])
y_gen = np.ones(batch_size)
# GAN的训练过程中判别器的权重需要固定
discriminator.trainable=False
# GAN的训练过程为交替“训练判别器”和“固定判别器权重训练链式模型”
gan.train_on_batch(noise, y_gen)
if e == 1 or e % 50 == 0:
# 画图 看一下生成器能生成什么
plot_generated_images(e, generator)
training(400,256)
本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。
直播预告
“他山之石”历史文章
pytorch的余弦退火学习率
Pytorch转ONNX-实战篇(tracing机制)
联邦学习:FedAvg 的 Pytorch 实现
PyTorch实现ShuffleNet-v2亲身实践
训练时显存优化技术——OP合并与gradient checkpoint
浅谈数据标准化与Pytorch中NLLLoss和CrossEntropyLoss损失函数的区别
在C++平台上部署PyTorch模型流程+踩坑实录
libtorch使用经验
深度学习模型转换与部署那些事(含ONNX格式详细分析)
如何支撑上亿类别的人脸训练?显存均衡的模型并行(PyTorch实现)
PyTorch trick 集锦
分享、点赞、在看,给个三连击呗!