查看原文
其他

生成对抗网络详解与代码演示

gloomyfish OpenCV学堂 2020-02-04

点击上方蓝字关注我们

星标或者置顶【OpenCV学堂】

干货教程第一时间送达!

生成对抗网络(GAN)

生成对抗网络(Generative Adversarial Nets)在图像生成、音乐与文本生成方面都有着很多神奇效果,生成对抗网络产生受到都来自博弈论与对战游戏的启发,生成对抗网络,需要三个输入

输入数据– 一组样本数据x-P(data).
生成器G – 随机初始化生成数P(g),终极目标是生成跟样本数据分布一致的数据.
判别器D – 判别数据是来自输入数据x还是来自生成者P(g).

生成器通过学习目的是生成一个假数分布数据让判别器无法正确判断数据是来自样本x还 ,而判别器通过学习正确判别数据是来自x还是P(g),两者通过这种博弈游戏不断提升自己能力, 最终达到一种稳定平衡状态,又称纳什均衡。

网络与损失
生成器是一个多层感知器网络,生成器的输入是随机初始化分布的样本数据,通过学习生成一个跟输入数据x分布一致的数据。判别器也是一个多层感知器网络,其输出单个标量表示数据是来自x或者P(g)的可能性。
训练最大化判别器输-D(x)
同时训练生成器最小-log(1-D(G(z)))
也就是说生成器与判别器在进行一个最大最小的对抗游戏,表达如下:

图示如下:

蓝色虚线表示判别者、黑色虚线表示生成者、绿色实线表示输入数据X、Z表示随机采样数据,(a)表示刚开始时候分布、(d)表示最终达到平衡状态的分布,很显然要想让生成器与判别器都达到稳定状态,只有D(x) = 1/2时候才满足。

全局优化与收敛
对于给定的生成器网络,G是固定情况下,判别器优化如下:

对于给定G(生成器)条件下,训练D(判别器)最大期望,就变成了对数似然,条件概率估算问题,表示如下:

其中KL散度是衡量两个分布的相似的度量,JS散度与KL散度之间的关系如下:

可以看出KL散度跟顺序有关系,是非对称的计算,JS散度是对称的,跟顺序无关。根据上述推导,同时训练生成器与判别器,最终生成器会收敛到输入数据,有如下关系成立:

代码实现

定义生成器网络-多层感知器

def generator(Z):
    hidden_ 
= tf.nn.sigmoid(tf.add(tf.matmul(Z, G_w1), G_b1))
    logist = tf.add(tf.matmul(hidden_, G_w2), G_b2)
    prob_ = tf.nn.sigmoid(logist)
    return prob_

定义判别器网络-多层感知器

def discriminator(X):
    D_Layer1 
= tf.add(tf.matmul(X, D_w1), D_b1)
    D_h1 = tf.nn.sigmoid(D_Layer1)
    D_logist = tf.add(tf.matmul(D_h1, D_w2), D_b2)
    D_prob = tf.nn.sigmoid(D_logist)
    return D_prob, D_logist

定义损失函数与优化器

# 生成者生成与判别者判别
G_sample = generator(input_z)
D_real, D_logit_real = discriminator(input_x)
D_fake, D_logit_fake = discriminator(G_sample)

# 损失函数
# D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
# G_loss = -tf.reduce_mean(tf.log(D_fake))

D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

基于mnist数据集训练与保存检查点

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(100000):
        batch_xs, batch_ys = mnist.train.next_batch(min_batch)
        _, D_loss_curr = sess.run([D_step, D_loss], feed_dict={input_x: batch_xs, input_z: sample_z(min_batch, 100)})
        _, G_loss_curr = sess.run([G_step, G_loss], feed_dict={input_z: sample_z(min_batch, 100)})
        if i % 100 == 0:
            print("D_loss_curr : %.4f, G_loss_curr : %.4f"%(D_loss_curr, G_loss_curr))
    saver.save(sess, "./my_gan_mnist.model", global_step=100000)

演示效果

运行模型生成mnist数据集

with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint("."))
    num_images = 26
    for i in range(1000):
        result_ = sess.run(G_sample, feed_dict={input_z: sample_z(num_images, 100)})
        images = np.resize(result_, (-12828)) * 255.0
        images = np.uint8(images)
        for i in range(25):
            plt.subplot(55, i + 1)
            plt.xticks([])
            plt.yticks([])
            plt.grid(False)
            plt.imshow(images[i], cmap=plt.cm.gray)
            plt.xlabel(str(i+1))
        plt.show()

运行效果如下

远飞者当换其新羽

欢迎扫码加入【OpenCV研习社】

- 学习OpenCV+tensorflow开发技术
- 与更多伙伴相互交流、一起学习进步
- 每周一到每周五分享知识点学习(音频+文字+源码)
- 系统化学习知识点,从易到难、由浅入深
- 直接向老师提问、每天答疑辅导



推荐阅读

OpenCV学堂-原创精华文章

《tensorflow零基础入门视频教程》

基于OpenCV与tensorflow实现实时手势识别

tensorflow风格迁移网络训练与使用

使用tensorflow layers相关API快速构建卷积神经网络

基于OpenCV Python实现二维码检测与识别

OpenCV+Tensorflow实现实时人脸识别演示

教程 | Tensorflow keras 极简神经网络构建与使用


关注【OpenCV学堂】

长按或者扫码即可关注


    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存