GoでDeepLearning

GoでMnistを使ってDeep Learning

昔(2018年10月27日)Qiitaに書いた記事の転載です

現在、Deep Leanig向けのFWとしてはTensorflow、Pytorchなどが有名です. これらのFWのインターフェースはPython、内部実装はC++になっていて,ユーザーは使いやすくかつ高速にというのを実現しています (演算はGPUがメインですよね). 最近Go言語に興味を持ち、趣味でGoのアプリケーションの実装をしてみたい&Deep Learningの勉強をしたい,Goでの実装はなさそう??ということでGoでやってみました.

概要

Deep Learningの始めの一歩としてまずは、Mnist(手書き文字のデータ)だろう!ということで,Mnistで学習して、自分の書いた数字を推論できるようになることを目標に進めていきます. 現在はまだGo100%です(まだと書いた理由はGPU演算に関してに書きました) Screenshot from 2018-08-10 13-01-35.png

今回実装したもの

実装したプロジェクトです.

以下の順番で実装を行っていきました.

  • Mnistのデータを1次元配列として読み込む
  • 行列演算の実装
  • Goで学習できる枠組みを作る(Forward,Backward演算,重みParameterのUpdate)
  • 学習したデータをSave,Restoreできるような仕組み
  • 保存した学習済みデータを使って自分の書いた文字を推論

Mnistのデータを1次元配列として読み込む(GoMnist)

GoMnistをForkしてfloat64の1次元の配列として読み込めるようにしたり,正規化できるように改造しました. 改造量は多くないです.Deep Learningは演算精度はそんなにいらなく,工夫すれば16bitで学習できたりもします. なのでfloat64もいらなくてfloat32にしようかなと思ったんですが,Goのライブラリはfloat64前提にされているものが多く,対応がめんどくさかったのでfloat64でデータも読み込んでいます.

行列演算の実装(gmat)

ここは愚直にネットワークに必要な行列演算の実装を行っていきました.実装はGPUとCPUに分かれていて, build tagでGPU,CPU使い分けるようにしています.下記に書きましたがGPUの方はまだ不完全です. あとCPU演算のpprofを見たところ圧倒的にdot演算に時間がかかってたのでこの部分だけCPU数分だけ並列処理するようにしたところ,だいーぶ早くなりました.syncやchannelのいい勉強になりました.

GPU演算に関して

GPUでの演算にも取り組みましたが,GoでGPU演算を行うのが少し難しく現状はDot演算のみ実装してあります.CUDAのkernel関数を定義してそれをcgoを使ってリンクすれば,やりたい演算を実現できるのですがそうするとGoでDeep Learningを行うというのに反してしまうと思い,踏みとどまりました. でもそうするとCUDAのライブラリ(cuda runtime, cublas,cudnnなど)を使うしかなく柔軟な計算ができません.そこをどうやって解決するかはまだ答えがでてません.. 一応すべての関数をcuda kernel関数を定義することで実装しました。

Goで学習できる枠組みを作る(gdeep)

Interfaceに,各演算Layer(Dropout,Denseなど)のメソッド(Forward,Backwardなど)を登録します.そしてユーザーの定義した演算レイヤーに対応するForward,Backward演算が実行できるようにしました. ユーザーが定義する演算レイヤーはこんな感じ

layer := []gdeep.LayerInterface{}
gdeep.LayerAdd(&layer, &gdeep.Dense{}, []int{inputSize, hiddenSize})
gdeep.LayerAdd(&layer, &gdeep.Relu{})

そして各IterationごとにRunしてあげれば学習が進みます.

loss := gdeep.Run(layer, momentum, x, t)

gdeep.Runの中でForward,Backwardを行い,Backwardの演算結果を使って重みParameterのアップデートも行っています. これによって学習が進んでいきます.フルのサンプルはこちらをご参照ください. で実際に上記のサンプルを実行してMnistのTrainデータを使って学習し,Testデータを使ってどれくらいの精度で学習できているか評価を行いました.そうしたところAccuracy:92%となり,ちゃんと学習できていることがわかりました! これでとりあえず学習はできていると一安心です.

学習したデータをSave,Restoreできるような仕組み(gdeep)

gobというライブラリを使って学習したParameterの保存を行いました. こんな風に書けば保存できます.

gdeep.Saver(layer, "./sample.gob")

ライブラリのおかげでわりと簡単にSaveとRestoreは実現できました.決め打ちになってしまっている箇所があるので修正しないとなーと思ってます.

保存した学習済みデータを使って自分の書いた文字を推論

はい.上記保存した学習モデルを使って自分の書いた文字が推論できるか試していきます. 推論用に使うように自分が書いた数字です(笑)

これをMnistと同じ28x28のサイズに加工して,GrayScaleにします.サイズの加工は外部ライブラリを利用して,GrayScaleにする部分は実装しました. これでデータの準備はできたので,推論していきましょう. 書いたサンプルです(現在は決め打ちなんでこんなに短いけど,本当は演算レイヤーの定義を書かないといけなくなる..)

package main                                                                                         
                                                                                                     
import (                                                                                             
    "fmt"                                                                                            
    "github.com/kuroko1t/gdeep"                                                                      
)                                                                                                    
                                                                                                     
func main() {                                                                                        
    img := gdeep.ImageRead2GrayNorm("data/5.jpg")                                                    
    layer := []gdeep.LayerInterface{}                                                                
    gdeep.Restore("./sample.gob", &layer)                                                            
    predictnum := gdeep.Predict(layer, img)                                                          
    fmt.Println("predict:", predictnum)                                                              
}       

実行してみたところ,

predict: 5

ちゃんと,推論できました!!めでたしめでたし.

追記

mpiを利用したデータ並列処理を追加しました.コマンドはこんな感じ

mpirun  -np 2 -H host1:2,host2:1 go run example/mlpMnist_allreduce.go

Todo

  • Convolutionの実装
  • MLPGPU実装

まとめ

現在自分が実装したことに関して書いてみました.アドバイスやコメントなどありましたらよろしくお願いします.