0
点赞
收藏
分享

微信扫一扫

1.7.2 torch 的sum 和 cumsum

import torch

a = torch.linspace(0, 10, 6).view(2, 3)

b = a.sum(dim=0)
c = torch.cumsum(a, dim=0)
print(a)
print(b)
print(c)
# tensor([[ 0.,  2.,  4.],
#         [ 6.,  8., 10.]])
# 
# tensor([ 6., 10., 14.])
# 
# tensor([[ 0.,  2.,  4.],
#         [ 6., 10., 14.]])


d = a.sum(dim=1)
e = torch.cumsum(a, dim=1)
print(d)
print(e)
# tensor([ 6., 24.])
# 
# tensor([[ 0.,  2.,  6.],
#         [ 6., 14., 24.]])

keepdim 参数,说明输出结果是否保留维度

import torch

a = torch.linspace(0, 10, 6).view(2, 3)
b = a.sum(dim=0)
c = a.sum(dim=0, keepdim=True)

print(a)
# tensor([[ 0.,  2.,  4.],
#         [ 6.,  8., 10.]])

print(b)
# tensor([ 6., 10., 14.])

print(b.shape)
# torch.Size([3])

print(c)
# tensor([[ 6., 10., 14.]])

print(c.shape)
# torch.Size([1, 3])

举报

相关推荐

0 条评论