深度学习:invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to

阅读 79

2022-04-05

版本问题

源代码:

test_loss += F.nll_loss(out_tgt.log(), target_label, size_average=False).data[0] # sum up batch loss

修改后代码:

test_loss += F.nll_loss(out_tgt.log(), target_label, size_average=False).item()  # sum up batch loss

精彩评论(0)

0 0 举报