hans

hans

【Others】记一个Pytorch显存溢出的问题


背景:

在 NLP 相关代码中,有一种 N2N 的 decoder 形式,需要上一轮预测的结果作为当前轮的输入,和当前轮的特征一起预测当前轮的结果。这里我看大部分代码都是用了 teacher
forcing,然后通过矩阵相乘的方式快速得到结果。测试阶段再通过 for 循环每一次预测一个结果。

问题:

测试阶段使用了 torch.no_grad (),显存不会出现问题。但是如果训练阶段我们不使用 teacher
forcing,也就是不用矩阵相乘一次出结果,就会面临 out of memory 的报错情况。

解决方法:

在所有涉及 相加 或者 列表 append 的地方,使用 .data [0] 这种形势来不继承计算图。

举例:

在这个项目中: GitHub - gongliym/data2text-transformer: Enhanced Transformer Model
for Data-to-Text Generation

最终是生成文本,所以涉及到了 N2N 的形式。测试的时候,如果不用 torch.no_grad (),显存是会爆掉的。如果在.model/src/model/transformer.py 这个文件中,修改有自加的位置,比如 380 行中:

tensor = tensor + attn

改成

tensor = tensor + attn.data[0]

就不会出现显存爆掉的情况了。

另外:

我们上面的情况并不适用 torch.cuda.empty_cache () 这个命令。我发现使用了这个命令,并不会释放缓存,还会极大的提升训练时间。

加载中...
此文章数据所有权由区块链加密技术和智能合约保障仅归创作者所有。