特征提取后的维度变化

GG_lyf

关注

阅读 86

2022-03-22

主要用于在使用一个现有的模型对数据进行特征提取时,进行维度降维,主要使用到线性变化,例如使用resnet50进行特征提取时,可能得到的是2048维特征,但是此时如果进行线性变化,就可得到512维或是128维等特征

class LinearNorm(nn.Module):
    def __init__(self, cfg):
        super(LinearNorm, self).__init__()
        self.fc = nn.Linear(cfg['IN_CHANNELS'], cfg['DIM'])
        self.fc.apply(weights_init_kaiming)

    def forward(self, x):
        x = self.fc(x)
        x = nn.functional.normalize(x, p=2, dim=1)
        return x

cfg['IN_CHANNELS'], cfg['DIM']分别表示输入维度和输出维度

使用方法

backbone为要使用的特征提取模型


head = LinearNorm(head_cfg)
self.model = Sequential(OrderedDict([("backbone", backbone), ("head", head)]))

精彩评论(0)

0 0 举报