hans

hans

【文本转录员】在训练CRNN时,关于ctc_loss的几点注意事项


這個 ctc_loss 非常神奇,訓練 CRNN 折磨了我好幾次。

我的資料集中的圖片大小不一,我先將它們等比例縮小到固定的高度為 32,寬度不確定。

常見的三個問題:

  1. CTC Loss Error: invalidArgumentError: Not Enough time for target transition sequence. (label_length 大於圖片長度除以 4 了)

  2. CTC Loss Error: InvalidArgumentError: sequence_length (b) <= time (圖片長度除以 4 小於等於 input_length 了)

  3. ctc_loss error “No valid path found.”(導致這個錯誤有兩種情況,一是 label_length 大於 input_length,對模型收斂沒有很大影響,只是出錯的那一個 batch 參數沒有更新優化。如果這個錯誤很少,可以忽略。如果這個錯誤很多的話就建議用下面方法優化一下訓練集。二是使用 SGD 優化時,初始學習率過大。)

導致這三個問題的原因,就是 label_length 和 input_length 的取值問題。

  1. CRNN 的一個主要優點就是可以識別任意長度的圖片。在訓練的時候,先統一將圖片 padding 到一個固定的很長的寬度。然後 input_length 設置為你等比例縮小後,padding 之前的圖片的寬除以四。部分程式碼如下:

    Img = Image.open (imagepath).convert ('L') #原始圖片
    ResizedImg = cv2.resize (Img, (int (Img.shape [1] * (32 / Img.shape [0])), 32)) # 等比縮小
    input_length [i] = ResizedImg.shape [1] // 4 # 取等比縮小後圖片長度除以 4 的值

  2. label_length 很簡單理解,就是 ground truth 的長度。

  3. 如果你以為這樣就完事大吉可以訓練你就錯了。因為你的圖片可能有不合格的存在。導致問題 3 出現,loss 變為 inf。

  4. 所以在訓練前,應該過濾一遍所有訓練集和驗證集的圖片。ctc_loss 在計算預測結果和真值的 loss 的時候,會在你真值 label 中重複的字符之間插入空符,所以必須將 label_length 加上空符個數大於 input_length 的圖片刪除掉。而程式碼中的 2,是我考慮有可能在 label 的開頭和末尾存在空符。(我並沒有驗證這個想法,只是為了保險起見。)舉個例子,你等比縮小後圖片高度為 32,寬度為 160,那麼 input_length=40。ground truth label='abbbccddddcccaa',label_length=15,經過計算 repreat_number 為 2 (bbb)+1 (cc)+3 (dddd)+2 (ccc)+1 (aa) = 9,然後再加上開頭結果的空符數 2,最終等於 11。也就是說必須滿足 label_length (15)+repreat_number (9)+2<=input_length (40) 的圖片才是合格的圖片。部分程式碼如下:

    Img = np.array(Image.open(ImgRootPath + '/' + imgName).convert('L'))
    ResizedImg = cv2.resize(Img, (int(Img.shape[1] * (32 / Img.shape[0])), 32))
    l = [len(list(g)) for k, g in itertools.groupby(Label)]
    repeat_number = 0
    for n in l:
    if n > 1:
    repeat_number += (n - 1)
    input_length = ResizedImg.shape[1] // 4
    if len(Label)+repeat_number+2 > input_length:
    continue

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。