らんだむな記憶

blogというものを体験してみようか!的なー

ミニバッチ

何かニューラルネットワーク $f(\cdot; w)$ があるとする。訓練データを $(x_j,t_j),\ j \in J$ とする。
後の都合のため、 $(x_j,t_j) \ in X \times Y$ とし、 $w \in W$ とする。 $f: X \to Y$ である。

入力 $x_j$ に対する出力 $y_j = f(x_j; w)$ と正解データを $t_j$ に関する損失関数 $L(w)$ を(簡単のため MSE として)
\begin{align}
L(w) = \sum_{j \in J} \left\| t_j - f(x_j; w)\right\|^2
\end{align}

とする。 $w= w_0$ から始めて、 $w_{k+1} = w_k - \varepsilon\;\mathrm{grad}\left(L(w_k)\right)$ という漸化式で $L(w)$ の極小値 $w_\infty$ を求めましょうというのが勾配法であった。

ところがこれだと計算量が大きいということで、添字集合 $J$ を適当に分割して $J = J_0 \coprod J_2 \coprod \cdots \coprod J_N$ というものを考えましょうと。次に部分和的な損失関数
\begin{align}
L_n(w) = \sum_{j \in J_n} \left\| t_j - f(x_j; w)\right\|^2, \quad n \in \{0,1\cdots,N\}
\end{align}

を考える。 $w = w_0$ から始めて、 $w_{k+1} = w_k - \varepsilon\;\mathrm{grad}\left(L_k(w_k)\right)$ という漸化式でパラメータを更新していけば完全な損失関数 $L(w)$ の最小値を与える $w_\infty$ が求まるんじゃないですかね?というのがミニバッチであった。

要は、 $J$ 全体の訓練データにフィッティングするように全体的に“係数” $w$ を更新していくのではなく、個々の $J_k$ に限定して部分的にフィッティングするように $L(w)$ の“部分和” $L_k(w)$ を使って $w$ を更新しましょうというお話。

$f$ を $X \times Y$ 空間上の関数として可視化したとすれば、 $L$ に対するフルの更新(フルバッチ)では $\{x_j;\ j \in J\}$ 全体にわたって改善される様子が見えることだろう。 $L_k$ に対するミニバッチでの更新では $\{x_j;\ j \in J_k\}$ にわたっては改善されるが、他のデータについては一部悪化するところもあるだろう。イテレーションを繰り返すうちにフルバッチでは全体的に徐々にフィッティングしていくのに対し、ミニバッチでは断片的にガタガタとフィッティングしていきながらももぐら叩きのように最後は全部概ねフィッティングしている感じに動いていく様子が見えることだろう。

・・・という備忘録。深層学習 | 書籍情報 | 株式会社 講談社サイエンティフィクの pp.25-27 では勿論もっと真面目に書かれている。