studylog/北の雲

chainer/python/nlp

Chainerのcleargradsと旧zerograds

1.15.0よりzerogradsが非推奨になりcleargradsというものが導入されたらしい。

github.com

0埋めはメモリいっぱい使うし意味が無いからNone埋めにする!みたいな感じ。
変更されたコードはこちら。上がclearで下がzero。

    def cleargrad(self):
        """Clears the gradient array."""
        self._grad = None

    def zerograd(self):
        """Initializes the gradient array by zeros.

        .. deprecated:: v1.15
           Use :meth:`cleargrad` instead.

        """
        warnings.warn(
            'Variable.zerograd is deprecated. Use Variable.cleargard instead.',
            DeprecationWarning)
        with cuda.get_device(self.data) as dev:
            if self._grad is None:
                xp = numpy if int(dev) == -1 else cuda.cupy
                self._grad = xp.zeros_like(self.data)
            else:
                self._grad.fill(0)

確かに0ではなくNoneを入れてる。

早速使ってみるとエラーが出る。
エラー箇所を見てみると、どうも0(あるいは数値)が入ってる事を想定されたコードなのに、cleargradsのせいでNoneが入っててエラーになってるよう。

https://github.com/pfnet/chainer/blob/master/chainer/optimizer.py#L487

    def __call__(self, opt):
        if cuda.available:
            kernel = cuda.elementwise(
                'T p, T decay', 'T g', 'g += decay * p', 'weight_decay')

        rate = self.rate
        for param in opt.target.params():
            p, g = param.data, param.grad #このparam.gradがNoneになってる
            with cuda.get_device(p) as dev:
                if int(dev) == -1:
                    g += rate * p #だからここで落ちる(CPUの場合)
                else:
                    kernel(p, rate, g) #GPUだとこっちで落ちる

https://github.com/pfnet/chainer/blob/master/chainer/optimizers/adam.py#L44

    def update_one_gpu(self, param, state):
        cuda.elementwise(
            'T grad, T lr, T one_minus_beta1, T one_minus_beta2, T eps',
            'T param, T m, T v',
            '''m += one_minus_beta1 * (grad - m);
               v += one_minus_beta2 * (grad * grad - v);
               param -= lr * m / (sqrt(v) + eps);''',
            'adam')(param.grad, self.lr, 1 - self.beta1, 1 - self.beta2, #このparam.gradがNone
                    self.eps, param.data, state['m'], state['v'])

他にもいっぱいある。
1.8.0から一気に1.16.0に上げたので、その間に書き方が色々と変わってて自分の書いたコードで何か必要なのが抜け落ちているのかもしれない。trainerは使ってないです。NstepLSTM(cuDNN RNN)を使っています。

よくわからないので今はとりあえずzerogradsのまま動かしています。

この方も関連する感じ。
hiho-developer.hatenablog.com

1.15.0リリースからはだいぶ経ってるのにエラー報告が全く無いという事はやはりこっち側が原因なのかな。
わからない。