hans

hans

【Text Transcriptor】训练CRNN时,关于ctc_loss的几点注意事项


这个 ctc_loss 很魔性,训练 CRNN 虐了我几个来回。

我的数据集图片大小不一,我是先等比例缩小到固定高度为 32,宽度不定。

常见三个问题:

1.CTC Loss Error: invalidArgumentError: Not Enough time for target transition
sequence. (label_length 大于图片长度除以 4 了)

2.CTC Loss Error: InvalidArgumentError: sequence_length(b) <= time
(图片长度除以 4 小于等于 input_length 了)

3.ctc_loss error “No valid path found.”
(导致这个错误有两种情况,一是 label_length 大于 input_length,对模型收敛没有很大影响,只是出错的那一个 batch 参数没有更新优化。如果这个错误很少,可以忽略。如果这个错误很多的话就建议用下面方法优化一下训练集。二是使用 SGD 优化时,初始学习率过大。)

导致这三个问题的原因,就是 label_length 和 input_length 的取值问题。

  1. CRNN 一个主要优点就是可以识别任意长度的图片。在训练的时候,先统一将图片 padding 到一个固定的很长的宽度。然后 input_length 设置为你等比例缩小后,padding 之前的图片的宽除以四。部分代码如下:

    Img = Image.open (imagepath).convert ('L') #原始图片
    ResizedImg = cv2.resize (Img, (int (Img.shape [1] * (32 / Img.shape [0])), 32)) # 等比缩小
    input_length [i] = ResizedImg.shape [1] // 4 # 取等比缩小后图片长度除以 4 的值

  2. label_length 很简单理解,就是 ground truth 的长度。

  3. 如果你以为这样就完事大吉可以训练你就错了。因为你的图片可能有不合格的存在。导致问题 3 出现,loss 变为 inf。

  4. 所以在训练前,应该过滤一遍所有训练集和验证集的图片。ctc_loss 在计算预测结果和真值的 loss 的时候,会在你真值 label 中重复的字符之间插入空符,所以必须将 label_length 加上空符个数大于 input_length 的图片删除掉。而代码中的 2,是我考虑有可能在 label 的开头和末尾存在空符。(我并没有验证这个想法,只是为了保险起见。)举个例子,你等比缩小后图片高度为 32,宽度为 160,那么 input_length=40。ground truth label='abbbccddddcccaa',label_length=15,经过计算 repreat_number 为 2 (bbb)+1 (cc)+3 (dddd)+2 (ccc)+1 (aa) = 9,然后再加上开头结果的空符数 2,最终等于 11。也就是说必须满足 label_length (15)+repreat_number (9)+2<=input_length (40) 的图片才是合格的图片。部分代码如下:

    Img = np.array(Image.open(ImgRootPath + '/' + imgName).convert('L'))
    ResizedImg = cv2.resize(Img, (int(Img.shape[1] * (32 / Img.shape[0])), 32))
    l = [len(list(g)) for k, g in itertools.groupby(Label)]
    repeat_number = 0
    for n in l:
    if n > 1:
    repeat_number += (n - 1)
    input_length = ResizedImg.shape[1] // 4
    if len(Label)+repeat_number+2 > input_length:
    continue

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.