cuda out of memory(PyTorch)

阅读 98

2022-06-27


文章目录

  • ​​情况1​​
  • ​​情况2​​
  • ​​解法1​​
  • ​​解法2​​

情况1

model.forward()过程中,中间变量过多,导致GPU使用量增大,如下所示:

def forward(self, x):
batch_size = x.shape[0]

x0 = self.base_model(x)

# Add positional info
x1 = self.up1(x0) # 512, 32, 32
x2 = self.up2(x1) # 256, 64, 64
x3 = self.up3(x2) # 256, 128, 128
outc = self.outc(x3) # 1, 128, 128
outr = self.outr(x3) # 2, 128, 128
return outc, outr

将中间传递的变量统一为x:

def forward(self, x):
batch_size = x.shape[0]

x = self.base_model(x)

# Add positional info
x = self.up1(x) # 512, 32, 32
x = self.up2(x) # 256, 64, 64
x = self.up3(x) # 256, 128, 128
outc = self.outc(x) # 1, 128, 128
outr = self.outr(x) # 2, 128, 128
return outc, outr

情况2

程序运行过程中会产生很多中间变量,pytorch不会清理这些中间变量,就会爆显存。

解法1

loss = self.criteration(output, label)
loss_sum += loss
####更改为
loss = self.criteration(output, label)
loss_sum += loss.item()

解法2

torch.cuda.empty_cache() 可清理缓存,应该是最有效最便捷的


精彩评论(0)

0 0 举报