import torch
input_tensor = torch.tensor([1,2,3,4,5])
mask = input_tensor>3
print(mask)
indexes = mask.nonzero().squeeze()
print(indexes)
tensor([0, 0, 0, 1, 1], dtype=torch.uint8)
tensor([3, 4])
pytorch,nonzero 实例 使用
阅读 72
2022-07-27
import torch
input_tensor = torch.tensor([1,2,3,4,5])
mask = input_tensor>3
print(mask)
indexes = mask.nonzero().squeeze()
print(indexes)
tensor([0, 0, 0, 1, 1], dtype=torch.uint8)
tensor([3, 4])
相关推荐
精彩评论(0)