tensorflow风格迁移网络训练与使用
微信公众号:OpenCV学堂
关注获取更多计算机视觉与深度学习知识
觉得文章对你有用,请戳底部广告支持
风格迁移原理解释
卷积神经网络实现图像风格迁移在2015的一篇论文中最早出现。实现了一张从一张图像中提取分割,从另外一张图像中提取内容,叠加生成一张全新的图像。早前风靡一时的风格迁移APP – Prisma其背后就是图像各种风格迁移、让人耳目一新。其主要的思想是对于训练好的卷积神经网络,其内部一些feature map跟最终识别的对象是特征独立的,这些特征当中有一些是关于内容特征的,另外一些是关于风格特征的,于是我们可以输入两张图像,从其中一张图像上提取其内容特征,另外一张图像上提取其风格特征,然后把它们叠加在一起形成一张新的图像,这个就风格迁移卷积网络。最常见的我们是用一个预先训练好的卷积神经网络,常见的就是VGG-19,其结构如下:
其包含16个卷积层、5个池化层、3个全链接层。其中:
表示内容层为:relu4-2
表示风格层为:relu1_1, relu2_1, relu3_1, relu4_1, relu5_1
越高阶的层图像内容越抽象,我们损失的像素信息越多,所有选用relu4-2层作为内容层而忽略低阶的内容损失,对于风格来说,它是从低阶到高阶的层组合。所以选用从低到高不同层组合作为风格[relu1_1, relu2_1, relu3_1, relu4_1, relu5_1]
迁移损失
风格迁移生成图像Y,
要求它的内容来自图像C,
要求它的风格来自图像S。
Y是随机初始化的一张图像,带入到预训练的网络中会得到内容层与风格层的输出结果
C是内容图像,带入到预训练的网络中得到内容层Target标签
S是风格图像,带入到预训练的网络中得到风格层Target标签
这样总的损失函数就是内容与标签的损失,此外我们希望最终生成的图像是一张光滑图像,所有还有一个像素方差损失,这三个损失分别表示为 :
Loss(content)、 Loss(style) 、 Loss(var)
最终总的损失函数为:
Total Loss = alpha * Loss (content) + beta * Loss (Style) + Loss (var)
其中alpha与beta分别是内容损失与风格损失的权重大小
代码实现:
获取内容图像C与风格图像S的标签
# Get network parameters
image = tf.placeholder('float', shape=shape)
vgg_net = vgg_network(network_weights, image)
# Normalize original image
original_minus_mean = content_image - normalization_mean
original_norm = np.array([original_minus_mean])
original_features[content_layers] = sess.run(vgg_net[content_layers],
feed_dict={image: original_norm})
# Get style image network
image = tf.placeholder('float', shape=style_shape)
vgg_net = vgg_network(network_weights, image)
style_minus_mean = style_image - normalization_mean
style_norm = np.array([style_minus_mean])
for layer in style_layers:
layer_output = sess.run(vgg_net[layer], feed_dict={image: style_norm})
layer_output = np.reshape(layer_output, (-1, layer_output.shape[3]))
style_gram_matrix = np.matmul(layer_output.T, layer_output) / layer_output.size
style_features[layer] = style_gram_matrix
随机初始化Y图像
# 随机初始化目标图像
initial = tf.random_normal(shape) * 0.256
image = tf.Variable(initial)
vgg_net = vgg_network(network_weights, image)
计算内容损失
# 计算目标图像内容与内容图像之间的差异, 内容损失
original_loss = original_image_weight * (
2 * tf.nn.l2_loss(vgg_net[content_layers] - original_features[content_layers]) /
original_features[content_layers].size)
计算风格损失
# 风格损失
style_loss = 0
style_losses = []
for style_layer in style_layers:
layer = vgg_net[style_layer]
feats, height, width, channels = [x.value for x in layer.get_shape()]
size = height * width * channels
features = tf.reshape(layer, (-1, channels))
style_gram_matrix = tf.matmul(tf.transpose(features), features) / size
style_expected = style_features[style_layer]
style_losses.append(2 * tf.nn.l2_loss(style_gram_matrix - style_expected) / style_expected.size)
style_loss += style_image_weight * tf.reduce_sum(style_losses)
添加smooth损失
# To Smooth the resuts, we add in total variation loss
total_var_x = sess.run(tf.reduce_prod(image[:, 1:, :, :].get_shape()))
total_var_y = sess.run(tf.reduce_prod(image[:, :, 1:, :].get_shape()))
first_term = regularization_weight * 2
second_term = (tf.nn.l2_loss(image[:, 1:, :, :] - image[:, :shape[1] - 1, :, :]) / total_var_y)
third_term = (tf.nn.l2_loss(image[:, :, 1:, :] - image[:, :, :shape[2] - 1, :]) / total_var_x)
total_variation_loss = first_term * (second_term + third_term)
训练风格迁移
# 总的损失
loss = original_loss + style_loss + total_variation_loss
# 优化器
optimizer = tf.train.AdamOptimizer(learning_rate, beta1, beta2)
train_step = optimizer.minimize(loss)
# 初始化参数与训练
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
for i in range(generations):
sess.run(train_step)
# Print update and save temporary output
if (i + 1) % output_generations == 0:
print('Generation {} out of {}, loss: {}'.format(i + 1, generations, sess.run(loss)))
image_eval = sess.run(image)
best_image = image_eval.reshape(shape[1:]) + normalization_mean
temp_img = np.clip(best_image, 0, 255).astype(np.uint8)
output_file = 'D:/pet_data/temp_output_{}.jpg'.format(i)
Image.fromarray(temp_img).save(output_file, quality=95)
saver.save(sess, "./neural_style.model", global_step=2500)
运行结果
输入图像(右下角为风格图像),输出图像
欢迎大家扫码加入【OpenCV研习社】
使用tensorflow layers相关API快速构建卷积神经网络
我们是
OpenCV学堂
长按二维码
关注我们