らんだむな記憶

blogというものを体験してみようか!的なー

Repeat ノードの勾配

ゼロつく 2 p.28 を掘り下げる。

関数 $f: \R^m \to \R^n $ を成分で表示すると、$x = (x_1, \cdots, x_m)$ として

\begin{align*}
f(x) = (f_1(x_1, \cdots, x_m), \cdots, f_n(x_1, \cdots, x_m))
\end{align*}

と書ける。更に $g: \R^n \to \R^1$

\begin{align*}
g(y) = g(y_1, \cdots, y_n)
\end{align*}

がある時、合成$ g \circ f$ について

\begin{align*}
(g \circ f)(x) = g(f_1(x_1, \cdots, x_m), \cdots, f_n(x_1, \cdots, x_m)),
\end{align*}

と書ける。これを $x_j$ について $x = a$ で偏微分すると

\begin{align*}
\frac{\del (g \circ f)}{\del x_j}(a) = \sum_{k=1}^{m} \frac{\del g}{\del y_k}\Bigg|_{y=f(a)} \frac{\del f_k}{\del x_j}\Bigg|_{x=a}
\tag{1}
\end{align*}

となる。

例えば、$m=3, n=6$ として、

\begin{align*}
f(x_1, x_2, x_3) = \begin{pmatrix} x_1 & x_2 & x_3 \\ x_1 & x_2 & x_3 \end{pmatrix}
\end{align*}

と、

\begin{align*}
g \left(\begin{pmatrix}y_1 & y_2 & y_3 \\ y_4 & y_5 & y_6\end{pmatrix} \right)
\end{align*}

のケースを考えることにする。$g$ は一般には複数の関数の積 $g_\ell \circ \cdots \circ g_1$ を考えても良い。ここで $\R^6 \simeq \mathrm{Mat}(2, 3; \R)$ という同一視をしている。すると、

\begin{align*}
(g \circ f)(x_1, x_2, x_3) = g \left(\begin{pmatrix} x_1 & x_2 & x_3 \\ x_1 & x_2 & x_3 \end{pmatrix} \right)
\end{align*}

となる。簡単のため、$\alpha_k = \frac{\del g}{\del y_k}\Big|_{y=f(a)}$ と置いて、$x = a$ で偏微分すると (1) 式より

\begin{align*}
\frac{\del (g \circ f)}{\del x_j}(a) = \sum_{k=1}^{6} \alpha_k \frac{\del f_k}{\del x_j}\Bigg|_{x=a}
\end{align*}

となる。ところで、

\begin{align*}
f_1 = f_4, f_2 = f_5, f_3 = f_6
\tag{2}
\end{align*}

であり、また、

\begin{align*}
\frac{\del f_k}{\del x_j} = \delta_{kj}, \quad 1 \leq k, j \leq 3
\tag{3}
\end{align*}

であるので、実は上式は $\sum_{k=1}^{3} (\alpha_k + \alpha_{k+3}) \delta_{kj}$ となり、すべての $j$ について表示すると

\begin{align*}
\left( \frac{\del (g \circ f)}{\del x_1}(a), \frac{\del (g \circ f)}{\del x_2}(a), \frac{\del (g \circ f)}{\del x_3}(a) \right) = (\alpha_1 + \alpha_4, \alpha_2 + \alpha_5, \alpha_3 + \alpha_6)
\end{align*}

となる。これが Repeat 或は BroadcastTobackward の実装であり、繰り返しにより複製されたインデックスの箇所同士で上位の導関数の値を足し合わせることになる。

重要なのは (2) 式と (3) 式の関係性である。前にも何か引用?したような気もするが、記事が埋もれた気がするので計算してみた・・・。