hans

hans

【其他】記一個Pytorch顯存溢出的問題


背景:

在 NLP 相關程式碼中,有一種 N2N 的 decoder 形式,需要上一輪預測的結果作為當前輪的輸入,和當前輪的特徵一起預測當前輪的結果。這裡我看大部分程式碼都是用了 teacher
forcing,然後通過矩陣相乘的方式快速得到結果。測試階段再通過 for 迴圈每一次預測一個結果。

問題:

測試階段使用了 torch.no_grad (),顯存不會出現問題。但是如果訓練階段我們不使用 teacher
forcing,也就是不用矩陣相乘一次出結果,就會面臨 out of memory 的報錯情況。

解決方法:

在所有涉及 相加 或者 列表 append 的地方,使用 .data [0] 這種形式來不繼承計算圖。

舉例:

在這個專案中: GitHub - gongliym/data2text-transformer: Enhanced Transformer Model
for Data-to-Text Generation

最終是生成文本,所以涉及到了 N2N 的形式。測試的時候,如果不用 torch.no_grad (),顯存是會爆掉的。如果在.model/src/model/transformer.py 這個檔案中,修改有自加的位置,比如 380 行中:

tensor = tensor + attn

改成

tensor = tensor + attn.data[0]

就不會出現顯存爆掉的情況了。

另外:

我們上面的情況並不適用 torch.cuda.empty_cache () 這個命令。我發現使用了這個命令,並不會釋放緩存,還會極大地提升訓練時間。

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。