文章目录
- 情况1
- 情况2
- 解法1
- 解法2
情况1
model.forward()过程中,中间变量过多,导致GPU使用量增大,如下所示:
def forward(self, x):
batch_size = x.shape[0]
x0 = self.base_model(x)
# Add positional info
x1 = self.up1(x0) # 512, 32, 32
x2 = self.up2(x1) # 256, 64, 64
x3 = self.up3(x2) # 256, 128, 128
outc = self.outc(x3) # 1, 128, 128
outr = self.outr(x3) # 2, 128, 128
return outc, outr
将中间传递的变量统一为x:
def forward(self, x):
batch_size = x.shape[0]
x = self.base_model(x)
# Add positional info
x = self.up1(x) # 512, 32, 32
x = self.up2(x) # 256, 64, 64
x = self.up3(x) # 256, 128, 128
outc = self.outc(x) # 1, 128, 128
outr = self.outr(x) # 2, 128, 128
return outc, outr
情况2
程序运行过程中会产生很多中间变量,pytorch不会清理这些中间变量,就会爆显存。
解法1
loss = self.criteration(output, label)
loss_sum += loss
####更改为
loss = self.criteration(output, label)
loss_sum += loss.item()
解法2
torch.cuda.empty_cache() 可清理缓存,应该是最有效最便捷的