Efficient Neural Architecture Search via Parameter Sharing (ENAS)
はじめに
最近DeepLearningのモデル構築もDeepLearningで行う手法(AutoML)が流行っています。 AutoMLの一つであるNAS(Neural Architechture Search)は、2016年にGoogleから発表されました。強化学習を利用してアーキテクチャ、ハイパラメータの最適化をDeeepLearningでできてしまうとして注目されてましたが、実行時間がネックとなっていました。2018年にNASより計算量を削減したあるENAS(Effective NAS)が発表され(こちらも強化学習を利用),1GPUでも演算可能な計算量とNASと変わらない精度を保っているということで注目されました.今回はこのENAS論文を読んでみます。(会社で論文紹介するのでメモ用です) AutoMLについて知らなかったので,調べてみると他にも複数の探索アルゴリズムがあるようです。
- MNAS
モバイル向けの最適モデル探索アルゴリズム。実行時間も目的関数に取り込んで精度、実行時間を考慮したモデル探索ができる。 Googleが公式モデルも公開していて,Google Cloudからも利用できます. - Darts
- FBNet
NASとENASの違い
ENASとNASの大きな違いは、実行時間です。NASと比較してGPUを利用した学習を1000倍以上高速にすることが可能になりました. 高速化の大きな要因として探索するDNNモデル間の重みデータの共有にあります。探索対象の新規モデルを毎回0から訓練するのではなく,探索済みのネットワークの重みデータを利用することで高速化しています。精度もNASと同等です。 ENASの構造的な特徴は有向非循環グラフ(DAG)としてNASの探索空間を表現したことです。各Nodeは局所的な演算(convolutionなどの演算)を表し、エッジは重みデータを表します。下図はすべての探索空間を表している例で,赤いエッジはコントローラによって選ばれた探索空間になります.
RNNモデルの探索の例
ENASのコントローラ(RNN)の役割は2つ
- どのエッジをアクティブにするか
- DAGの各ノードで実行する演算の種類
下図の左図、DAGとアクティブエッジ(赤矢印)を表す。 下図の右図、DAGに対応するRNNを表します。
単純なRNN探索のENASのメカニズム。
- ノード1で利用するActivation関数(例:tanh)を選択
- ノード2では前の接続Node(ノード1)とActivation関数(例:relu)を選択
- ノード3ではノード2とActivation関数(例:tanh)を選択
- ノード4ではノード1とActivation関数(例:tanh)を選択
- 出力ノードを決定する。出力ノードは他のノードへの入力になっていないノードとします。この例ではノード3,4です。リカレントセルは平均である(h3+h4) / 2 を出力として使用する。
各エッジはノード間の情報(重み)を表しています。ENASではこのエッジ情報を探索中のすべての繰り返しセルで共有します。なので、繰り返しで選ぶActivation関数が変わったとしても選んだノードペアが同一なら重みも共有されるということだと。
探索空間は4つのActivation関数(tanh, ReLU, identity, sigmoid)とする場合、4N x Nです。この論文ではN = 12で、約1015の空間です。
ENASのトレーニング
コントローラのネットワークは100個のhidden-unitを持つLSTMで、2つの学習パラメータθ(探索ポリシーの変数?)と子モデルの共有パラメータwがあります。トレーニングはまず、トレーニング全体の学習で子モデルのwをトレーニングします。そして一定のStep数ごとにθをトレーニングします。この論文では2000stepごとにθをトレーニングしている。
子モデルのwのトレーニング
コントローラのポリシーを固定して、SGDを利用してloss関数(下図左)を最小化します。勾配はモンテカルロ推定を利用します。
M=1でうまく機能します。つまりπ(m;θ)からサンプリングした任意の1つのモデルmの勾配を利用してwを更新できます。
θのトレーニング
wとθの報酬の最大化を目的とし強化学習を利用します。報酬は検証用データで計算される。
CNNモデルの探索の例
コントローラーはRNNの時と同様にノードにおけるローカル演算と一つ前に接続するノードを決定する操作をします.1つ前の接続によってはスキップ構造(残差接続)を表現できます. ノードは6つのオペレーションから選択します.
- convolutions with filter sizes 3x3
- convolutions with filter sizes 5x5
- depthwise-separable convolutions 3x3
- depthwise-separable convolutions 5x5
- max pooling
- average pooling
L層のネットワークでは,6L x 2 ^ L(L-1)/2の探索空間で,L=12では1.6 x 1029の探索空間になります.
メカニズム
- ノード1,2は入力ノードなのでなにもしない.
- ノード3は2つの入力ノード(ノード2,ノード2)を決定する.入力ノードはそれぞれidentityとseparable_conv_5x5と接続される.
- ノード3は2つの入力ノード(ノード3,ノード1)を決定する.入力ノードはそれぞれaverage poolingとsep_conv_3x3と接続される. 4. ノード4はどのノードの入力でもないので,出力として扱われる.
公開モデル
pytorchモデル
2019/9/1現在、CNNには対応しておらずRNNに対応しています。tensorflowモデル
python2系のみ対応しています。nni
Microsoftが作成しているAutoML向けのツールでその中にENASもあります. 主要なFWはカバーしていたり,Web UIもあったりしてnniは多機能そうです.