加载网络权重,去除全连接层的权重

阅读 110

2022-04-14

仅作为记录,大佬请跳过。

感谢老师的示范。

fc_keys = [k for k in state_dict.keys() if "fc" in k]
for k in fc_keys:
    del state_dict[k]

查看设计的网络加载的网络权重的有没有不同的层

def load_from_pretrained(self, ckpt_path):
    print(f"==============> Loading weight {ckpt_path} for fine-tuning......")
    ckpt = torch.load(ckpt_path, map_location='cpu')
    state_dict = ckpt

    fc_keys = [k for k in state_dict.keys() if "fc" in k]
    for k in fc_keys:
        del state_dict[k]

    from pprint import pprint
    missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
    print('missing_keys = ')
    pprint(missing_keys)
    print('unexpected_keys = ')
    pprint(unexpected_keys)
    print(f"=> loaded successfully '{ckpt_path}'")
    print('ok')

其中,self指设计的网络

精彩评论(0)

0 0 举报