0
点赞
收藏
分享

微信扫一扫

语义分割之模型构建(pytorch)

MaxWen 2022-06-27 阅读 43

前言

语义分割是对图像进行逐像素级的分类任务。本博文将从经典的编解码结构来简单介绍模型的构建过程。在阅读本博文之前,最好熟悉卷积神经网络的基本组件,例如卷积、反卷积、batch_normalization、激活函数、dropout、图像插值等。

U-Net模型架构

语义分割之模型构建(pytorch)_编码器
上图为15年Ronneberger提出的网络模型U-Net,本文将介绍Res_U-Net,该模型是以Res-Net为编码器构建而成的,当前主流的做法都是以现有的图像分类网络作为分割模型的编码器,常见的分类网络有resnet、vggnet、Alex 等等。

编码器

编码器的主要功能是对图像进行特征提取,随着网络的加深,其所获得的语义信息越丰富,下图是经典的残差网络的基本单元。
语义分割之模型构建(pytorch)_ide_02
本文将对resnet进行改造,将其末端的全连接层去除,替换为解码器(下面一节将会介绍),接下来上代码。

net = torchvision.models.resnet18(pretrained=pretrained)
net.conv1 = torch.nn.Conv2d(3, 64, 7, 2, 3, bias=False)
self.encoder = net
decoder_channels = (256, 128, 64, 32, 16)
encoder_channels = (512, 256, 128, 64, 64)
in_channels = self.compute_channels(encoder_channels, decoder_channels)
out_channels = decoder_channels

for layer in self.encoder.parameters():
layer.requires_grad = not freeze_encoder

self.relu = nn.ReLU(inplace=True)

self.conv0 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool)
self.conv1 = self.encoder.layer1
self.conv2 = self.encoder.layer2
self.conv3 = self.encoder.layer3
self.conv4 = self.encoder.layer4

在代码中,我们只需要接收编码器产生的特征图,由于U-Net需要跳跃连接,所以需要各个层级的特征图,在这里,resnet共产生四种不同分辨率的特征图。

解码器

解码器的主要功能是恢复位置信息。在U型结构中,其先对分辨率较低的feature map进行上采样操作(一般都选择双线性插值),然后将resnet中产生的相同分辨率的feature map进行跳跃连接,再使用包含两个卷积层的基本组件进行特征融合,如下图(a)所示,如此循环,直到恢复至原图尺寸。在解码器末端使用1×1卷积将其通道数转换为类别数,再利用sigmoid函数将其值压缩至01之间,如下图(b)所示。
语义分割之模型构建(pytorch)_编码器_03
解码模块代码如下

class Conv2dReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding=0,
stride=1, use_batchnorm=True, **batchnorm_params):

super().__init__()

layers = [
nn.Conv2d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, bias=not (use_batchnorm)),
nn.ReLU(inplace=True),
]

if use_batchnorm:
layers.insert(1, nn.BatchNorm2d(out_channels, **batchnorm_params))

self.block = nn.Sequential(*layers)

def forward(self, x):
return self.block(x)


class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels,
use_batchnorm=True,
attention_kernel_size=3,
reduction=8):
super().__init__()
self.block = nn.Sequential(
Conv2dReLU(in_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm),
Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm),
)

def forward(self, x):
x, skip = x
x = F.interpolate(x, scale_factor=2, mode='nearest')
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.block(x)
return x

完整的模型结构代码

from torch import nn
from torch.nn import functional as F
import torch
import torchvision


class Conv2dReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding=0,
stride=1, use_batchnorm=True, **batchnorm_params):

super().__init__()

layers = [
nn.Conv2d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, bias=not (use_batchnorm)),
nn.ReLU(inplace=True),
]

if use_batchnorm:
layers.insert(1, nn.BatchNorm2d(out_channels, **batchnorm_params))

self.block = nn.Sequential(*layers)

def forward(self, x):
return self.block(x)


class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels,
use_batchnorm=True,
attention_kernel_size=3,
reduction=8):
super().__init__()
self.block = nn.Sequential(
Conv2dReLU(in_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm),
Conv2dReLU(out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm),
)

def forward(self, x):
x, skip = x
x = F.interpolate(x, scale_factor=2, mode='nearest')
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.block(x)
return x


class UNet(nn.Module):
"""
UNet (https://arxiv.org/abs/1505.04597) with Resnet34(https://arxiv.org/abs/1512.03385) encoder

"""
def __init__(self, num_classes=1, pretrained=True, use_batchnorm=True, freeze_encoder=False):
"""
:param num_classes:
:param pretrained:
False - no pre-trained network is used
True - encoder is pre-trained with resnet34
:is_deconv:
False: bilinear interpolation is used in decoder
True: deconvolution is used in decoder
"""
super().__init__()
self.num_classes = num_classes
self.pool = nn.MaxPool2d(2, 2)

net = torchvision.models.resnet18(pretrained=pretrained)
net.conv1 = torch.nn.Conv2d(3, 64, 7, 2, 3, bias=False)
self.encoder = net
decoder_channels = (256, 128, 64, 32, 16)
encoder_channels = (512, 256, 128, 64, 64)
in_channels = self.compute_channels(encoder_channels, decoder_channels)
out_channels = decoder_channels

for layer in self.encoder.parameters():
layer.requires_grad = not freeze_encoder

self.relu = nn.ReLU(inplace=True)

self.conv0 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool)
self.conv1 = self.encoder.layer1
self.conv2 = self.encoder.layer2
self.conv3 = self.encoder.layer3
self.conv4 = self.encoder.layer4

self.layer1 = DecoderBlock(in_channels[0], out_channels[0], use_batchnorm=use_batchnorm)
self.layer2 = DecoderBlock(in_channels[1], out_channels[1], use_batchnorm=use_batchnorm)
self.layer3 = DecoderBlock(in_channels[2], out_channels[2], use_batchnorm=use_batchnorm)
self.layer4 = DecoderBlock(in_channels[3], out_channels[3], use_batchnorm=use_batchnorm)
self.layer5 = DecoderBlock(in_channels[4], out_channels[4], use_batchnorm=use_batchnorm)
self.final = nn.Conv2d(out_channels[4], num_classes, kernel_size=1)

def compute_channels(self, encoder_channels, decoder_channels):
channels = [
encoder_channels[0] + encoder_channels[1],
encoder_channels[2] + decoder_channels[0],
encoder_channels[3] + decoder_channels[1],
encoder_channels[4] + decoder_channels[2],
0 + decoder_channels[3],
]
return channels

def forward(self, x):
conv0 = self.encoder.conv1(x)
conv0 = self.encoder.bn1(conv0)
conv0 = self.encoder.relu(conv0)

conv1 = self.pool(conv0)
conv1 = self.conv1(conv1)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
# print(f"conv0:{conv0.size()}") #[32, 64, 64, 64]
# print(f"conv1:{conv1.size()}") #[32, 64, 64, 64]
# print(f"conv2:{conv2.size()}") #[32, 128, 32, 32]
# print(f"conv3:{conv3.size()}") #[32, 256, 16, 16]
# print(f"conv4:{conv4.size()}") #[32, 512, 8, 8]
x = self.layer1([conv4, conv3])
x = self.layer2([x, conv2])
x = self.layer3([x, conv1])
x = self.layer4([x, conv0])
x = self.layer5([x, None])
x = self.final(x)

return x


举报

相关推荐

0 条评论