らんだむな記憶

blogというものを体験してみようか!的なー

RNN (3)

何回めになるか分からないけど、RNN にまたチャレンジ。https://github.com/oreilly-japan/deep-learning-from-scratch-3/blob/master/steps/step59.pyによる正弦波の予測を真似る。

class SimpleRNN(nn.Module):
    def __init__(self, hidden_size, out_size):
        super(SimpleRNN, self).__init__()
        self.rnn = nn.RNN(out_size, hidden_size, nonlinearity="tanh")
        self.fc = nn.Linear(hidden_size, out_size)
        self.reset_state()

    def reset_state(self):
        self.hidden = None

    def __call__(self, x):
        h, self.hidden = self.rnn(x, self.hidden)
        y = self.fc(h)
        return y

のような感じで作って学習で使う。学習データは

def to_torch(x):
    return torch.from_numpy(x.astype(np.float32)).clone()

のような感じの処理をかましPyTorch の意味でのテンソルに変換しておく。また、入力データが 1 つしかないなら長さ 1 のバッチにしないと怒られることになるので、 x.unsqueeze_(dim=0) とかで次元を増やしてバッチの体裁にする。しかしこれでも問題が起きて、2 回めの loss.backward() でエラーが出る。

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

さっぱり意味が分からないので適当に gg るだけ・・・。[Solved] Training a simple RNN - autograd - PyTorch Forumsという感じのページが見つかる。

The solution is to save hidden.detach()

というのが結論らしい。勾配の計算が終わっても hidden が計算グラフにぶら下がりっぱなしの状態になってしまうので、一旦勾配を計算したなら hidden を計算グラフからデタッチしなさいということのようだ。PyTorch のコードとしてはhttps://github.com/pytorch/pytorch/blob/v1.5.1/torch/tensor.py#L267-L268が該当する。これを使って上記の SimpleRNN クラスに

    def unchain(self):
        if self.hidden is not None:
            self.hidden.detach_()

を追加して、勾配計算の際に

model.zero_grad()
loss.backward()
model.unchain()
loss.detach_()
optimizer.step()

とすればエラーは出なくなったし、それっぽい結果が得られた。正直よく分からんことはまだまだあるが一歩前に進めただろうか?

フォーラムにある

I must have misunderstood how RNN’s work in pytorch as they were pretty much “plug and play” in keras and you didnt have to hold on to the hidden state.

がまさにその通りでud187のLesson 8 - らんだむな記憶に書いたように Keras を使うと全部裏でよろしくやってくれるので SimpleRNN がどういうものであるか?とか、隠れ状態とは?といったことを知ることもなく動いて、そして求める結果が得られる。一体どういうことが裏にあるのだろう?とソースコードを読むと、その便利機能が盛りだくさんなためか全然本質が読み取れないという・・・。

*****

DeZero の勉強をしつつ再度この記事を読むと

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

は本当に loss.backward(retain_graph=True) しなさいということだな。毎回 detach_ してたら RNN の恩恵がないな・・・。計算グラフを再度使うのであればそれを維持 (retain) するようにしなさいと。O'Reilly Japan - ゼロから作るDeep Learning ❸ p.497 で言う「メモリ使用の効率化」によりグラフに関わるデータが破棄されているのだろう。DeZero では実質 retain_graph=True で動作しているのだと思う。