studylog/北の雲

chainer/python/nlp

可変長データのミニバッチをchainerのwhereでやる

以前にもちょこっと書いたのだけど自然言語はたいてい次元(単語数、文字数など)が合わない。

データA 1 2
データB 1 2 3

こんな感じで次元が合っていないデータをミニバッチで学習したい。

末尾をEOSで埋めてミニバッチ

可変長のミニバッチの時に末尾をEOSで埋めてそのまま学習するとまずいのでwhereを使うといいらしい。以前から挑戦していたのだけれどもなかなかうまくいかず、先日ようやく中の人にtwitterで教えてもらって出来ました。

まず末尾を0や-1など特殊なもので埋めて次元を固定する。仮に0をEOSに該当するIDだとしたらこのように埋められる。これでAもBも4次元で揃った。

データA 1 2 0 0
データB 1 2 3 0

転置(numpyの.T)したらそのまま入力に使える形式になってくれる。

A B
1 1
2 2
0 3
0 0

実際のコードだとこう。

a = numpy.asarray( [[1,2,0,0],[1,2,3,0]] )

a
#array([[1, 2, 0, 0],
       [1, 2, 3, 0]])

a.T
#array([[1, 1],
       [2, 2],
       [0, 3],
       [0, 0]])


ここから学習を開始する
step1
1 1 ←最初はここが入力(x)
2 2 ←最初はここが正解データ(t)
0 3
0 0

step2
次は一つずれる
データAのtに0が来る。この場合、正解がEOSになってもいいならば問題ない。

1 1
2 2 ← x
0 3 ← t
0 0

step3
ここでxに0が入ってしまう。
1 1
2 2
0 3 ← x
0 0 ← t

EOSを入力して正解にEOSを出すように学習することになる。これがマズいかどうかはタスク次第かもしれない。次の文字を予測するだけのタスクに使うならEOSを出力した時点で打ち切ると思うのでさほど問題は無いかもしれない。

でも文の終わりの時点での状態(シグナル?chainerのlstmだとhに入ってるやつ)を正確に取り出したいなら、
1→2→0
1→2→0→0
この両者は違うベクトルを出力するので困る。きちんと1→2→0で終わらせたい。

前置きが長くなったけど、そういう時にwhereを使うといいらしい。教えて頂いたコードをそのまま引用すると

c, h = lstm(c_prev, lstm_in) #ここでlstm
enable = (x != 0) #where用のcondition 実際はこれをVariableにする
c_next = where(enable, c , c_prev) #x!=0ならcを、x=0ならc_prevをc_nextに入れる
h_next = where(enable, h  , h_prev) #hも同じ

自分はLinkのlstmを使いたいので__call__を上のように少し書き換えて使ってます。

本当は下のように書ければベストだけどwhereの入力はVariableに限られているしlstm自体もc,hのタプルを返してきて直接whereに使えないので上のように一度lstmしてから戻り値をwhereで振り分けている。

c_next , h_next = where( conditon , lstm(c_prev, lstm_in) , ( c , h ) )

さらに末尾埋めに-1を使う場合(自分がやってる方法)

-1で埋めるやり方もある。1.4よりsoftmax_cross_entropyが正解t=-1の場合はlossに0を返してくれるようになったので出来れば末尾は-1で埋めたい。もし末尾0埋めだとlossやacc周りが少し面倒な気がする。EOSは0にして、末尾-1で埋める場合、

データA 1 2 0 -1 -1 -1
データB 1 2 3 4 0 -1

これを転置して学習開始するのだけど、以下のstepになると問題発生する。xに-1が来ちゃう。

A B
1 1
2 2
0 3
-1 4 ←x
-1 0 ←t
-1 0

ただし、ここでも書いたようにEmbedIDに-1を投げると危険。

#これはマズい nanが返ってきたりしていつかlossが破綻する
h0 = self.embed(x)

#自分はこうやってx=-1のときはZEROベクトルを返すようにしてる
h0 = F.where( x != -1 , self.embed(x) , ZERO )

そのうちEmbedIDにもignore flag=-1が出来ると思います。
Ignore labels for EmbedID · Issue #832 · pfnet/chainer · GitHub

2016/4/30追記 バージョン1.8でignore_flagが実装されました
デフォルトでは無効になっているので手動で設定します。-1じゃなくても0などintであればおk。若干パフォーマンス落ちるようです。

super(RNNLM, self).__init__(
    embed=L.EmbedID(n_vocab, n_units,ignore_label=-1), #ここで設定
    l1 = L.LSTM(n_units, n_units),
    l2 = L.LSTM(n_units, n_units),
)

追記終わり


次元が同じもの同士だけでミニバッチする

以下の例のように末尾埋めしちゃうと一番長い文に次元を合わせないといけないので、そのぶん無駄もできる。

1 2 -1 -1 ................................-1 ←短文
1 2 3 -1...................................-1 ←短文
1 2 3 4 5 6 7.......100 101 102 -1 ←長文 こいつのせいで無駄な処理が発生

長さが同じ文だけを集めてミニバッチすれば上の末尾埋めで発生する色々な煩わしさが無くなるし、何より速い。
ただ常に次元が同じものだけで学習するとコーパスが偏って学習に影響が出そうなのでできればシャッフル性を高めたい。

両者いいとこ取り

1.完全ランダム(一番遅い)
2.次元一致したのだけで学習(超高速)
3.なるべく末尾埋めが少なくなるようにうまいことシャッフル(20~30だけ、30~40だけ、みたいなグループにわける)

これをローテーションすれば学習に悪影響を与えずに速度も速くできるのかなと思いました。


以上です。
こんなやり方でシンプルなRNNだと簡単にミニバッチできるようになった。
でも木構造のrecursiveとかattentionなどの場合はどうするんだろう…。