hans

hans

【テキストトランスクリプター】CRNNのトレーニング時におけるctc_lossに関する注意事項


この ctc_loss は非常に厄介で、CRNN のトレーニングは私を何度も苦しめました。

私のデータセットの画像サイズはバラバラで、まず固定の高さ 32 に比例縮小し、幅は可変です。

よくある 3 つの問題:

  1. CTC Loss Error: invalidArgumentError: Not Enough time for target transition sequence. (label_length が画像の長さを 4 で割った値より大きいです)

  2. CTC Loss Error: InvalidArgumentError: sequence_length (b) <= time (画像の長さを 4 で割った値が input_length 以下です)

  3. ctc_loss error “No valid path found.” (このエラーの原因は 2 つあります。1 つは label_length が input_length より大きい場合で、モデルの収束には大きな影響はありませんが、エラーが発生したバッチのパラメータは最適化されません。このエラーがまれであれば無視できますが、頻繁に発生する場合は以下の方法でトレーニングセットを最適化することをお勧めします。2 つ目は、SGD 最適化を使用している場合、初期学習率が大きすぎることです。)

これらの 3 つの問題の原因は、label_length と input_length の値の取り方です。

  1. CRNN の主な利点の 1 つは、任意の長さの画像を認識できることです。トレーニング時には、まず画像を固定の非常に長い幅にパディングします。そして、input_length を縮小比率を適用した後、パディングする前の画像の幅を 4 で割った値に設定します。以下は一部のコードです:

    Img = Image.open (imagepath).convert ('L') #元の画像
    ResizedImg = cv2.resize (Img, (int (Img.shape [1] * (32 / Img.shape [0])), 32)) # 等比縮小
    input_length [i] = ResizedImg.shape [1] // 4 # 縮小後の画像の長さを 4 で割った値を取得

  2. label_length は非常に簡単に理解できるもので、正解のラベルの長さです。

  3. これで終わりだと思っているなら、それは間違いです。なぜなら、画像には不適切なものが含まれている可能性があるからです。これが問題 3 の原因で、損失が inf になります。

  4. したがって、トレーニングの前に、すべてのトレーニングセットと検証セットの画像をフィルタリングする必要があります。ctc_loss は、予測結果と真の値の損失を計算する際に、真の値のラベルの重複する文字の間に空の記号を挿入します。そのため、label_length に空の記号の数を加えた値が input_length より大きい画像は削除する必要があります。コード中の 2 は、label の先頭と末尾に空の記号が存在する可能性を考慮しています(この考えを検証していませんが、念のために追加しました)。例えば、縮小後の画像の高さが 32 で幅が 160 の場合、input_length=40 となります。ground truth label='abbbccddddcccaa' の場合、label_length=15 となります。計算結果として repreat_number は 2 (bbb)+1 (cc)+3 (dddd)+2 (ccc)+1 (aa) = 9 となります。そして、先頭の空の記号の数 2 を加えると、合計 11 となります。つまり、label_length (15)+repreat_number (9)+2<=input_length (40) を満たす画像が有効な画像となります。以下は一部のコードです:

    Img = np.array(Image.open(ImgRootPath + '/' + imgName).convert('L'))
    ResizedImg = cv2.resize(Img, (int(Img.shape[1] * (32 / Img.shape[0])), 32))
    l = [len(list(g)) for k, g in itertools.groupby(Label)]
    repeat_number = 0
    for n in l:
    if n > 1:
    repeat_number += (n - 1)
    input_length = ResizedImg.shape[1] // 4
    if len(Label)+repeat_number+2 > input_length:
    continue

読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。