0
点赞
收藏
分享

微信扫一扫

pytorch,筛选出一定范围的值

Raow1 2022-07-27 阅读 85


import torch
input_tensor = torch.tensor([1,2,3,4,5])
print(input_tensor>3)
mask = (input_tensor>3).nonzero()
print(mask)
print(input_tensor.index_select(0,mask))

tensor([0, 0, 0, 1, 1], dtype=torch.uint8)
tensor([3, 4])
tensor([4, 5])


举报

相关推荐

0 条评论