import torch
from torchvision import models
def mytranslate_pth_onnx(save_weight_pth, num=1, channel=3, height=512, width=672):
"""
Torch Version 1.8.2+cu101
:param save_weight_pth: 权重文件
:param num: 1
:param channel: 3
:param height: 512
:param width: 672
:return:
"""
print('Torch Version', torch.__version__)
print('格式转换中...')
# 模拟数据 N C H W
example = torch.rand(num, channel, height, width)
# 网络模型
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
checkpoint = torch.load(save_weight_pth, map_location=lambda storage, loc: storage.cuda(0))
model.load_state_dict(checkpoint)
# 导出ONNX
model.eval()
torch.onnx.export(model, example, save_weight_pth + ".onnx", verbose=True, opset_version=11)
print("格式转换完成!")
if __name__ == "__main__":
# 权重文件
save_weight_pth = r'C:\Users\admin\.cache\torch\hub\checkpoints\fasterrcnn_resnet50_fpn_coco-258fb6c6.pth'
num = 1
channel = 3
height = 512
width = 672
mytranslate_pth_onnx(save_weight_pth, num=1, channel=3, height=512, width=672)