0
点赞
收藏
分享

微信扫一扫

【Pytorch】torch.gather用法详解


torch.gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None)

沿指定的维收集值。

参数:

  • ​input​​ (Tensor) –输入张量
  • ​dim​​ (int) – 要索引的维
  • ​index​​ (LongTensor) – 要收集的元素的索引
  • ​sparse_grad​​ (bool, optional) – 如果为​​True​​,关于​​input ​​的梯度将是稀疏张量。
  • ​out​​ (Tensor, optional) –输出张量

对于一维张量,输出由以下公式指定:

out[i] = input[index[i]]  # dim= 0

例如:

input_tensor= torch.tensor([1, 2])
index = torch.tensor([0, 0])
input[0]=1
input[1]=2

index[0]=0
index[0]=0
out = torch.gather(input, 0, index)
out[0]=input[index[0]]=input[0]=1
out[1]=input[index[1]]=input[0]=1

对于二维张量,输出由以下公式指定:

out[i][j] = input[index[i][j]][j]  # if dim == 0
out[i][j] = input[i][index[i][j]] # if dim == 1

举个栗子:

input_tensor= torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]])
input[0][0]=1
input[0][1]=2
input[1][0]=3
input[1][1]=4

index[0][0]=0
index[0][1]=0
index[1][0]=1
index[1][1]=0

dim=0:

out = torch.gather(input, 0, torch.tensor([[0, 0], [1, 0]]))
print(out)
out[0][0]=input[index[0][0]][0]=input[0][0]=1
out[0][1]=input[index[0][1]][1]=input[0][1]=2
out[1][0]=input[index[1][0]][0]=input[1][0]=3
out[1][1]=input[index[1][1]][1]=input[0][1]=2

dim=1:

out = torch.gather(input, 1, torch.tensor([[0, 0], [1, 0]]))
print(out)
out[0][0]=input[0][index[0][0]]=input[0][0]=1
out[0][1]=input[0][index[0][1]]=input[0][0]=1
out[1][0]=input[1][index[1][0]]=input[1][1]=4
out[1][1]=input[1][index[1][1]]=input[1][0]=3

对于三维张量,同理:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

注意:​​input​​​和​​index​​​必须有相同的维度。​​out​​​尺寸和​​index​​​相同;​​input​​​和​​index​​之间不会广播。

对于​​d=dim​​​,可以有​​index.size(d)< input.size(d)​​:

input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
index = torch.tensor([[1, 0],[2, 0]])
print('input_tensor.size:', input_tensor.size())
print('index.size:', index.size())
out = torch.gather(input_tensor, 1, index)
print(out)
input_tensor.size: torch.Size([2, 3])
index.size: torch.Size([2, 2])
tensor([[2, 1],
[6, 4]])

或​​index.size(d)> input.size(d)​​:

input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
index = torch.tensor([[1, 0, 1, 0], [2, 0, 2, 0]])
print('input_tensor.size:', input_tensor.size())
print('index.size:', index.size())
out = torch.gather(input_tensor, 1, index)
print(out)
input_tensor.size: torch.Size([2, 3])
index.size: torch.Size([2, 4])
tensor([[2, 1, 2, 1],
[6, 4, 6, 4]])



举报

相关推荐

0 条评论