hans

hans

【Python】【Caffe】二、訓練輸出可視化《python調用caffe模塊》


GitHub 程式碼地址: https://github.com/HansRen1024/Use-Python-to-call-Caffe-module

前言#

細心的同學會發現上一篇文章中最後完整程式碼已經有訓練程式碼了,四句話就能搞定。當然如果你要單機多卡訓練的話,建議你使用 caffe/python/ 目錄下的 train.py。

關於訓練結果可視化,我在之前文章中寫到過。通過記錄終端輸出,然後 matlab 實現。地址:
【Caffe】快速上手訓練自己的數據 《很認真的講講 Caffe》

看第六部分可視化內容

今天這裡我寫一下如何通過 python 調用 caffe 模塊,在訓練期間,提取輸出結果放到列表中,最後訓練結束將列表內容可視化。

完整程式碼:#

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Sun Jul 30 21:56:32 2017

author: hans

"""

import matplotlib.pyplot as plt
import caffe
import numpy as np

caffe.set_device(0)
caffe.set_mode_gpu()
solver = caffe.SGDSolver('doc/solver_lenet.prototxt')

# 下面參數參照solver
max_iter = 10000
display= 100
test_iter = 100
test_interval =500

#初始化
train_loss = np.zeros(max_iter/ display)
test_loss = np.zeros(max_iter/ test_interval)
test_acc = np.zeros(max_iter / test_interval)

_train_loss = 0
_test_loss = 0
_accuracy = 0
for it in range(max_iter):
    solver.step(1) # 這裡是運行一次迭代
    _train_loss += solver.net.blobs['loss'].data # 'loss' or 'Softmax1'
    if it % display == 0:
        # 計算平均train loss
        train_loss[it / display] = _train_loss / display
        _train_loss = 0

    if it % test_interval == 0:
        for test_it in range(test_iter):
            solver.test_nets[0].forward()
            _test_loss += solver.test_nets[0].blobs['loss'].data # 'loss' or 'Softmax1'
            _accuracy += solver.test_nets[0].blobs['acc'].data # 'acc' or 'Accuracy1'
        # 計算平均test loss
        test_loss[it / test_interval] = _test_loss / test_iter
        # 計算平均test accuracy
        test_acc[it / test_interval] = _accuracy / test_iter
        _test_loss = 0
        _accuracy = 0

_, ax1 = plt.subplots()
ax2 = ax1.twinx()

ax1.plot(display * np.arange(len(train_loss)), train_loss, 'g')
ax1.plot(test_interval * np.arange(len(test_loss)), test_loss, 'y')
ax2.plot(test_interval * np.arange(len(test_acc)), test_acc, 'r')

ax1.set_xlabel('iteration')
ax1.set_ylabel('loss')
ax2.set_ylabel('accuracy')
plt.show()

以上內容參考自: http://www.cnblogs.com/denny402/p/5686067.html

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。