DeZero で眺める自動微分 (3) - らんだむな記憶 の続き。
https://github.com/oreilly-japan/deep-learning-from-scratch-3/blob/master/steps/step33.py で高階微分を求めるという内容になるが、計算グラフを考えると意外と難しい。このサンプルでは難しいので、もっと簡単なケースでみる。
\begin{align*}
y = x^3
\end{align*}
を考えると、その導函数は
\begin{align*}
\frac{dy}{dx} = 3x^2
\end{align*}
である。つまり、backprop においての計算グラフは gx = Mul(3, Pow(2)(x))
の形のものが出て来ているはずである。
実際、DeZero
に構築させた順伝播の計算グラフは
であり、逆伝播の計算グラフは
である。高階微分、特にここでは 2 階微分を求めるにはこの計算グラフを作成する必要があり、そのために y.backward(create_graph=True)
で逆伝播でも計算グラフを作成させているのである。