一、语法格式
格式一(只针对argmax函数):
torch.argmax(input) → LongTensor
功能:
Returns the indices of the maximum value of all elements in the input tensor。
即:返回输入张量中所有元素中最大值对应的索引(按行搜索);如果有多个相同的值,则返回第一次遇到的那个值对应的索引。
举例:
In [28]: r=torch.tensor([[1,2,3,4,5],[6,7,8,9,10],[11,12,13,14,15]])
In [29]: torch.argmax(r)
Out[29]: tensor(14)
格式二:
[1]torch.argmax(input, dim=None, keepdim=False)
功能:
Returns the indices of the maximum values of a tensor across a dimension.
- input(Tensor) – the input tensor.即:输出张量。
- dim(int) – the dimension to reduce. If
None
, the argmax of the flattened input is returned.即:要减少的维数。
- keepdim(bool) – whether the output tensor has
dim
retained or not. Ignored if dim=None
.即:
举例:
In [30]: a = torch.randn(4, 4)
In [31]: a
Out[31]:
tensor([[ 1.4360, 0.6342, -0.5233, 0.4902],
[ 1.1998, -0.8644, 0.5244, 0.2690],
[ 0.0998, -1.5043, 0.1619, -1.4634],
[ 0.0992, -1.0843, -1.3829, 0.5790]])
In [32]: torch.argmax(a)
Out[32]: tensor(0)
In [33]: torch.argmax(a,dim=0)
Out[33]: tensor([0, 0, 1, 3])
In [34]: torch.argmax(a,dim=1)
Out[34]: tensor([0, 0, 2, 3])
- 对于tensor(0)输出,意义如下:
第0个: 1.4360 | 第1个: 0.6342 | 第2个: -0.5233 | 第3个: 0.4902 | 第4个: 1.1998 | 第5个: -0.8644 | 第6个: 0.5244 | 第7个: 0.2690 | 第8个: 0.0998 | 第9个: -1.5043 |
第10个: 0.1619 | 第11个: -1.4634 | 第12个: 0.0992 | 第13个: -1.0843 | 第14个: -1.3829 | 第15个: 0.5790 |
- 对于tensor([0, 0, 1, 3])输出,意义如下:
这时,每一列视为下标从0到3的一个数组。易见,从左到右每一列(数组)中最大值分别为:1.4360、0.6342、0.5244、0.5790,它们对应的一维数组中的下标分别为0、0、1、3,于是得到张量tensor([0, 0, 1, 3])。
- 对于tensor([0, 0, 2, 3])输出:
意义就容易理解了。沿水平方向从左向右从上到下看,每一行对应一个数组,下标向左向右依次为0、1、2、3。于是,这4个数组中最大值分别为1.4360、1.1998、0.1619、1.3829,它们对应的一维数组中的下标分别为0、0、2、3,于是得到张量tensor([0, 0, 2, 3])。
功能:
[2]torch.argmin(input, dim=None, keepdim=False) → LongTensor
argmin功能:Returns the indices of the minimum value(s) of the flattened tensor or along a dimension。
理解类似上面argmax函数的第二种格式,相应于dim=0和dim=1,依次返回由最小值对应下标组成的列方向数组与行方向数组组成的张量。