0
点赞
收藏
分享

微信扫一扫

Pytorch:Tensor基本运算【add/sub/mul/div:加减乘除】【mm/matmul:矩阵相乘】【Pow/Sqrt/rsqrt:次方】【近似:floor...】【裁剪:clamp】

caoxingyu 2022-03-12 阅读 60

一、基本运算:加减乘除

1、乘法

1.1 a * b:element-wise 对应元素相乘

a * b:要求两个矩阵维度完全一致,即两个矩阵对应元素相乘,输出的维度也和原矩阵维度相同

1.2 torch.mul(a, b):element-wise 对应元素相乘

torch.mul(a, b):是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵

1.3 torch.mm(a,b):只能用于二维矩阵

torch.mm(a,b),要求两个矩阵维度是(n×m)和(m×p),即普通二维矩阵乘法

1.4 torch.matmul(a,b):可用于多维矩阵(只取最后2个维度做矩阵乘法)(适用broadcast机制)

torch.matul(a,b),matmul可以进行张量乘法,输入可以是高维,当输入是多维时,把多出的一维作为batch提出来,其他部分做矩阵乘法

重载运算符:@

其实就是支持多个矩阵对的并行相乘

在这里插入图片描述

二、Pow

在这里插入图片描述

三、exp

在这里插入图片描述

四、log

在这里插入图片描述

五、近似

在这里插入图片描述

六、裁剪:clamp

梯度的模一般小于10合适,如果是>100了,则太大了,需要梯度裁剪。

梯度的模:W.grad.norm(2)

grad.clamp(10):表示将元素种小于10的元素都置为10
在这里插入图片描述




参考资料:
pytorch中tensor的加减乘除和常见操作
Pytorch Tensor基本数学运算

举报
0 条评论