基于飞桨复现 GLCLC 模型,对残次图片实现图像补全

点击左上方蓝字关注我们

【飞桨开发者说】侯继旭,海南师范大学自动化本科在读,PPDE飞桨开发者技术专家,研究方向为目标检测、对抗生成网络等

 

本次复现使用的数据集是CelebA人脸数据集,这是一个大规模的人脸属性数据集,是由香港中文大学汤晓鸥教授实验室公布的大型人脸识别数据集,拥有超过20万张名人图像,已下载放置在此项目的数据集中,人脸属性有40多种。

本文项目代码github地址:

https://github.com/Eric-Hjx/PaddlePaddle_Image_Completion

模型摘要

在此篇论文中,作者们提出了Globally and Locally Consistent Image Completion方法,可以使得图像的缺失部分自动补全,局部和整图保持一致。作者通过全卷积网络,可以补全图片中任何形状的缺失,为了保持补全后的图像与原图的一致性,作者使用全局(整张图片)和局部(缺失补全部分)两种鉴别器来训练。全局鉴别器查看整个图像以评估它是否作为整体是连贯的,而局部鉴别器仅查看以完成区域为中心的小区域来确保所生成的补丁的局部一致性。

接着对图像补全网络训练以欺骗两个内容鉴别器网络,这要求它生成总体以及细节上与真实无法区分的图像。我们证明了我们的方法可以用来完成各种各样的场景。此外,与PatchMatch等基于补丁的方法相比,我们的方法可以生成图像中未出现的碎片,这使我们能够自然地完成具有熟悉且高度特定的结构(如面部)的对象的图像。

该论文的方法,完全以卷积网络作为基础,使用了GAN网络的思路,设计了两部分(三个网络),一部分用于生成图像,即补全网络,一部分用于鉴别生成图像是否与原图像一致,即全局鉴别器和局部鉴别器。网络结构图如下所示:

网络介绍:

  1. 补全网络:补全网络是完全卷积的,目的是用来修复图像。

  2. 全局鉴别器:以完整的图像作为输入,识别场景的全局一致性。

  3. 局部鉴别器:只关注完成区域周围的一个小区域,以判断更详细的外观质量。

 

基于飞桨实现GLCLC算法

 

下面我们基于飞桨开源深度学习框架动手实现 GLCLC 算法,介绍神经网络代码实现内容,主要使用了卷积、反卷积、空洞卷积、正则、激活函数等方法搭建了补全网络及鉴别网络。

1. 补全网络结构

补全网络部分,作者采用12层卷积网络对输入图像进行encoding,得到一张原图16分之一大小的网格。然后再对该网格采用4层卷积网络进行decoding。为了保证生成区域尽量不模糊,文中降低分辨率的操作是使用strided convolution 的方式进行的,而且只用了两次,将图片的size 变为原来的四分之一。同时在中间层还使用了空洞卷积来增大感受野,在尽量获取更大范围内的图像信息的同时不损失额外的信息,从而得到复原图像。下表为补全网络各层参数分布情况。

输入为RGB图像与二进制掩码(需要填充的区域以1填充)的组合图像;输出为RGB图像。

# 搭建补全网络
def generator(x):
    # conv1
    conv1 = fluid.layers.conv2d(input=x,num_filters=64,filter_size=5,dilation=1,stride=1,padding='SAME',name='generator_conv1',data_format='NHWC')
    conv1 = fluid.layers.batch_norm(conv1, momentum=0.99, epsilon=0.001)
    conv1 = fluid.layers.relu(conv1, name=None)
    # conv2
    conv2 = fluid.layers.conv2d(input=conv1,num_filters=128,filter_size=3,dilation=1,stride=2,padding='SAME',name='generator_conv2',data_format='NHWC')
    conv2 = fluid.layers.batch_norm(conv2, momentum=0.99, epsilon=0.001)
    conv2 = fluid.layers.relu(conv2, name=None)
    # conv3
    conv3 = fluid.layers.conv2d(input=conv2,num_filters=128,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv3',data_format='NHWC')
    conv3 = fluid.layers.batch_norm(conv3, momentum=0.99, epsilon=0.001)
    conv3 = fluid.layers.relu(conv3, name=None)
    # conv4
    conv4 = fluid.layers.conv2d(input=conv3,num_filters=256,filter_size=3,dilation=1,stride=2,padding='SAME',name='generator_conv4',data_format='NHWC')
    conv4 = fluid.layers.batch_norm(conv4, momentum=0.99, epsilon=0.001)
    conv4 = fluid.layers.relu(conv4, name=None)
    # conv5
    conv5 = fluid.layers.conv2d(input=conv4,num_filters=256,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv5',data_format='NHWC')
    conv5 = fluid.layers.batch_norm(conv5, momentum=0.99, epsilon=0.001)
    conv5 = fluid.layers.relu(conv5, name=None)
    # conv6
    conv6 = fluid.layers.conv2d(input=conv5,num_filters=256,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv6',data_format='NHWC')
    conv6 = fluid.layers.batch_norm(conv6, momentum=0.99, epsilon=0.001)
    conv6 = fluid.layers.relu(conv6, name=None)
    # 空洞卷积
    # dilated1
    dilated1 = fluid.layers.conv2d(input=conv6,num_filters=256,filter_size=3,dilation=2,padding='SAME',name='generator_dilated1',data_format='NHWC')
    dilated1 = fluid.layers.batch_norm(dilated1, momentum=0.99, epsilon=0.001)
    dilated1 = fluid.layers.relu(dilated1, name=None)
    # dilated2
    dilated2 = fluid.layers.conv2d(input=dilated1,num_filters=256,filter_size=3,dilation=4,padding='SAME',name='generator_dilated2',data_format='NHWC') #stride=1
    dilated2 = fluid.layers.batch_norm(dilated2, momentum=0.99, epsilon=0.001)
    dilated2 = fluid.layers.relu(dilated2, name=None)
    # dilated3
    dilated3 = fluid.layers.conv2d(input=dilated2,num_filters=256,filter_size=3,dilation=8,padding='SAME',name='generator_dilated3',data_format='NHWC')
    dilated3 = fluid.layers.batch_norm(dilated3, momentum=0.99, epsilon=0.001)
    dilated3 = fluid.layers.relu(dilated3, name=None)
    # dilated4
    dilated4 = fluid.layers.conv2d(input=dilated3,num_filters=256,filter_size=3,dilation=16,padding='SAME',name='generator_dilated4',data_format='NHWC')
    dilated4 = fluid.layers.batch_norm(dilated4, momentum=0.99, epsilon=0.001)
    dilated4 = fluid.layers.relu(dilated4, name=None)
    # conv7
    conv7 = fluid.layers.conv2d(input=dilated4,num_filters=256,filter_size=3,dilation=1,name='generator_conv7',data_format='NHWC')
    conv7 = fluid.layers.batch_norm(conv7, momentum=0.99, epsilon=0.001)
    conv7 = fluid.layers.relu(conv7, name=None)
    # conv8
    conv8 = fluid.layers.conv2d(input=conv7,num_filters=256,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv8',data_format='NHWC')
    conv8 = fluid.layers.batch_norm(conv8, momentum=0.99, epsilon=0.001)
    conv8 = fluid.layers.relu(conv8, name=None)
    # deconv1
    deconv1 = fluid.layers.conv2d_transpose(input=conv8, num_filters=128, output_size=[64,64],stride = 2,name='generator_deconv1',data_format='NHWC')
    deconv1 = fluid.layers.batch_norm(deconv1, momentum=0.99, epsilon=0.001)
    deconv1 = fluid.layers.relu(deconv1, name=None)
    # conv9
    conv9 = fluid.layers.conv2d(input=deconv1,num_filters=128,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv9',data_format='NHWC')
    conv9 = fluid.layers.batch_norm(conv9, momentum=0.99, epsilon=0.001)
    conv9 = fluid.layers.relu(conv9, name=None)
    # deconv2
    deconv2 = fluid.layers.conv2d_transpose(input=conv9, num_filters=64, output_size=[128,128],stride = 2,name='generator_deconv2',data_format='NHWC')
    deconv2 = fluid.layers.batch_norm(deconv2, momentum=0.99, epsilon=0.001)
    deconv2 = fluid.layers.relu(deconv2, name=None)
    # conv10
    conv10 = fluid.layers.conv2d(input=deconv2,num_filters=32,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv10',data_format='NHWC')
    conv10 = fluid.layers.batch_norm(conv10, momentum=0.99, epsilon=0.001)
    conv10 = fluid.layers.relu(conv10, name=None)
    # conv11
    x = fluid.layers.conv2d(input=conv10,num_filters=3,filter_size=3,dilation=1,stride=1,padding='SAME',name='generator_conv11',data_format='NHWC')
    x = fluid.layers.tanh(x)
    return x

2. 内容鉴别器

内容鉴别器分为了两个部分,一个全局鉴别器(Global Discriminator)以及一个局部鉴别器(Local Discriminator)。全局鉴别器是将一张完整的图像作为输入数据,对图像的全局一致性做出判断;局部鉴别器仅在以填充区域为中心的原图像四分之一大小区域上观测,对此部分图像的一致性做出判断。通过采用上述两个不同的鉴别器,可以使得最终的网络,不但可以对图像全局一致性做判断,并且能够通过局部鉴别方法,优化生成图的细节,最终能产生更好的图片填充效果。

在原文中,作者设定的全局鉴别网络输入是256X256X3的图片,局部网络输入是128X128X3的图片。原始论文中,全局网络和局部网络都会通过使用5X5的卷积层、2X2的stride降低图像分辨率,通过全连接,分别得到一个1024维的向量。然后,作者将全局和局部两个鉴别器的输出连接成一个2048维向量,再通过一个全连接,然后用sigmoid函数对整体的图像的一致性进行打分判别。但在本次实验,为了能降低训练难度,设定全局鉴别网络输入是128X128X3的图片,局部网络输入是64X64X3的图片。

# 搭建内容鉴别器
def discriminator(global_x, local_x):
    def global_discriminator(x):
        # conv1
        conv1 = fluid.layers.conv2d(input=x,num_filters=64,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv1',data_format='NHWC')
        conv1 = fluid.layers.batch_norm(conv1, momentum=0.99, epsilon=0.001)
        conv1 = fluid.layers.relu(conv1, name=None)
        # conv2
        conv2 = fluid.layers.conv2d(input=conv1,num_filters=128,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv2',data_format='NHWC')
        conv2 = fluid.layers.batch_norm(conv2, momentum=0.99, epsilon=0.001)
        conv2 = fluid.layers.relu(conv2, name=None)
        # conv3
        conv3 = fluid.layers.conv2d(input=conv2,num_filters=256,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv3',data_format='NHWC')
        conv3 = fluid.layers.batch_norm(conv3, momentum=0.99, epsilon=0.001)
        conv3 = fluid.layers.relu(conv3, name=None)
        # conv4
        conv4 = fluid.layers.conv2d(input=conv3,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv4',data_format='NHWC')
        conv4 = fluid.layers.batch_norm(conv4, momentum=0.99, epsilon=0.001)
        conv4 = fluid.layers.relu(conv4, name=None)
        # conv5
        conv5 = fluid.layers.conv2d(input=conv4,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv5',data_format='NHWC')
        conv5 = fluid.layers.batch_norm(conv5, momentum=0.99, epsilon=0.001)
        conv5 = fluid.layers.relu(conv5, name=None)
        # conv6
        conv6 = fluid.layers.conv2d(input=conv5,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_global_conv6',data_format='NHWC')
        conv6 = fluid.layers.batch_norm(conv6, momentum=0.99, epsilon=0.001)
        conv6 = fluid.layers.relu(conv6, name=None)
        # fc
        x = fluid.layers.fc(input=conv6, size=1024,name='discriminator_global_fc1')
        return x

    def local_discriminator(x):
        # conv1
        conv1 = fluid.layers.conv2d(input=x,num_filters=64,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv1',data_format='NHWC')
        conv1 = fluid.layers.batch_norm(conv1, momentum=0.99, epsilon=0.001)
        conv1 = fluid.layers.relu(conv1, name=None)
        # conv2
        conv2 = fluid.layers.conv2d(input=conv1,num_filters=128,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv2',data_format='NHWC')
        conv2 = fluid.layers.batch_norm(conv2, momentum=0.99, epsilon=0.001)
        conv2 = fluid.layers.relu(conv2, name=None)
        # conv3
        conv3 = fluid.layers.conv2d(input=conv2,num_filters=256,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv3',data_format='NHWC')
        conv3 = fluid.layers.batch_norm(conv3, momentum=0.99, epsilon=0.001)
        conv3 = fluid.layers.relu(conv3, name=None)
        # conv4
        conv4 = fluid.layers.conv2d(input=conv3,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv4',data_format='NHWC')
        conv4 = fluid.layers.batch_norm(conv4, momentum=0.99, epsilon=0.001)
        conv4 = fluid.layers.relu(conv4, name=None)
        # conv5
        conv5 = fluid.layers.conv2d(input=conv4,num_filters=512,filter_size=5,dilation=1,stride=2,padding='SAME',name='discriminator_lobal_conv5',data_format='NHWC')
        conv5 = fluid.layers.batch_norm(conv5, momentum=0.99, epsilon=0.001)
        conv5 = fluid.layers.relu(conv5, name=None)
        # fc
        x = fluid.layers.fc(input=conv5, size=1024,name='discriminator_lobal_fc1')
        return x

    global_output = global_discriminator(global_x)
    local_output = local_discriminator(local_x)
    print('global_output',global_output.shape)
    print('local_output',local_output.shape)
    output = fluid.layers.concat([global_output, local_output], axis=1)
    output = fluid.layers.fc(output, size=1,name='discriminator_concatenation_fc1')

    return output

3. 损失函数

生成网络使用weighted Mean Squared Error (MSE)作为损失函数,计算原图与生成图像像素之间的差异,表达式如下所示:

鉴别器网络使用GAN损失函数,其目标是最大化生成图像和原始图像的相似概率,表达式如下所示:

最后结合两者损失,形成下式:

网络训练

 

原文作者使用4个K80 GPU,使用的输入图像大小是256*256,训练了2个月才训练完成。

本项目为了缩短训练时间,仅采用了此论文核心思想、网络结构、优化目标等,并对训练方式及部分细节做了简化。使用的输入图像大小:128*128,训练方式设定为:先训练生成器再将生成器和判别器一起训练。

# 生成器优先迭代次数
NUM_TRAIN_TIMES_OF_DG = 100
# 总迭代轮次
epoch = 200

step_num = int(len(x_train) / BATCH_SIZE)

np.random.shuffle(x_train)

for pass_id in range(epoch):
    # 训练生成器
    if pass_id <= NUM_TRAIN_TIMES_OF_DG:
        g_loss_value = 0
        for i in tqdm.tqdm(range(step_num)):
            x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
            points_batch, mask_batch = get_points()
            # print(x_batch.shape)
            # print(mask_batch.shape)
            dg_loss_n = exe.run(dg_program,
                                 feed={'x': x_batch, 
                                        'mask':mask_batch,},
                                 fetch_list=[dg_loss])[0]
            g_loss_value += dg_loss_n
        print('Pass_id:{}, Completion loss: {}'.format(pass_id, g_loss_value))

        np.random.shuffle(x_test)
        x_batch = x_test[:BATCH_SIZE]

        completion_n = exe.run(dg_program, 
                        feed={'x': x_batch, 
                                'mask': mask_batch,},
                        fetch_list=[completion])[0][0]
        # 修复图片
        sample = np.array((completion_n + 1) * 127.5, dtype=np.uint8)
        # 原图
        x_im = np.array((x_batch[0] + 1) * 127.5, dtype=np.uint8)
        # 挖空洞输入图
        input_im_data = x_im * (1 - mask_batch[0])
        input_im = np.array(input_im_data + np.ones_like(x_im) * mask_batch[0] * 255, dtype=np.uint8)
        output_im = np.concatenate((x_im,input_im,sample),axis=1)
        #print(output_im.shape)
        cv2.imwrite('./output/pass_id:{}.jpg'.format(pass_id), cv2.cvtColor(output_im, cv2.COLOR_RGB2BGR))
        # 保存模型
        save_pretrain_model_path = 'models/'
        # 创建保持模型文件目录
        #os.makedirs(save_pretrain_model_path)
        fluid.io.save_params(executor=exe, dirname=save_pretrain_model_path, main_program=dg_program)

    # 生成器判断器一起训练
    else:
        g_loss_value = 0
        d_loss_value = 0
        for i in tqdm.tqdm(range(step_num)):
            x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
            points_batch, mask_batch = get_points()
            dg_loss_n = exe.run(dg_program,
                                 feed={'x': x_batch, 
                                        'mask':mask_batch,},
                                 fetch_list=[dg_loss])[0]
            g_loss_value += dg_loss_n

            completion_n = exe.run(dg_program, 
                                feed={'x': x_batch, 
                                        'mask': mask_batch,},
                                fetch_list=[completion])[0]
            local_x_batch = []
            local_completion_batch = []
            for i in range(BATCH_SIZE):
                x1, y1, x2, y2 = points_batch[i]
                local_x_batch.append(x_batch[i][y1:y2, x1:x2, :])
                local_completion_batch.append(completion_n[i][y1:y2, x1:x2, :])
            local_x_batch = np.array(local_x_batch)
            local_completion_batch = np.array(local_completion_batch)
            d_loss_n  = exe.run(d_program,
                                feed={'x': x_batch, 'mask': mask_batch, 'local_x': local_x_batch, 'global_completion': completion_n, 'local_completion': local_completion_batch},
                                fetch_list=[d_loss])[0]
            d_loss_value += d_loss_n
        print('Pass_id:{}, Completion loss: {}'.format(pass_id, g_loss_value))
        print('Pass_id:{}, Discriminator loss: {}'.format(pass_id, d_loss_value))

        np.random.shuffle(x_test)
        x_batch = x_test[:BATCH_SIZE]
        completion_n = exe.run(dg_program, 
                        feed={'x': x_batch, 
                                'mask': mask_batch,},
                        fetch_list=[completion])[0][0]
        # 修复图片
        sample = np.array((completion_n + 1) * 127.5, dtype=np.uint8)
        # 原图
        x_im = np.array((x_batch[0] + 1) * 127.5, dtype=np.uint8)
        # 挖空洞输入图
        input_im_data = x_im * (1 - mask_batch[0])
        input_im = np.array(input_im_data + np.ones_like(x_im) * mask_batch[0] * 255, dtype=np.uint8)
        output_im = np.concatenate((x_im,input_im,sample),axis=1)
        #print(output_im.shape)
        cv2.imwrite('./output/pass_id:{}.jpg'.format(pass_id), cv2.cvtColor(output_im, cv2.COLOR_RGB2BGR))
        # 保存模型
        save_pretrain_model_path = 'models/'
        # 创建保持模型文件目录
        #os.makedirs(save_pretrain_model_path)
        fluid.io.save_params(executor=exe, dirname=save_pretrain_model_path, main_program = dg_program)

结果展示

 

项目总结

 

整个训练过程,花了9小时左右,共训练了100次补全网络+45次补全网络和鉴别网络。

Image Completion Result 中的 Input 是挖洞后输入补全网络的图像,在 Output 看到, Input 图像上挖的洞已经被补上了,这说明现在的训练结果已经能在一定程度上补全图像的缺失部分了。由于本项目实现时在硬件及时间方面受限,因此对原文中的方法进行了简化,训练方法和数据样本处理较原论文有所调整做了调整,无法达到原论文效果,但相较于原作者两个月的训练时间对比,这样的训练方式也是可取的。

如想到达到原论文的精准的小伙伴,可以在本项目基础上修改训练策略~在此附上原论文训练程序图

本项目使用了飞桨开源深度学习框架,在AI Studio上完成了数据处理、模型训练、效果预测等整个工作过程,非常感谢AI Studio给我们提供的GPU在线训练环境,对于在深度学习道路上硬件条件上不足的学生来说简直是非常大的帮助。

如果你对这个小实验感兴趣,也可以自己来尝试一下,整个项目包括数据集与相关代码已公开在AI Studio上,欢迎小伙伴们Fork。

https://aistudio.baidu.com/aistudio/projectdetail/632313

如在使用过程中有问题,可加入飞桨官方QQ群进行交流:1108045677。

如果您想详细了解更多飞桨的相关内容,请参阅以下文档。

·飞桨开源框架项目地址·

GitHub: https://github.com/PaddlePaddle/Paddle 

Gitee: https://gitee.com/paddlepaddle/Paddle

·飞桨官网地址·

https://www.paddlepaddle.org.cn/

 

扫描二维码 | 关注我们

微信号 : PaddleOpenSource

END

精彩活动