0
点赞
收藏
分享

微信扫一扫

在pytorch中,如何对标准的预训练模型进行修改以适应三通道以上的输入

net = resnet50(pretrained=pretrained)
with torch.no_grad():
pretrained_conv1 = net.conv1.weight.clone()
# Assign new conv layer with 4 input channels
net.conv1 = torch.nn.Conv2d(4, 64, 7, 2, 3, bias=False)
# Use same initialization as vanilla ResNet (Don't know if good idea)
torch.nn.init.kaiming_normal_(
net.conv1.weight, mode='fan_out', nonlinearity='relu')
# Re-assign pretraiend weights to first 3 channels
# (assuming alpha channel is last in your input data)
net.conv1.weight[:, :3] = pretrained_conv1

此代码的作用是修改标准预训练模型来适应四通道的输入,前三个通道保持原来的参数,最后一个通道kaiming初始化


举报

相关推荐

0 条评论