らんだむな記憶

blogというものを体験してみようか!的なー

勾配クリッピング

DeZero に勾配クリッピングを実装する場合どうすれば良いのだろう?と思ったので PyTorchAPI を見てみる。

torch.nn.utils.clip_grad_norm_ — PyTorch 1.10.0 documentation
なるほど、loss.backward() 後にクリッピングするのか。

https://github.com/pytorch/pytorch/blob/v1.10.0/torch/nn/utils/clip_grad.py#L9-L56 を見ると勾配を持っているパラメータをかき集めている。なるほど。

https://github.com/pytorch/pytorch/blob/v1.10.0/torch/nn/modules/module.py#L1499-L1521 に相当するメソッドも基底クラスに実装する必要があると。これについては https://github.com/oreilly-japan/deep-learning-from-scratch-3/blob/06419d7fb2e7ea19aa3719efc27795edbdc41a1f/dezero/layers.py#L33-L40実装済みのようだ。PyTorch の場合、活性関数も nn.Module を継承しているので、DeZero の実装だとカバーしている範囲が狭そうではあるが、簡易実装としては十分に使えそうに思える。