chainerのwhereでbroadcastingやりたいけどtype_checkに弾かれる
numpyのbroadcastingは自分には難しくてあまり理解していなかったのだけど最近ようやく少しだけわかってきた。
Pythonによるデータ分析入門のP410からの説明、英語版だと嬉しい事にPython for Data Analysis - Free Download eBook - pdfより無料で読める。P362より。シンプルな動作なら英語版Figure 12-4の図だけを見ればだいたいわかると思います。この本は大半がpandasの使い方で占められているのだけどnumpyのところだけ切り取って加筆してもっと薄く安くして売って欲しかった。
以下はxpにcuda.cupyを入れているけどnumpyでも同じ動作。
where( condition , Trueだったら代入する行列 , Falseだったら代入する行列 )
xp = cuda.cupy #numpyでも同じ動作 A = xp.asarray([[1,1,1] , [2,2,2]] , dtype=np.float32) #=> shape(2,3) Zero = xp.zeros((2,3) , dtype=np.float32) #Aと同じ形で中身が0の行列を作る #Aの2行目(indexは1、つまり[2,2,2]のところ)を0にしたい cond = xp.asarray( [[True , True , True],[False ,False , False]] ) B = xp.where( cond , A , Zero ) #=>[[ 1. 1. 1.] [ 0. 0. 0.]]になる
ここでcondを以下のようにしても同じ動作になる。
cond = xp.asarray( [ [True] , [False] ] ) #shapeが(2,1) B = xp.where( cond , A , Zero ) #=>[[ 1. 1. 1.] [ 0. 0. 0.]] 上と同じ結果になる
上は全部(2,3)の行列だけど、下は(2,1)と(2,3)の行列が混じってる。
これでうまくいくのはnumpyのbroadcastingのおかげらしい。
本題:chainerのwhereの場合
同じような動作をするchainerのfunctionもある。入出力はもちろんVariableになる。
import chainer.functions as F def to_variable(asarray): return Variable(asarray) #AとZeroは上と同じ条件なので省略 cond = xp.asarray( [[True , True , True],[False ,False , False]] ) #全部Variableに変換 A = to_variable(A) ; Zero = to_variable(Zero) ;cond = to_variable(cond) B = F.where(valid , A , Zero ) #=>B.dataが[[ 1. 1. 1.] [ 0. 0. 0.]]になる
ところがcondを(2,1)にしてbroadcastingを試みるとこうなってしまう。
cond = xp.asarray( [ [True] , [False] ] ) #shapeが(2,1) B = xp.where( cond , A , Zero ) #以下のエラー #chainer.utils.type_check.InvalidType: #Invalid operation is performed in: Where (Forward) #Expect: in_types[1].shape == in_types[0].shape #Actual: (2, 3) != (2, 1)
type_checkさんがwhereの前に仕事をしてしまい、shape不一致で弾かれてしまう。ただ中のコードを読むとtype_checkさえスルーできればいけそうなので、
import os os.environ["CHAINER_TYPE_CHECK"] = "0" #type_checkしない ~省略~ cond = xp.asarray( [ [True] , [False] ] ) B = F.where(valid , A , Zero ) #=>B.dataが[[ 1. 1. 1.] [ 0. 0. 0.]]になる
としたらうまくいった。
ちなみにos.environ["CHAINER_TYPE_CHECK"] = "0" はtype_checkしないぶん学習も速くなります。自分はいつもオフにしてます。
F.whereはメモリをそこそこ食う
隠れ層同士の結合(例 512 * 512とか)にwhere使ってもそこまでサイズは大きくならないけれど、最終出力層がクラス分類で単語ボキャブラリー数なんかだと(512 * 60000)みたいになって前者より100倍サイズが大きいのでかなりメモリを食ってしまう。where使いだして異常にメモリを食うようになったら注意です。
おしまい。