pytorch,构造矩阵mask

阅读 46

2022-07-27


import torch
mask = torch.triu(
torch.ones(5, 5), diagonal=1).byte()
print(mask)
mask = torch.triu(
torch.ones(5, 5), diagonal=2).byte()
print(mask)

tensor(
[[0, 1, 1, 1, 1],
[0, 0, 1, 1, 1],
[0, 0, 0, 1, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0]], dtype=torch.uint8)

tensor(
[[0, 0, 1, 1, 1],
[0, 0, 0, 1, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]], dtype=torch.uint8)


精彩评论(0)

0 0 举报