基於 GAN 實現圖像銳化應用(附代碼)

【導讀】生成對抗網絡(GAN)是 Ian Goodfellow 在 2014 年在其論文 Generative Adversarial Nets 中提出來的,可以說是當前最炙手可熱的技術了。本文基於 Keras 框架構建 GAN 網絡,解決圖像銳化問題。首先介紹了 GAN 的基本網絡架構,然後從數據、模型、訓練等幾個方面介紹 GAN 在圖像銳化的應用。本文是一篇很好的 GAN 學習實例,並且給出了許多不錯的 GAN 學習鏈接,對 GAN 感興趣的讀者不容錯過!

基於 GAN 實現圖像銳化應用

2014 年,Ian Goodfellow 提出了生成對抗網絡(Generative Adversarial Networks,GAN),在這篇文章中我們介紹如何基於 Keras 框架構建 GAN 網絡,解決圖像銳化問題。

Keras 代碼可以在此處查看:

https://github.com/RaphaelMeudec/deblur-gan

原始論文見

https://arxiv.org/pdf/1711.07064.pdf。

Pytorch 版本見 https://github.com/KupynOrest/DeblurGAN/。

快速瞭解生成對抗網絡

在生成式對抗網絡中,兩個網絡互相對抗。其中,生成器通過創建僞造信號來誤導判別器,而判別器需要判斷輸入的信號是真實的還是假造的。

圖 GAN 訓練過程

其中,有三個主要的訓練步驟:

將生成器與判別器鏈接在一起,原因是我們沒有對於生成器輸出的反饋,唯一的衡量標準是判別器是否接受生成的樣本。

數據

Ian Goodfellow 首次應用 GAN 生成了 MNIST 數據,在本文,我們使用 GAN 進行圖像銳化,因此,發生器的輸入不是噪聲,而是模糊的圖像。

本次任務中,我們使用的數據集是 GOPRO 數據集,大家可以下載輕量級版本                                        (9GB):https://drive.google.com/file/d/1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2/view?usp=sharing,

                                                             或者完整的版本(35GB):https://drive.google.com/file/d/1SlURvdQsokgsoyTosAaELc4zRjQz9T2U/view?usp=sharing,它包含了來自多個街景的人工模糊圖像。

我們首先將圖像分配到兩個不同文件夾中,A(模糊)和 B(銳利)。A&B 結果來自於這篇關於 pix2pix 的文章:https://phillipi.github.io/pix2pix/。我在倉庫中創建了一個自定義腳本來實現這個任務,請按照 README 的步驟使用。

模型

訓練過程保持不變,開始前,我們來看一下神經網絡的架構。

生成器

生成器的目標是重現銳化的圖像。該網絡基於 ResNet 構建,它會跟蹤原始模糊圖像的變化,這篇文章中也提到了一種基於 UNet 網絡的版本:https://arxiv.org/pdf/1505.04597.pdf。

圖:銳化 GAN 生成器網絡架構:https://arxiv.org/pdf/1711.07064.pdf

方法的核心是應用於原始圖像採樣的 9 個 ResNet 塊,下面我們來看看 Keras 的實現。

from keras.layers import Input, Conv2D, Activation, BatchNormalization
from keras.layers.merge import Add
from keras.layers.core import Dropout

def res_block(input, filters, kernel_size=(3,3), strides=(1,1), 
use_dropout=False):
    """
    Instanciate a Keras Resnet Block using sequential API.
    :param input: Input tensor
    :param filters: Number of filters to use
    :param kernel_size: Shape of the kernel for the convolution
    :param strides: Shape of the strides for the convolution
    :param use_dropout: Boolean value to determine the use of dropout
    :return: Keras Model
    """
    x = ReflectionPadding2D((1,1))(input)
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               strides=strides,)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    if use_dropout:
        x = Dropout(0.5)(x)

    x = ReflectionPadding2D((1,1))(x)
    x = Conv2D(filters=filters,
                kernel_size=kernel_size,
                strides=strides,)(x)
    x = BatchNormalization()(x)

    # Two convolution layers followed by a direct connection between 
input and output
    merged = Add()([input, x])
    return merged

ResNet 層是典型的卷積層,添加輸入輸出信息以形成最終的結果,

from keras.layers import Input, Activation, Add
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.core import Lambda
from keras.layers.normalization import BatchNormalization
from keras.models import Model

from layer_utils import ReflectionPadding2D, res_block

ngf = 64
input_nc = 3
output_nc = 3
input_shape_generator = (256, 256, input_nc)
n_blocks_gen = 9


def generator_model():
    """Build generator architecture."""
    # Current version : ResNet block
    inputs = Input(shape=image_shape)

    x = ReflectionPadding2D((3, 3))(inputs)
    x = Conv2D(filters=ngf, kernel_size=(7,7), padding='valid')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # Increase filter number
    n_downsampling = 2
    for i in range(n_downsampling):
        mult = 2**i
        x = Conv2D(filters=ngf*mult*2, kernel_size=(3,3), strides=2, 
padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    # Apply 9 ResNet blocks
    mult = 2**n_downsampling
    for i in range(n_blocks_gen):
        x = res_block(x, ngf*mult, use_dropout=True)

    # Decrease filter number to 3 (RGB)
    for i in range(n_downsampling):
        mult = 2**(n_downsampling - i)
        x = Conv2DTranspose(filters=int(ngf * mult / 2),
kernel_size=(3,3), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    x = ReflectionPadding2D((3,3))(x)
    x = Conv2D(filters=output_nc, kernel_size=(7,7), padding='valid')(x)
    x = Activation('tanh')(x)

    # Add direct connection from input to output and recenter to [-1, 1]
    outputs = Add()([x, inputs])
    outputs = Lambda(lambda z: z/2)(outputs)

    model = Model(inputs=inputs, outputs=outputs, name='Generator')
    return model

按計劃,9 個 ResNet 塊應用於輸入的 upsample 版本。我們增加了輸入到輸出的連接,併除以 2 以保持標準化的輸出。

這就是生成器的實現,下面我們來看一下判別器的架構。

判別器

目標是確定輸入圖像是真實圖片還是僞造的圖片。因此,判別器的結構是卷積層與輸出層,輸出結果是單個的值。

from keras.layers import Input
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.layers.core import Dense, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Model

ndf = 64
output_nc = 3
input_shape_discriminator = (256, 256, output_nc)


def discriminator_model():
    """Build discriminator architecture."""
    n_layers, use_sigmoid = 3, False
    inputs = Input(shape=input_shape_discriminator)

    x = Conv2D(filters=ndf, kernel_size=(4,4), strides=2,
padding='same')(inputs)
    x = LeakyReLU(0.2)(x)

    nf_mult, nf_mult_prev = 1, 1
    for n in range(n_layers):
        nf_mult_prev, nf_mult = nf_mult, min(2**n, 8)
        x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=2, 
padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(0.2)(x)

    nf_mult_prev, nf_mult = nf_mult, min(2**n_layers, 8)
    x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=1,
padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)

    x = Conv2D(filters=1, kernel_size=(4,4), strides=1, 
padding='same')(x)
    if use_sigmoid:
        x = Activation('sigmoid')(x)

    x = Flatten()(x)
    x = Dense(1024, activation='tanh')(x)
    x = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=inputs, outputs=x, name='Discriminator')
    return model

最後是構建完整的模型,這個 GAN 的特殊之處在於輸入是實際的圖像,而不是噪聲,因此,我們需要爲生成器的輸出引入直接反饋。

from keras.layers import Input
from keras.models import Model

def generator_containing_discriminator_multiple_outputs(generator,
discriminator):
    inputs = Input(shape=image_shape)
    generated_images = generator(inputs)
    outputs = discriminator(generated_images)
    model = Model(inputs=inputs, outputs=[generated_images, outputs])
    return model

接下來讓我們看看兩個特殊的損失函數。

訓練

Losses

我們分別從兩個級別提取 losses:生成器級別和全模型級別。

生成器級別:根據生成器的輸出計算損失函數,這個損失確保了 GAN 模型面向一個模糊的任務,它比較了 VGG 的第一個卷積的輸出。

import keras.backend as K
from keras.applications.vgg16 import VGG16
from keras.models import Model

image_shape = (256, 256, 3)

def perceptual_loss(y_true, y_pred):
    vgg = VGG16(include_top=False, weights='imagenet', 
input_shape=image_shape)
    loss_model = Model(inputs=vgg.input,
outputs=vgg.get_layer('block3_conv3').output)
    loss_model.trainable = False
    return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))

全模型級別使用 Wasserstein loss,用來計算整個模型的損失。它計算了兩個圖像間的平均偏差。可以改善 GAN 的收斂性。

import keras.backend as K

def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true*y_pred)

訓練路線

第一步是載入數據並初始化模型,我們使用我們自定義的函數去載入數據集,併爲模型增加 Adam 優化器,最後設置 Keras 的訓練參數。

# Load dataset
data = load_images('./images/train', n_images)
y_train, x_train = data['B'], data['A']

# Initialize models
g = generator_model()
d = discriminator_model()
d_on_g = generator_containing_discriminator_multiple_outputs(g, d)

# Initialize optimizers
g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

# Compile models
d.trainable = True
d.compile(optimizer=d_opt, loss=wasserstein_loss)
d.trainable = False
loss = [perceptual_loss, wasserstein_loss]
loss_weights = [100, 1]
d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
d.trainable = True

之後,我們開始訓練。

for epoch in range(epoch_num):
  print('epoch: {}/{}'.format(epoch, epoch_num))
  print('batches: {}'.format(x_train.shape[0] / batch_size))

  # Randomize images into batches
  permutated_indexes = np.random.permutation(x_train.shape[0])

  for index in range(int(x_train.shape[0] / batch_size)):
      batch_indexes = permutated_indexes[index*batch_size:(index+1)*
      batch_size]
      image_blur_batch = x_train[batch_indexes]
      image_full_batch = y_train[batch_indexes]

最後,我們成功地基於兩個損失函數對生成器與判別器進行了訓練。我們使用生成器輸出了僞造圖片,進而使用僞造圖片與真實圖片訓練判別器對二者的評判區分。

for epoch in range(epoch_num):
  for index in range(batches):
    # [Batch Preparation]

    # Generate fake inputs
    generated_images = g.predict(x=image_blur_batch, 
batch_size=batch_size)
    
    # Train multiple times discriminator on real and fake inputs
    for _ in range(critic_updates):
        d_loss_real = d.train_on_batch(image_full_batch, 
output_true_batch)
        d_loss_fake = d.train_on_batch(generated_images, 
output_false_batch)
        d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

    d.trainable = False
    # Train generator only on discriminator's decision and
generated images
    d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, 
[image_full_batch, output_true_batch])

    d.trainable = True

完整代碼請參見:https://www.github.com/raphaelmeudec/deblur-gan

訓練環境

在 AWS Instance 上使用 Deep Learning AMI(3.0 版本)。輕量級數據集,訓練時間大約 5 小時。

圖像銳化結果

上圖中是 Keras 銳化 GAN 的結果。即使在非常嚴重的模糊圖片上,這一網絡仍然可以給出更加銳利的圖片。圖中車燈更加銳利,樹枝也更加清晰。

存在問題是模型在圖像中引入了新的圖案,這可能是由於使用 VGG 作爲損失函數引起的。

如果你對計算機視覺感興趣,這裏有一篇基於內容的圖像檢索問題的介紹:

https://blog.sicara.com/keras-tutorial-content-based-image-retrieval-convolutional-denoising-autoencoder-dc91450cc511。

下面列出了一些 GAN 的優質資源。

GAN 資源

本文由 Readfog 進行 AMP 轉碼,版權歸原作者所有。
來源https://mp.weixin.qq.com/s/MFSLkmGOCV5maltlKAF2Wg