0
点赞
收藏
分享

微信扫一扫

PyTorch的torch.cat、squeeze()、unsqueeze()和size()函数

目录

一、sequeeze()函数

二、unsequeeze()函数 

三、size()函数

四、torch.cat函数


在Pytorch做深度学习过程中,CNN的卷积和池化过程中会用到torch.cat、squeeze()、unsqueeze()和size()函数,下面分别做讲解:

一、sequeeze()函数

x.squeeze(dim)

用途:进行维度压缩,去掉tensor中维数为1的维度

参数设置:如果设置dim=a,就是去掉指定维度中维数为1的

示例: 

import torch
x = torch.tensor([[[1],[2]],[[3],[4]]])
print('x:',x)
print(x.shape)
x1 = x.squeeze()
print('x1:',x1)
print(x1.shape)
x2 = x.squeeze(2)
print('x2:',x2)
print(x2.shape)

输出:

x: tensor([[[1],
         [2]],
 
        [[3],
         [4]]])
torch.Size([2, 2, 1])
x1: tensor([[1, 2],
        [3, 4]])
torch.Size([2, 2])
x2: tensor([[1, 2],
        [3, 4]])
 torch.Size([2, 2, 1])

可以看出:

(1) x.squeeze(),shape 由(2,2,1)变为(2,2),说明维度为1时被去掉。

(2) x.squeeze(2),shape仍然为(2,2,1),这是因为只有维度为1时才会去掉。

二、unsequeeze()函数 

x.unsqueeze(dim=a)

用途:进行维度扩充,在指定位置加上维数为1的维度

参数设置:如果设置dim=a,就是在维度为a的位置进行扩充

示例:

import torch
x = torch.tensor([1,2,3,4])
print(x)
print(x.shape)
x1 = x.unsqueeze(0)
print(x1)
print(x1.shape)
x2 = x.unsqueeze(1)
print(x2)
print(x2.shape)
 
y = torch.tensor([[1,2,3,4],[9,8,7,6]])
print(y)
print(y.shape)
y1 = y.unsqueeze(0)
print(y1.shape)
print(y1)
print(y1.shape)
y2 = y.unsqueeze(1)
print(y2)
print(y2.shape)

输出:

x: tensor([1, 2, 3, 4])
torch.Size([4])
x1: tensor([[1, 2, 3, 4]])
torch.Size([1, 4])
x2: tensor([[1],
        [2],
        [3],
        [4]])
torch.Size([4, 1])
y: tensor([[1, 2, 3, 4],
        [9, 8, 7, 6]])
torch.Size([2, 4])
y1: tensor([[[1, 2, 3, 4],
         [9, 8, 7, 6]]])
torch.Size([1, 2, 4])
y2: tensor([[[1, 2, 3, 4]],

        [[9, 8, 7, 6]]])
torch.Size([2, 1, 4])

可以看出:

(1) x.unsqueeze(0) ,shape 由(4)变为(1,4),证明在第一个位置增加一个维度。

(2) x.unsqueeze(1) ,shape 由(4)变为(4,1),证明在第二个位置增加一个维度。

三、size()函数

介绍
size()函数主要是用来统计矩阵元素个数,或矩阵某一维上的元素个数的函数。

参数
numpy.size(a, axis=None)
a:输入的矩阵
axis:int型的可选参数,指定返回哪一维的元素个数。当没有指定时,返回整个矩阵的元素个数。

示例: 

a = np.array([[1,2,3],[4,5,6]])
print(a.shape)
print(np.size(a,0))
print(np.size(a,1))
print(np.size(a))

输出:

(2,3)
2
3
6

示例:

b = tensor([[[1, 2, 3, 4]],

        [[9, 8, 7, 6]]])
print(b.shape)
print(b.size(0))
print(b.size(1))
print(b.size(2))

输出:

torch.Size([2, 1, 4])
2
1
4

可以看出:

axis的值没有设定,返回矩阵的元素个数
axis = 0,返回该二维矩阵的行数
axis = 1,返回该二维矩阵的列数

四、torch.cat函数

torch.cat(inputs, dimension=0) → Tensor,在给定维度上对输入的张量序列seq 进行连接操作。
torch.cat()可以看做 torch.split() 和 torch.chunk()的反操作。 cat() 函数面例子更好的理解。
参数:
 inputs (sequence of Tensors) – 可以是任意相同Tensor 类型的python 序列
 dimension (int, optional) – 沿着此维连接张量序列。
示例:

x = torch.randn(2, 3)
print(x)
print(x.shape)
x1 = torch.cat((x,x,x,),0)
print(x1)
print(x1.shape)
y = torch.cat((x,x,x,),1)
print(y1)
print(y1.shape)

输出:

tensor([[-1.1883,  0.5793, -0.2716],
        [-0.8177,  0.0659,  0.8393]])
torch.Size([2, 3])
tensor([[-1.1883,  0.5793, -0.2716],
        [-0.8177,  0.0659,  0.8393],
        [-1.1883,  0.5793, -0.2716],
        [-0.8177,  0.0659,  0.8393],
        [-1.1883,  0.5793, -0.2716],
        [-0.8177,  0.0659,  0.8393]])
torch.Size([6, 3])
tensor([[-1.1883,  0.5793, -0.2716, -1.1883,  0.5793, -0.2716, -1.1883,  0.5793,
         -0.2716],
        [-0.8177,  0.0659,  0.8393, -0.8177,  0.0659,  0.8393, -0.8177,  0.0659,
          0.8393]])
torch.Size([2, 9])

举报

相关推荐

0 条评论