版本问题
源代码:
test_loss += F.nll_loss(out_tgt.log(), target_label, size_average=False).data[0] # sum up batch loss
修改后代码:
test_loss += F.nll_loss(out_tgt.log(), target_label, size_average=False).item() # sum up batch loss
深度学习:invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to
阅读 79
2022-04-05
版本问题
源代码:
test_loss += F.nll_loss(out_tgt.log(), target_label, size_average=False).data[0] # sum up batch loss
修改后代码:
test_loss += F.nll_loss(out_tgt.log(), target_label, size_average=False).item() # sum up batch loss
相关推荐
精彩评论(0)