らんだむな記憶

blogというものを体験してみようか!的なー

詳解ディープラーニング 第2版 (7)

6/layers/torch/Attention.py の実装を見てみたい。p.364 に

Attention 層は式 (6.12) (6.13) (6.14) で見たスコア関数のどれを用いるかによってパラメータが変わってきますが、ここでは式 (6.13) を用いてみることにしましょう。

とあるので、

\begin{align*}
g(\bm{h}_s, \bm{h}_t) = \bm{h}_t^T W_a \bm{h}_s
\end{align*}

が使われることになることに注意しよう。なお、ゼロつく 2 では (6.14) の式が使われていた。

以下の流れは基本的に p.360 をそのままコードに落とした形になっている。
コメントとして次元についても付記したが、16 の部分は、x における英語の文章のトークンの数なので、サンプルごとに値が変動する。

6/layers/torch/Attention.py#L33-L34

        # [16, 1, 128],[128, 128]->[16, 1, 128]
        score = torch.einsum('jik,kl->jil', (hs, self.W_a))
        # [1, 1, 128],[16, 1, 128]->[1, 1, 16]
        score = torch.einsum('jik,lik->jil', (ht, score))

は式 (6.10):

\begin{align*}
w(\tau, t) := g(\bm{h}_s(\tau), \bm{h}_t(t-1)) = \bm{h}_t^T(t-1) W_a \bm{h}_s(\tau)
\end{align*}

に対応する。

6/layers/torch/Attention.py#L36-L41

        # [1, 1, 16]->[1, 1, 16]
        score = score - torch.max(score, dim=-1, keepdim=True)[0]
        # [1, 1, 16]->[1, 1, 16]
        score = torch.exp(score)
        # [1, 1, 16]->[1, 1, 16]
        a = score / torch.sum(score, dim=-1, keepdim=True)

は p.365 のオーバーフロー対策を施した softmax の計算で式 (6.16)

\begin{align*}
a(\tau, t) = \mathrm{softmax}(g(\bm{h}_s(\tau), \bm{h}_t(t-1)))
\end{align*}

に対応する。なお、マスクの処理は割愛した。

6/layers/torch/Attention.py#L43

        # [1, 1, 16],[16, 1, 128])->[1, 1, 128]
        c = torch.einsum('jik,kil->jil', (a, hs))

は式 (6.17)

\begin{align*}
\bm{c}(t) = \sum_{\tau=1}^T a(\tau, t) \bm{h}_s(\tau)
\end{align*}

に対応し、k が数式中の $\tau$ に対応している。

6/layers/torch/Attention.py#L44-L45

        # [1, 1, 128],[1, 1, 128]->[1, 1, 256]
        h = torch.cat((c, ht), -1)
        # [1, 1, 256],[256, 128],[128])->[1, 1, 128]
        return torch.tanh(torch.einsum('jik,kl->jil', (h, self.W_c)) + self.b)

は式 (6.18)

\begin{align*}
\tilde{\bm{h}}_t(t) = \mathrm{tanh} \left( W_c \begin{pmatrix} \bm{c}(t) \\ \bm{h}_t(t) \end{pmatrix} + \bm{b} \right)
\end{align*}

に対応している。

そしてこうやって実装された Attention 層が、6/06_attention_torch.py#L103 でデコーダ内の LSTM の出力の後に適用されることになる。