#@save
def masked_softmax(X, valid_lens):
"""通过在最后一个轴上掩蔽元素来执行softmax操作"""
# X:3D张量,valid_lens:1D或2D张量
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)
每一行为一个样本,tensor里面的值是代表了特征的有效性(也就是列数),举例来讲:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
给定一个2*2*4的随机矩阵,第一个矩阵,前面两列有效,在softmax后,后面两列权重为0;第二个矩阵前面三列有效,最后一列为0.
同时也可以指定每一个矩阵里的每一个样本。例如,第一个矩阵的第一个样本第一列有效,后三列无效。
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))