0
点赞
收藏
分享

微信扫一扫

动态输入输出神经网络(自适应)

罗蓁蓁 2022-02-12 阅读 44


import torch
class Speech_Detect(torch.nn.Module):
def __init__(self,min_l=3,max_l=20,x_l=10):
super(Speech_Detect,self).__init__()
self.in_dict=dict()
self.out_dict=dict()
for i in range(min_l,max_l):
self.in_dict[i]=torch.nn.Linear(i,x_l)
self.out_dict[i]=torch.nn.Linear(x_l,i)



def forward(self,x,label_len):
x=self.in_dict[x.shape[0]](x.T)
x=self.out_dict[label_len](x)
return x.T


net = Speech_Detect()

for _ in range(100):
data=torch.randn([12,12])
out=net(data,11)
print(out.shape)


if __name__ == '__main__':
pass



举报

相关推荐

0 条评论