0
点赞
收藏
分享

微信扫一扫

pytorch使用LSTMCell层定义LSTM网络结构


pytorch中目前已经实现好了3中循环神经网络,分别是​​RNN​​​、​​GRU​​​、​​LSTM​​​,但是发现在nn模块中还存在​​RNNCell()​​​、​​LSTMCell()​​这个模块。

对于循环神经网络常用来处理​​序列数据​​​,可以理解为依次处理每个时间片的数据,但是对于Cell层只能够处理序列数据中的​​一个时间片​​的数据,所以要想使用Cell层达到RNN的目的,就需要不断循环处理每个时间片的数据。

下面使用LSTM和LSTMCell这两个模块来示例:

nn.LSTM()

input = torch.randn(10, 32, 100)

lstm = nn.LSTM(100, 8, 1)

output, _ = lstm(input)
print(output.shape)

该段代码定义了输入数据维度为【10,32,100】,批次大小为32,序列长度为10,每个时间片对应的维度为100。

定义了一个LSTM层,输入维度为100,隐藏状态维度为8,只有1层,经过LSTM后得到所有时间片的输出结果,维度为【10,32,8】。

torch.Size([10, 32, 8])

nn.LSTMCell()

要想使用LSTMCell来达到同样效果就需要不断使用这个Cell循环处理每个时间片的数据,然后将每次循环得到的输出结果进行堆叠即可。

input = torch.randn(10, 32, 100)

lstm = nn.LSTMCell(100, 8)

output = []

for time_data in input:
out, _ = lstm(time_data)
output.append(out)

output = torch.stack(output)
print(output.shape)

torch.Size([10, 32, 8])

那么pytorch中已经有了LSTM层,为什么要定义这个LSTMCell层呢?

原因很简单,就是能够提高定义模型的灵活性,可以根据自定义的网络模块来组合调用LSTMCell层。


举报

相关推荐

0 条评论