0
点赞
收藏
分享

微信扫一扫

220315-MultilinearMap的Numpy与Pytorch实现


import numpy as np
import torch


def multilinear_map_numpy(F, P):
import numpy as np
F1 = F[:,:,np.newaxis]
P1 = P[:,:,np.newaxis]
F2 = np.repeat(F1, P.shape[1], 2)
P2 = np.repeat(P1, F.shape[1], 2).swapaxes(1,2)
P3 = F2*P2
P4 = P3.reshape(F.shape[0], -1)
# print(F)
# print('\n')
# print(P)
# print(F2.shape, P2.shape)
# print(P3)
# print(P4)
return P4

def multilinear_map_torch(F, P):
import torch
F1 = F[:,:,None]
P1 = P[:,:,None]
F2 = F1.repeat(1,1,P.shape[1])
P2 = P1.repeat(1,1, F.shape[1]).swapaxes(1,2)
P3 = F2*P2
P4 = P3.view(F.shape[0],-1)
# print(F)
# print('\n')
# print(P)
# print(F2.shape, P2.shape)
# print(P3)
# print(P4)
return P4


F = np.arange(60).reshape(10,6)
P = np.arange(30).reshape(10,3)
M = multilinear_map_numpy(F, P)
print(M.shape)

F = torch.Tensor(np.arange(60).reshape(10,6))
P = torch.Tensor(np.arange(30).reshape(10,3))
M = multilinear_map_torch(F, P)
print(M.shape)


举报

相关推荐

0 条评论