ChainerのcuDNN-RNN(NStepLSTM)のとっかかり
16.0の新機能NstepLSTMはcuDNN5.0以降で最適化されたcuDNN-RNNを利用できます。速くなるらしいです。
Optimizing Recurrent Neural Networks in cuDNN 5 | Parallel Forall
これの良い所は次元数が合わないデータでもミニバッチ処理が簡単にできる点です。
再掲しますが以前はこんな風にやっていました。
可変長データのミニバッチをchainerのwhereでやる - studylog/北の雲
以前のやり方
手順1(次元が合ってない)
データA 1 2
データB 1 2 3
手順2(0で末尾を埋めて次元を合わせる)
データA 1 2 0 0
データB 1 2 3 0
手順3(転置)
A B
1 1
2 2
0 3
0 0
手順4-1(入力する)
1 1 ←最初はここが入力(x)
2 2 ←最初はここが正解データ(t)
0 3
0 0
手順4-2(一つずらして入力)
1 1
2 2 ← x
0 3 ← t
0 0
以下繰り返す
次元を合わせるために末尾に-1や0などを追加して、転置して…というめんどくさい方法。これNStepLSTMでは全部Chainerがやってくれます。
NStepLSTMのやり方
手順1(次元が合ってないけどそのまま入力する)
データA 1 2
データB 1 2 3
おわり。
めちゃくちゃ楽だ!
15.0で導入された可変長対応LSTMでは次元数を小さい順に並び替える処理が必要でしたがNStepLSTMでは必要無いです。
こんなのもそのまま放り込むだけ。
データA 1 2
データB 1 2 3 4
データC 1 2 3
lossの計算
以前のLSTMは入力する時に時刻tを一つづつずらしてその都度lossを計算していたと思うんですが、NStepLSTMは入力時に時刻ずらしは考えなくて済むようになってます。全ての時刻での出力結果がまとまって返ってくるのでそれを時刻ごとにずらしてloss計算するイメージ。入力時じゃなくて出力時に時刻ずらし。文章で伝わりますかね…。そのうちなるべくシンプルなサンプルコードをあげたいと思います。
速度
以前のLSTMとの比較はまだやれていません。
cuDNNの有無では明確に有る方が速いです。条件によるんで具体的な数字は何とも言えないですがEmbedサイズが大きくなればなるほど差がつくようです。ここは後で追記するかもしれません。
その他
dropoutを有効にするとエラーが出ちゃいます。何故だ。cuDNN5.1,Ubuntu14.04です。
(追記:17.0で治った模様)
File "/usr/local/lib/python3.4/dist-packages/chainer/links/connection/n_step_lstm.py", line 96, in __call__
train=train, use_cudnn=self.use_cudnn)
File "/usr/local/lib/python3.4/dist-packages/chainer/functions/connection/n_step_lstm.py", line 468, in n_step_lstm
ret = rnn(*inputs)
File "/usr/local/lib/python3.4/dist-packages/chainer/function.py", line 198, in __call__
outputs = self.forward(in_data)
File "/usr/local/lib/python3.4/dist-packages/chainer/functions/connection/n_step_lstm.py", line 267, in forward
self.reserve_space.data.ptr, reserve_size)
File "cupy/cuda/cudnn.pyx", line 960, in cupy.cuda.cudnn.RNNForwardTraining (cupy/cuda/cudnn.cpp:12578)
File "cupy/cuda/cudnn.pyx", line 978, in cupy.cuda.cudnn.RNNForwardTraining (cupy/cuda/cudnn.cpp:12318)
File "cupy/cuda/cudnn.pyx", line 321, in cupy.cuda.cudnn.check_status (cupy/cuda/cudnn.cpp:1721)
cupy.cuda.cudnn.CuDNNError: CUDNN_STATUS_BAD_PARAM: b'CUDNN_STATUS_BAD_PARAM'
追記
実サンプルを上げてくれている方がいました。DropoutのバグもGithubに報告してくれたおかげで治ったみたいです。感謝。
www.monthly-hack.com