6/layers/torch/Attention.py の実装を見てみたい。p.364 に
Attention 層は式 (6.12) (6.13) (6.14) で見たスコア関数のどれを用いるかによってパラメータが変わってきますが、ここでは式 (6.13) を用いてみることにしましょう。
とあるので、
\begin{align*}
g(\bm{h}_s, \bm{h}_t) = \bm{h}_t^T W_a \bm{h}_s
\end{align*}
が使われることになることに注意しよう。なお、ゼロつく 2 では (6.14) の式が使われていた。
以下の流れは基本的に p.360 をそのままコードに落とした形になっている。
コメントとして次元についても付記したが、16
の部分は、x
における英語の文章のトークンの数なので、サンプルごとに値が変動する。
6/layers/torch/Attention.py#L33-L34
# [16, 1, 128],[128, 128]->[16, 1, 128] score = torch.einsum('jik,kl->jil', (hs, self.W_a)) # [1, 1, 128],[16, 1, 128]->[1, 1, 16] score = torch.einsum('jik,lik->jil', (ht, score))
は式 (6.10):
\begin{align*}
w(\tau, t) := g(\bm{h}_s(\tau), \bm{h}_t(t-1)) = \bm{h}_t^T(t-1) W_a \bm{h}_s(\tau)
\end{align*}
に対応する。
6/layers/torch/Attention.py#L36-L41
# [1, 1, 16]->[1, 1, 16] score = score - torch.max(score, dim=-1, keepdim=True)[0] # [1, 1, 16]->[1, 1, 16] score = torch.exp(score) # [1, 1, 16]->[1, 1, 16] a = score / torch.sum(score, dim=-1, keepdim=True)
は p.365 のオーバーフロー対策を施した softmax
の計算で式 (6.16)
\begin{align*}
a(\tau, t) = \mathrm{softmax}(g(\bm{h}_s(\tau), \bm{h}_t(t-1)))
\end{align*}
に対応する。なお、マスクの処理は割愛した。
6/layers/torch/Attention.py#L43
# [1, 1, 16],[16, 1, 128])->[1, 1, 128] c = torch.einsum('jik,kil->jil', (a, hs))
は式 (6.17)
\begin{align*}
\bm{c}(t) = \sum_{\tau=1}^T a(\tau, t) \bm{h}_s(\tau)
\end{align*}
に対応し、k
が数式中の $\tau$ に対応している。
6/layers/torch/Attention.py#L44-L45
# [1, 1, 128],[1, 1, 128]->[1, 1, 256] h = torch.cat((c, ht), -1) # [1, 1, 256],[256, 128],[128])->[1, 1, 128] return torch.tanh(torch.einsum('jik,kl->jil', (h, self.W_c)) + self.b)
は式 (6.18)
\begin{align*}
\tilde{\bm{h}}_t(t) = \mathrm{tanh} \left( W_c \begin{pmatrix} \bm{c}(t) \\ \bm{h}_t(t) \end{pmatrix} + \bm{b} \right)
\end{align*}
に対応している。
そしてこうやって実装された Attention 層が、6/06_attention_torch.py#L103 でデコーダ内の LSTM の出力の後に適用されることになる。