DeZero に勾配クリッピングを実装する場合どうすれば良いのだろう?と思ったので PyTorch
の API を見てみる。
torch.nn.utils.clip_grad_norm_ — PyTorch 1.10.0 documentation
なるほど、loss.backward()
後にクリッピングするのか。
https://github.com/pytorch/pytorch/blob/v1.10.0/torch/nn/utils/clip_grad.py#L9-L56 を見ると勾配を持っているパラメータをかき集めている。なるほど。
https://github.com/pytorch/pytorch/blob/v1.10.0/torch/nn/modules/module.py#L1499-L1521 に相当するメソッドも基底クラスに実装する必要があると。これについては https://github.com/oreilly-japan/deep-learning-from-scratch-3/blob/06419d7fb2e7ea19aa3719efc27795edbdc41a1f/dezero/layers.py#L33-L40 で実装済みのようだ。PyTorch の場合、活性関数も nn.Module
を継承しているので、DeZero の実装だとカバーしている範囲が狭そうではあるが、簡易実装としては十分に使えそうに思える。