studylog/北の雲

chainer/python/nlp

RNNに数式を読ませて計算、さらに逆に数式をdecodeさせる

小手先ではRNNに文章を生成させることが厳しい事がわかってきたので、もう少し基礎から勉強。

  1. "1 + 8"という文字列(str)をRNNに読ませて
  2. その結果を計算させて正解(9)を出力できるように学習して
  3. 今度は逆に答えが9になるような数式を出力

というものをやってみた。つまり、

f:id:kitanokumo:20150903052229p:plain

これRNNのチュートリアルにいいと思う。数字ではなくて文字列なのにちゃんと計算可能ってところが面白い。ちなみに除算と引き算は処理が面倒だったので無しにした。
やってみた結果は一番下に。▲は数式の生成に失敗、●は計算に失敗、先頭に何も無いやつは完璧に成功。最初は失敗多め、

ただあまりにも数式が短すぎるので、ちゃんと規則から計算してるっていうよりも、文字列と結果の関係性を丸暗記しちゃってる気もする。RNNの良いところは可変長の入/出力なんでもござれってところなので、理論上は式の長さが変わっても対応できるはず。「2 + 3 * 4 = 2 * 7」こんなのもいけるはず。いつか試したい。

これを応用すると翻訳や言い換えのタスクが可能になるみたい。左辺が日本語、右辺が英語とか。あるいはもっと進んだトピックとしてskip thoughtsというものが気になってるけれどイマイチ理屈がわからない。9月中に論文の人がpythonコードを公開予定なので楽しみ。(ryankiros/skip-thoughts · GitHub)

loss 6.87466335297 accurate 0.0
▲ 5+8 = 548 = 799
● +**
● 8++
● 5**
▲ 7+1 = 11 = 0+9
▲ 3*1 = 8 = 0+3
▲ 4*2 = 12 = 7+4
▲ 5+5 = 8 = 3+7
▲ 4*2 = 0 = 1+7
loss 519.03705442 accurate 0.07
▲ 4*1 = 0 = 3+1
▲ 6+9 = 8 = 9+6
▲ 4+0 = 10 = 2*2
▲ 9+9 = 11 = 9*2
▲ 6*8 = 30 = 8*6
loss 326.475861747 accurate 0.21
▲ 7*2 = 8 = 7+7
▲ 9*6 = 0 = 6*9
▲ 8*3 = 18 = 4*6
▲ 1+2 = 4 = 0+3
▲ 5*1 = 7 = 3+2
▲ 9+5 = 9 = 6+8
loss 278.558683787 accurate 0.21
▲ 7*9 = 24 = 9*7
▲ 8*4 = 24 = 4*8
▲ 4*2 = 12 = 5+3
▲ 2+7 = 10 = 0+9
▲ 4+4 = 10 = 5+3
loss 240.761846579 accurate 0.38
▲ 1*2 = 8 = 0+2
▲ 9+1 = 9 = 2*5
▲ 5+2 = 10 = 3+4
▲ 9+5 = 11 = 5+9
7+5 = 12 = 6+6
5*6 = 30 = 6*5
loss 240.465828435 accurate 0.31
▲ 8*5 = 0 = 5*8
8+1 = 9 = 3*3
▲ 9+4 = 11 = 4+9
0+9 = 9 = 9+0
7+7 = 14 = 7*2
▲ 2+8 = 12 = 2*5
loss 214.708754739 accurate 0.35
▲ 3*2 = 9 = 0+6
▲ 9+6 = 13 = 8+7
▲ 4+5 = 10 = 5+4
7*1 = 7 = 1+6
9+0 = 9 = 0+9
▲ 0+7 = 9 = 4+3
2+2 = 4 = 0+4
▲ 5*2 = 15 = 8+2
loss 166.554507551 accurate 0.53
8*1 = 8 = 6+2
▲ 1+3 = 6 = 0+4
2*3 = 6 = 6*1
0+8 = 8 = 2+6
▲ 5+3 = 7 = 3+5
loss 180.613642615 accurate 0.48
2+8 = 10 = 3+7
2*7 = 14 = 7+7
▲ 4+8 = 10 = 2*6
8*1 = 8 = 2*4
▲ 5+7 = 10 = 9+3
7+6 = 13 = 4+9
▲ 4+4 = 12 = 6+2
1+7 = 8 = 5+3
loss 124.338632597 accurate 0.61
5*2 = 10 = 9+1
▲ 3*9 = 18 = 9*3
5*1 = 5 = 2+3
3*2 = 6 = 2+4
▲ 6+7 = 14 = 4+9
2+6 = 8 = 6+2
7*1 = 7 = 1*7
1+8 = 9 = 9+0
loss 130.908571795 accurate 0.6
1+8 = 9 = 0+9
7+7 = 14 = 2*7
3*4 = 12 = 6*2
▲ 8+8 = 14 = 9+7
▲ 9+3 = 11 = 3+9
9*4 = 36 = 6*6
1*5 = 5 = 5+0
7+7 = 14 = 6+8
▲ 3+2 = 6 = 5+0
loss 121.06500636 accurate 0.64
▲ 1+0 = 2 = 1*1
▲ 1+5 = 9 = 0+6
▲ 5+8 = 10 = 6+7
5*7 = 35 = 7*5
2+6 = 8 = 4+4
loss 116.942061755 accurate 0.69
8+4 = 12 = 6*2
2*7 = 14 = 5+9
7*3 = 21 = 3*7
▲ 8+5 = 10 = 5+8
▲ 8+7 = 17 = 3*5
5+4 = 9 = 6+3
▲ 7+4 = 13 = 4+7
4*5 = 20 = 5*4
▲ 3+6 = 8 = 9*1
loss 104.598396208 accurate 0.69
1*9 = 9 = 4+5
4*6 = 24 = 8*3
3*9 = 27 = 9*3
0+9 = 9 = 4+5
2*5 = 10 = 5+5
3+3 = 6 = 6*1
1+9 = 10 = 6+4
loss 92.0016268954 accurate 0.69
▲ 5+9 = 12 = 6+8
1*5 = 5 = 0+5
8*4 = 32 = 4*8
1+5 = 6 = 3+3
▲ 2+7 = 10 = 0+9
5*2 = 10 = 5+5
2*1 = 2 = 0+2
loss 78.1875240455 accurate 0.79
▲ 8+3 = 10 = 9+2
7*2 = 14 = 5+9
2*9 = 18 = 3*6
4+7 = 11 = 6+5
8*9 = 72 = 9*8
2+2 = 4 = 0+4
▲ 2+1 = 4 = 3+0
3*2 = 6 = 4+2
loss 75.3174798711 accurate 0.78
8+1 = 9 = 7+2
0+4 = 4 = 4*1
8+1 = 9 = 3*3
8*2 = 16 = 9+7
1*6 = 6 = 4+2
0+8 = 8 = 2+6
9*7 = 63 = 7*9
9*1 = 9 = 6+3
loss 57.52815645 accurate 0.85
▲ 7+2 = 8 = 1+8
▲ 2+4 = 8 = 5+1
2+4 = 6 = 1+5
▲ 8+7 = 16 = 3*5
8*7 = 56 = 7*8
2*4 = 8 = 4+4
loss 56.4752816258 accurate 0.84
9*1 = 9 = 6+3
▲ 1+5 = 7 = 6*1
7*2 = 14 = 7+7
3+0 = 3 = 0+3
▲ 6+4 = 9 = 8+2
7*2 = 14 = 8+6
5+5 = 10 = 7+3
loss 54.1888957516 accurate 0.84
▲ 8+6 = 15 = 5+9
6*1 = 6 = 3*2
3*3 = 9 = 9*1
9+3 = 12 = 5+7
6+9 = 15 = 7+8
5*6 = 30 = 6*5
loss 42.4718404189 accurate 0.9
4+4 = 8 = 3+5
7+1 = 8 = 5+3
1*1 = 1 = 1+0
2*3 = 6 = 4+2
5*2 = 10 = 7+3
9*2 = 18 = 2*9
9+5 = 14 = 6+8
1*1 = 1 = 0+1
2*5 = 10 = 9+1
loss 39.666808875 accurate 0.9
4+4 = 8 = 6+2
6*1 = 6 = 3*2
3+3 = 6 = 2+4
6*8 = 48 = 8*6
3*5 = 15 = 6+9
4*2 = 8 = 6+2
2+5 = 7 = 1*7
6+0 = 6 = 3*2
6+2 = 8 = 2+6
loss 35.5719798488 accurate 0.93
4*5 = 20 = 5*4
4+8 = 12 = 3+9
6*6 = 36 = 4*9
1+7 = 8 = 4+4
2+5 = 7 = 4+3
4+9 = 13 = 7+6
1*5 = 5 = 0+5
loss 34.7921494932 accurate 0.92
8+1 = 9 = 6+3
9*2 = 18 = 2*9
7*8 = 56 = 8*7
7*3 = 21 = 3*7
7+5 = 12 = 8+4
6+6 = 12 = 3+9
loss 24.544334271 accurate 0.97
1+2 = 3 = 1*3
2*2 = 4 = 4+0
2+4 = 6 = 3+3
7+5 = 12 = 4*3
4*2 = 8 = 4+4
loss 22.6480830276 accurate 0.99
2+7 = 9 = 7+2
4*6 = 24 = 6*4
9*4 = 36 = 4*9
8+2 = 10 = 7+3
5+3 = 8 = 2*4
loss 17.4081512354 accurate 0.99
6+0 = 6 = 6*1
6*1 = 6 = 1+5
4+2 = 6 = 6+0
4+3 = 7 = 5+2
4+2 = 6 = 3+3

追記:数字が3つの時もほぼ学習完了

4*3*6 = 72 = 3*4*6
8+0*1 = 8 = 1+6+1
1+5*9 = 46 = 4+7*6
5*6*9 = 270 = 5*9*6
5*5+9 = 34 = 8*4+2
6+2+9 = 17 = 8*1+9
8+8+2 = 18 = 9+1*9
2*5*6 = 60 = 6*9+6
▲ 2*5*2 = 16 = 8+8+4
9+5*4 = 29 = 5*4+9
2+9+0 = 11 = 1*8+3
6+4+6 = 16 = 7+1*9
3*3*5 = 45 = 5+5*8
1*9+0 = 9 = 2+1+6
4+3*0 = 4 = 0*4+4
3*7+5 = 26 = 8*3+2