hans

hans

【Python】【Shell】【Caffe】訓練集預處理 —— 數據增強 《很認真的講講Caffe》


----------【2017.09.29】更新包含 7 種數據增強方法的代碼 ----------------------------------------

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 29 15:42:18 2017

@author: hans

http://blog.csdn.net/renhanchi
"""

import skimage
import skimage.io
import numpy as np
import matplotlib.pyplot as plt
import os
import argparse

import sys
reload(sys)
sys.setdefaultencoding('utf-8')

num = 0

def flip(image):
    return np.fliplr(image)

def channel_shift(x, limit=0.1, channel_axis=2):
    x = np.rollaxis(x, channel_axis, 0)
    min_x, max_x = np.min(x), np.max(x)
    channel_images = [np.clip(x_ch + np.random.uniform(-limit, limit), min_x, max_x) for x_ch in x]
    x = np.stack(channel_images, axis=0)
    x = np.rollaxis(x, 0, channel_axis + 1)
    return x

def gray(img):
    coef = np.array([[[0.114, 0.587, 0.299]]])
    gray = np.sum(img * coef, axis=2)
    img = np.dstack((gray, gray, gray))
    return img

def contrast(img, limit=0.3):
    alpha = 1.0 + np.random.uniform(-limit, limit)
    coef = np.array([[[0.114, 0.587, 0.299]]])
    gray = img * coef
    gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray)
    img = alpha * img + gray
    img = np.clip(img, 0., 1.)
    return img

def lighter(img):
    return skimage.exposure.adjust_gamma(img, 0.5)

def darker(img):
    return skimage.exposure.adjust_gamma(img, 2)

def saturation(img, limit=0.3):
    alpha = 1.0 + np.random.uniform(-limit, limit)
    coef = np.array([[[0.114, 0.587, 0.299]]])
    gray = img * coef
    gray = np.sum(gray, axis=2, keepdims=True)
    img = alpha * img + (1. - alpha) * gray
    img = np.clip(img, 0., 1.)
    return img

parser = argparse.ArgumentParser()

parser.add_argument(
        'n',
        type = str,
        help = """\
        目錄名稱
        """
)

parser.add_argument(
        'm',
        type = str,
        default = 'flip',
        help = """\
        模式:
        flip(img),
        channel_shift(img, limit=0.1, channel_axis=2),
        gray(img),
        contrast(img, limit=0.3),
        lighter(img),
        darker(img),
        saturation(img, limit=0.3) 
        """
)

FLAGS = parser.parse_args()
mode = FLAGS.m
cla = FLAGS.n

dirpath = r'%s/' %cla
for dirname in os.listdir(dirpath):
    if os.path.isdir(r'%s%s' %(dirpath, dirname)): #判斷是否是目錄
        if not os.path.exists(r'%s_%s/%s/' %(cla, mode, dirname)): #判斷鏡像目錄是否存在
            os.makedirs(r'%s_%s/%s/' %(cla, mode, dirname)) #不存在就新建目錄
            for imagename in os.listdir(r'%s%s'%(dirpath, dirname)):
                num += 1
                print '%s saving %s_%s/%s/%s' %(num, cla, mode, dirname, imagename)
                image = os.path.join('%s%s/%s' % (dirpath, dirname, imagename))
                ori_Image = skimage.img_as_float(skimage.io.imread(image)).astype(np.float64)
                if mode == 'flip':
                    transform_image = flip(ori_Image)
                elif mode == 'channel_shift':
                    transform_image = channel_shift(ori_Image)
                elif mode == 'gray':
                    transform_image = gray(ori_Image)
                elif mode == 'contrast':
                    transform_image = contrast(ori_Image)
                elif mode == 'lighter':
                    transform_image = lighter(ori_Image)
                elif mode == 'darker':
                    transform_image = darker(ori_Image)
                elif mode == 'saturation':
                    transform_image = saturation(ori_Image)
                plt.imsave('%s_%s/%s/%s' %(cla, mode, dirname, imagename), transform_image, format='JPEG')

輸入第一個參數是目錄名,沒有任何符號,就是目錄名。第二個參數是數據增強模式名,沒有任何其他內容,就是模式名.


有時候訓練集太少,很快就過擬合.

增加訓練集是必須的.

目錄結構:

當前目錄: flower, Mirror.py 等

flower/ : 各種類別花的目錄

各種類別花的目錄 / : 當前種類花的所有圖片

執行後,在當前目錄下創建新目錄保存新數據

1. 鏡像圖片#

這個方法用或者不用看個人選擇,因為在 caffe 數據層預處理方法中有一個 mirror 開關,這個開關是隨機 mirror 當前 batch 圖片。

layer {
  name: "data"
  type: "Data"
  top: "data"
  top: "label"
  include {
    phase: TRAIN
  }
  transform_param {
    crop_size: 227
    mean_file: ".binaryproto"
    mirror: true #這個就是隨機對數據做鏡像預處理的開關。
  }
  data_param {
    source: "train_lmdb_227"
    batch_size: 64
    backend: LMDB
  }

代碼還是要放出來的。

#!/usr/bin/env python2
"""
Created on Fri Jul 21 11:08:23 2017

@author: hans
"""

import skimage
import skimage.io
import numpy as np
import matplotlib.pyplot as plt
import os

import sys
reload(sys)
sys.setdefaultencoding('utf-8') #ubuntu系統, windows下用gbk

cla = 'vegetable'

dirpath = r'%s/' %cla
for dirname in os.listdir(dirpath):
    if os.path.isdir(r'%s%s' %(dirpath, dirname)): #判斷是否是目錄
        if not os.path.exists(r'%s_mirror/%s/' %(cla, dirname)): #判斷鏡像目錄是否存在
            os.makedirs(r'%s_mirror/%s/' %(cla, dirname)) #不存在就新建目錄
            print "creat dir: %s" %dirname
            
        for imagename in os.listdir(r'%s%s'%(dirpath, dirname)):
            image = os.path.join('%s%s/%s' % (dirpath, dirname, imagename))
            oriMirror = skimage.img_as_float(skimage.io.imread(image)).astype(np.float64)
            imgMirror = np.fliplr(oriMirror)
            plt.imsave('%s_mirror/%s/%s' %(cla, dirname, imagename), imgMirror, format='JPEG')

2. 修改亮度#

#!/usr/bin/env python2
"""
Created on Fri Jul 21 11:08:23 2017

@author: hans
"""

import skimage
import skimage.io
import matplotlib.pyplot as plt
import numpy as np
import os

import sys
reload(sys)
sys.setdefaultencoding('utf-8')

cla = 'animal'
mode = 'lighter'

dirpath = r'%s/' %cla
for dirname in os.listdir(dirpath):
    if os.path.isdir(r'%s%s' %(dirpath, dirname)):
        if not os.path.exists(r'%s_%s/%s/' %(cla, mode, dirname)):
            os.makedirs(r'%s_%s/%s/' %(cla, mode, dirname))
            for imagename in os.listdir(r'%s%s'%(dirpath, dirname)):
                print 'saving %s_%s/%s/%s' %(cla, mode, dirname, imagename)
                image = os.path.join('%s%s/%s' % (dirpath, dirname, imagename))
                ori = skimage.img_as_float(skimage.io.imread(image)).astype(np.float32)
                img = skimage.exposure.adjust_gamma(ori, 0.5) # 小於1變亮,大於1變暗,跟上面mode匹配好
                plt.imsave('%s_%s/%s/%s' %(cla, mode, dirname, imagename), img, format='JPEG')

未完待續...

路徑腳本#

將數據名,數據類型索引和從當前目錄開始到該數據的路徑保存到.txt , 並亂序.

#!/bin/sh

classes=(Anthurium asparagus_fern bamboo_palm Begonia cactus cape_jasmine Carnation Cherry_plum chrysanthemum)

for cla in flower flower_mirror
do
	num=0 #類別索引
	for class in ${classes[@]}
	do
		ls $cla/$class/* > $class.txt
		sed -i "s/$/ $num/g" $class.txt #末尾添加類別索引
		let num+=1
		cat $class.txt >> temp.txt
		rm $class.txt
	done
done
cat temp.txt | awk 'BEGIN{srand()}{print rand()"\t"$0}' | sort -k1,1 -n | cut -f2- > flower_train.txt #亂序
rm temp.txt
載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。