在网络中实现尺寸减半或尺寸不变的多种方式

ivy吖

关注

阅读 71

2022-01-27

概述

网络尺寸的改变,尤其是下采样时,比较经典的方式就是卷积加池化。然而,随着transformer等创新的工作的实现,似乎可以有一些更有趣的尺寸改变方式。某种意义上来说,就是一种数据重组的方式。

卷积+池化

显然,可以通过设计卷积结构的方式实现尺寸减半(卷积+池化)或尺寸不变。
对于尺寸减半来说,常用的参数有:

Conv2d(kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

对于尺寸不变而言,通常使用1*1的卷积,参数可以是:

Conv2d(kernel_size=(1, 1), stride=(1, 1), bias=False)

patch embedding & patch merging

在swin transformer中,使用了patch merging的方式完全替代了CNN中的下采样,具体为:

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)
        return x

而大名鼎鼎的patch embedding在与NLP中的position embedding本极其相似的同时,也在某种程度上实现了降采样的工作,其python 代码实现为:

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super(PatchEmbed, self).__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patchs_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] #每个patch的大小为56 * 56
        self.img_size = img_size
        self.patch_size = patch_size
        self.patchs_resolution = patchs_resolution
        self.num_patches = patchs_resolution[0] * patchs_resolution[1]

        self.in_chans = in_chans # 输入为三通道的RGB向量
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # 将kernel大小还有stride的大小都设置为与patch_size一样,假设为1/4分辨率,那么则将kernel大小和stride大小都设置为4
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None # 默认不做归一化,如果做的话,可以设置为layer_norm

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model({self.img_size[0]}*{self.img_size[1]})"
        x = self.proj(x).flatten(2).transpose(1, 2)
        # x的维度变化
        # (B,96,224/4,224/4)
        # (B,96,56*56)
        # (B,56*56,96)
        if self.norm is not None:
            x = self.norm(x)
        return x

参考文献

  • swin transformer:
    Liu Z, Lin Y, Cao Y, et al. Swin transformer: Hierarchical vision transformer using shifted windows[J]. arXiv preprint arXiv:2103.14030, 2021.

精彩评论(0)

0 0 举报