0
点赞
收藏
分享

微信扫一扫

【百战GAN】SRGAN人脸低分辨率老照片修复代码实战


大家好,欢迎来到专栏《百战GAN》,在这个专栏里,我们会进行算法的核心思想讲解,代码的详解,模型的训练和测试等内容。

作者&编辑 | 言有三

【百战GAN】SRGAN人脸低分辨率老照片修复代码实战_python

本文资源与生成结果展示

本文篇幅:7000字

背景要求:会使用Python和Pytorch

附带资料:参考论文和项目

1 项目背景

了解GAN的同学都知道,GAN擅长于捕捉概率分布,因此非常适合图像生成类任务。我们在图片视频拍摄以及传输过程中,经常会进行图像的压缩,导致图像分辨率过低,另外早些年的设备拍摄出来的照片也存在分辨率过低的问题,比如10年前的320*240分辨率。

【百战GAN】SRGAN人脸低分辨率老照片修复代码实战_人工智能_02

要解决此问题,需要使用到图像超分辩技术。

2 原理简介

图像超分辩任务输入是一张低分辨率的图像,输出是一张对它进行分辨率增大的图片,下面是一个常用的框架示意图[1]:

【百战GAN】SRGAN人脸低分辨率老照片修复代码实战_计算机视觉_03

该框架首先对输入图使用插值方法进行上采样,然后使用卷积层对输入进行学习,这种框架的劣势是计算代价比较大,因为整个网络是对高分辨率图操作。

随后研究者提出在网络的后端进行分辨率放大,通过扩充通道数,然后将其重新分布来获得高分辨率图,这套操作被称为(PixShuffle)[2],这样整个网络大部分计算量是对低分辨率图操作,如下图:

【百战GAN】SRGAN人脸低分辨率老照片修复代码实战_人工智能_04

对于维度为H×W×C的图像,标准反卷积操作输出的特征图维度为rH×rW×C,其中r就是需要放大的倍数,而从图中可以看出,亚像素卷积层的输出特征图维度为H×W×C×r×r,即特征图与输入图片的尺寸保持一致,但是通道数被扩充为原来的r×r倍,然后再进行重新排列得到高分辨率的结果。

整个流程因为使用了更小的图像输入,从而可以使用更小的卷积核获取较大的感受野,这既使得输入图片中邻域像素点的信息得到有效利用,还避免了计算复杂度的增加,是一种将空间上采样问题转换为通道上采样问题的思路,被大多数主流超分模型采用为上采样模块。

以上构成了图像超分辨的基本思路,之后Twitter的研究者们使用ResNet作为生成器结构,使用VGG作为判别器结构,提出了SRGAN[3]模型,模型结构示意图如下图:

【百战GAN】SRGAN人脸低分辨率老照片修复代码实战_机器学习_05

图中生成器结构包含了若干个不改变特征分辨率的残差模块和多个基于亚像素卷积的后上采样模块。

判别器结构则包含了若干个通道数不断增加的卷积层,每次特征通道数增加一倍时,特征分辨率降低为原来的一半。

SRGAN模型基于VGG网络特征构建内容损失函数(content loss),代替了之前的MSE损失函数,通过生成器和判别器的对抗学习取得了视觉感知上更好的重建结果。

3 模型训练

大多数超分重建任务的数据集都是通过从高分辨率图像进行下采样获得,论文中往往选择ImageNet数据集,由于我们这里打算专门对人脸进行清晰度恢复,因此选择了一个常用的高清人脸数据集,CelebA-HQ,它发布于 2019 年,包含30000张不同属性的高清人脸图,其中图像大小均为1024×1024,预览如下。

【百战GAN】SRGAN人脸低分辨率老照片修复代码实战_人工智能_06

接下来我们对代码进行解读:

3.1 数据预处理

图像超分辨数据集往往都是从高分辨率图进行采样得到低分辨率图,然后组成训练用的图像对,下面是对训练集和验证集中数据处理的核心代码:

from os import listdir

from os.path import join

import numpy as np

from PIL import Image

from torch.utils.data.dataset import Dataset

from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize

import imgaug.augmenters as iaa

aug = iaa.JpegCompression(compression=(0, 50))

## 基于上采样因子对裁剪尺寸进行调整,使其为upscale_factor的整数倍

def calculate_valid_crop_size(crop_size, upscale_factor):

return crop_size - (crop_size % upscale_factor)

## 训练集高分辨率图预处理函数

def train_hr_transform(crop_size):

    return Compose([

        RandomCrop(crop_size),

        ToTensor(),

])

## 训练集低分辨率图预处理函数

def train_lr_transform(crop_size, upscale_factor):

    return Compose([

        ToPILImage(),

        Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),

        ToTensor()

    ])

## 训练数据集类

class TrainDatasetFromFolder(Dataset):

    def __init__(self, dataset_dir, crop_size, upscale_factor):

        super(TrainDatasetFromFolder, self).__init__()

        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)] ##获得所有图像

crop_size = calculate_valid_crop_size(crop_size, upscale_factor)##获得裁剪尺寸

        self.hr_transform = train_hr_transform(crop_size) ##高分辨率图预处理函数

        self.lr_transform = train_lr_transform(crop_size, upscale_factor) ##低分辨率图预处理函数

##数据集迭代指针

    def __getitem__(self, index): 

        hr_image = self.hr_transform(Image.open(self.image_filenames[index])) ##随机裁剪获得高分辨率图

        lr_image = self.lr_transform(hr_image) ##获得低分辨率图

img = np.array(lr_image)

image_aug = aug(image=img)

lr_image = Image.fromarray(image_aug.astype('uint8')).convert('RGB') 

        return ToTensor()(lr_image), hr_image

    def __len__(self):

        return len(self.image_filenames)

## 验证数据集类

class ValDatasetFromFolder(Dataset):

    def __init__(self, dataset_dir, upscale_factor):

        super(ValDatasetFromFolder, self).__init__()

        self.upscale_factor = upscale_factor

        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]

    def __getitem__(self, index):

        hr_image = Image.open(self.image_filenames[index])

##获得图像窄边获得裁剪尺寸

        w, h = hr_image.size

        crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)

        lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)

        hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)

        hr_image = CenterCrop(crop_size)(hr_image) ##中心裁剪获得高分辨率图

        lr_image = lr_scale(hr_image) ##获得低分辨率图

img = np.array(lr_image)

image_aug = aug(image=img)

lr_image = Image.fromarray(image_aug.astype('uint8')).convert('RGB') 

        hr_restore_img = hr_scale(lr_image)

        return ToTensor()(lr_image), ToTensor()(hr_image)

    def __len__(self):

        return len(self.image_filenames)

从上述代码可以看出,包含了两个预处理函数接口,分别是train_hr_transform,train_lr_transform。train_hr_transform包含的操作主要是随机裁剪,而train_lr_transform包含的操作主要是缩放。

另外还有一个函数calculate_valid_crop_size,对于训练集来说,它用于当配置的图像尺寸crop_size不能整除上采样因子upscale_factor时对crop_size进行调整,我们在使用的时候应该避免这一点,即配置crop_size让它等于upscale_factor的整数倍。对于验证集,图像的窄边min(w, h)会被用于crop_size的初始化,所以该函数的作用是当图像的窄边不能整除上采样因子upscale_factor时对crop_size进行调整。

训练集类TrainDatasetFromFolder包含了若干操作,它使用train_hr_transform从原图像中随机裁剪大小为裁剪尺寸的正方形的图像,使用train_lr_transform获得对应的低分辨率图。而验证集类ValDatasetFromFolder则将图像按照调整后的crop_size进行中心裁剪,然后使用train_lr_transform获得对应的低分辨率图。

在这里我们使用了随机裁剪和JPEG噪声压缩作为训练时的数据增强操作,JPEG噪声的添加使用了imaaug库,其项目地址为https://github.com/aleju/imgaug,下图展示了对一些样本添加不同幅度的JPEG噪声的图像。

【百战GAN】SRGAN人脸低分辨率老照片修复代码实战_python_07

第1行为分辨率512×512大小的原图,第2行是缩放为128×128大小,不添加JPEG压缩噪声的图,第3行和第4行分别是缩放为128×128大小,并且用imgaug库添加幅度为30%和90%的JPEG压缩噪声的图像。可以看出JPEG噪声对图像质量影响很大,尤其是当噪声幅度很大时,斑块效应非常明显。我们在后面会比较添加不同幅度的JPEG噪声与不添加JPEG噪声的模型结果对比,验证对于真实的图像超分辨任务,更接近真实退化过程的数据增强操作是必要的。

3.2 生成器网络

生成器是一个基于残差模块的上采样模型,它的定义包括残差模块,上采样模块以及主干模型,如下。

## 残差模块

class ResidualBlock(nn.Module):

    def __init__(self, channels):

        super(ResidualBlock, self).__init__()

        ## 两个卷积层,卷积核大小为3×3,通道数不变

        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

        self.bn1 = nn.BatchNorm2d(channels)

        self.prelu = nn.PReLU()

        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):

        residual = self.conv1(x)

        residual = self.bn1(residual)

        residual = self.prelu(residual)

        residual = self.conv2(residual)

        residual = self.bn2(residual)

        return x + residual

## 上采样模块,每一个恢复分辨率为2

class UpsampleBLock(nn.Module):

    def __init__(self, in_channels, up_scale):

        super(UpsampleBLock, self).__init__()

        ## 卷积层,输入通道数为in_channels,输出通道数为in_channels * up_scale ** 2

        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)

        ## PixelShuffle上采样层,来自于后上采样结构

        self.pixel_shuffle = nn.PixelShuffle(up_scale)

        self.prelu = nn.PReLU()

    def forward(self, x):

        x = self.conv(x)

        x = self.pixel_shuffle(x)

        x = self.prelu(x)

        return x

## 生成模型

class Generator(nn.Module):

    def __init__(self, scale_factor):

        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()

        ## 第一个卷积层,卷积核大小为9×9,输入通道数为3,输出通道数为64

        self.block1 = nn.Sequential(

            nn.Conv2d(3, 64, kernel_size=9, padding=4),

            nn.PReLU()

        )

        ## 6个残差模块

        self.block2 = ResidualBlock(64)

        self.block3 = ResidualBlock(64)

        self.block4 = ResidualBlock(64)

        self.block5 = ResidualBlock(64)

        self.block6 = ResidualBlock(64)

        self.block7 = nn.Sequential(

            nn.Conv2d(64, 64, kernel_size=3, padding=1),

            nn.BatchNorm2d(64)

        )

        ## upsample_block_num个上采样模块,每一个上采样模块恢复2倍的上采样倍率

        block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]

        ## 最后一个卷积层,卷积核大小为9×9,输入通道数为64,输出通道数为3

        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))

        self.block8 = nn.Sequential(*block8)

    def forward(self, x):

        block1 = self.block1(x)

        block2 = self.block2(block1)

        block3 = self.block3(block2)

        block4 = self.block4(block3)

        block5 = self.block5(block4)

        block6 = self.block6(block5)

        block7 = self.block7(block6)

        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2

在上述的生成器定义中,调用了nn.PixelShuffle模块来实现上采样,它的具体原理就是上面提到的pixelshuffle。

3.3 判别器定义

判别器是一个普通的类似于VGG的CNN模型:

## 残差模块

class Discriminator(nn.Module):

    def __init__(self):

        super(Discriminator, self).__init__()

        self.net = nn.Sequential(

            nn.Conv2d(3, 64, kernel_size=3, padding=1),

            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),

            nn.BatchNorm2d(64),

            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),

            nn.BatchNorm2d(128),

            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),

            nn.BatchNorm2d(128),

            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),

            nn.BatchNorm2d(256),

            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),

            nn.BatchNorm2d(256),

            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),

            nn.BatchNorm2d(512),

            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),

            nn.BatchNorm2d(512),

            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),

            ## 两个全连接层,使用卷积实现

            nn.Conv2d(512, 1024, kernel_size=1),

            nn.LeakyReLU(0.2),

            nn.Conv2d(1024, 1, kernel_size=1)

        )

    def forward(self, x):

        batch_size = x.size(0)

        return torch.sigmoid(self.net(x).view(batch_size))

3.4 损失函数定义

损失函数的定义是框架中的重点,下面主要看生成器的损失,因为判别器就是一个分类损失。

## 生成器损失定义

class GeneratorLoss(nn.Module):

    def __init__(self):

        super(GeneratorLoss, self).__init__()

        vgg = vgg16(pretrained=True)

        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()

        for param in loss_network.parameters():

            param.requires_grad = False

        self.loss_network = loss_network

        self.mse_loss = nn.MSELoss() ##MSE损失

        self.tv_loss = TVLoss() ##TV平滑损失

    def forward(self, out_labels, out_images, target_images):

        # 对抗损失

        adversarial_loss = torch.mean(1 - out_labels)

        # 感知损失

        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))

        # 图像MSE损失

        image_loss = self.mse_loss(out_images, target_images)

        # TV平滑损失

        tv_loss = self.tv_loss(out_images)

        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss

## TV平滑损失

class TVLoss(nn.Module):

    def __init__(self, tv_loss_weight=1):

        super(TVLoss, self).__init__()

        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):

        batch_size = x.size()[0]

        h_x = x.size()[2]

        w_x = x.size()[3]

        count_h = self.tensor_size(x[:, :, 1:, :])

        count_w = self.tensor_size(x[:, :, :, 1:])

        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()

        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()

        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod

    def tensor_size(t):

        return t.size()[1] * t.size()[2] * t.size()[3]

上述代码中定义了对抗网络损失,逐像素的图像MSE损失,基于VGG模型的感知损失,用于约束图像平滑的TV平滑损失。

至此就完成了工程中核心代码的解读。

4 模型训练与测试

接下来我们对模型进行训练和测试。

4.1 模型训练

模型训练就是完成模型定义,数据载入,可视化以及存储等工作,核心代码如下:

## 参数解释器

parser = argparse.ArgumentParser(description='Train Super Resolution Models')

## 裁剪尺寸,即训练尺度

parser.add_argument('--crop_size', default=240, type=int, help='training images crop size')

## 超分上采样倍率

parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8],

                    help='super resolution upscale factor')

##训练主代码

if __name__ == '__main__':

    opt = parser.parse_args()

    CROP_SIZE = opt.crop_size

    UPSCALE_FACTOR = opt.upscale_factor

    NUM_EPOCHS = opt.num_epochs

    ## 获取训练集/验证集

    train_set = TrainDatasetFromFolder('data/train', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)

    val_set = ValDatasetFromFolder('data/val', upscale_factor=UPSCALE_FACTOR)

    train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)

    val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

    netG = Generator(UPSCALE_FACTOR) ##生成器定义

    netD = Discriminator() ##判别器定义

    generator_criterion = GeneratorLoss() ##生成器优化目标

    ## 是否使用GPU

    if torch.cuda.is_available():

        netG.cuda()

        netD.cuda()

        generator_criterion.cuda()

    ##生成器和判别器优化器

    optimizerG = optim.Adam(netG.parameters())

    optimizerD = optim.Adam(netD.parameters())

    ## epoch迭代

    for epoch in range(1, NUM_EPOCHS + 1):

        train_bar = tqdm(train_loader)

        running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0} ##结果变量

        netG.train() ##生成器训练

        netD.train() ##判别器训练

        ## 每一个epoch的数据迭代

        for data, target in train_bar:

            g_update_first = True

            batch_size = data.size(0)

            running_results['batch_sizes'] += batch_size

            ## 优化判别器,最大化D(x)-1-D(G(z))

            real_img = Variable(target)

            if torch.cuda.is_available():

                real_img = real_img.cuda()

            z = Variable(data)

            if torch.cuda.is_available():

                z = z.cuda()

            fake_img = netG(z) ##获取生成结果

            netD.zero_grad()

            real_out = netD(real_img).mean()

            fake_out = netD(fake_img).mean()

            d_loss = 1 - real_out + fake_out

            d_loss.backward(retain_graph=True)

            optimizerD.step() ##优化判别器

            ## 优化生成器 最小化1-D(G(z)) + Perception Loss + Image Loss + TV Loss

            netG.zero_grad()

            g_loss = generator_criterion(fake_out, fake_img, real_img)

            g_loss.backward()

            fake_img = netG(z)

            fake_out = netD(fake_img).mean()

            optimizerG.step()

以上就是训练中的代码,训练时采用的crop_size为240×240,训练时我们将所有图缩放为320×320,使用的优化器为Adam,添加了变量缓存的完整代码请看下文的完整工程项目。

图中jpeg0对应的曲线表示不添加JPEG压缩噪声,jpeg0-50,jpeg30-70对应的曲线分别表示添加0~50%幅度的压缩噪声,添加30~70%幅度的压缩噪声的训练结果。可以看出,模型已经基本收敛,添加噪声幅度越大,则最终的PSNR指标和SSIM指标也会越低。

【百战GAN】SRGAN人脸低分辨率老照片修复代码实战_人工智能_08

下图从左至右分别是有噪声原图,无噪声原图,去噪声图。

【百战GAN】SRGAN人脸低分辨率老照片修复代码实战_计算机视觉_09

4.2 模型推理

实际用的时候,我们希望使用自己的图片来完成推理,完整的代码如下:

import torch

from PIL import Image

from torch.autograd import Variable

from torchvision.transforms import ToTensor, ToPILImage

from model import Generator

UPSCALE_FACTOR = 4 ##上采样倍率

TEST_MODE = True ## 使用GPU进行测试

IMAGE_NAME = sys.argv[1] ##图像路径

RESULT_NAME = sys.argv[1] ##结果图路径

MODEL_NAME = 'netG.pth' ##模型路径

model = Generator(UPSCALE_FACTOR).eval() ##设置验证模式

if TEST_MODE:

    model.cuda()

    model.load_state_dict(torch.load(MODEL_NAME))

else:

    model.load_state_dict(torch.load(MODEL_NAME, map_location=lambda storage, loc: storage))

image = Image.open(IMAGE_NAME) ##读取图片

image = Variable(ToTensor()(image), volatile=True).unsqueeze(0) ##图像预处理

if TEST_MODE:

    image = image.cuda()

out = model(image)

out_img = ToPILImage()(out[0].data.cpu())

out_img.save(RESULT_NAME)

下图展示了一张真人图的超分辨结果,输入图是从512×512的大小缩放为128×128大小,然后分别使用opencv的imwrite函数存储为JPEG和PNG格式,前者使用opencv库默认的JPEG压缩率。

【百战GAN】SRGAN人脸低分辨率老照片修复代码实战_深度学习_10

真人图像的SRGAN超分结果

图中第1列为原图,其中两行分别是JPEG格式和PNG格式,显示时使用双线性插值进行上采样。第2列为不添加JPEG噪声数据增强进行训练后的4倍超分结果,第3列为添加0~50%随机幅度的JPEG噪声数据增强进行训练后的超分结果,第4列为添加30~70%随机幅度的JPEG噪声数据增强进行训练后的超分结果。

比较第1行和第2行,可以看出,对于JPEG压缩的图像,如果没有添加噪声数据增强,则结果图会放大原图中的噪声,其中第2列结果非常明显。不过,噪声的幅度也不能过大,否则重建结果会失真,比较第4列结果和第3列结果,虽然第4列有更强的噪声抑制能力,但是人脸图已经开始出现失真,如皮肤过于平滑,眼睛失真明显,因此不能一味得增加噪声幅度。

虽然我们的训练数据集是真人图,但是模型也可以泛化到其他域的人脸图像。


举报

相关推荐

0 条评论