Attention を可視化したいなと思ったので、ゼロつく 2 の visualize_attention.py を流用してみようと思う。
そのために、6/layers/torch/Attention.py のほうも改造が必要である。
class Attention(nn.Module): def __init__(self, output_dim, hidden_dim, device='cpu'): ... # 追加 self.reset_attention_map() def forward(self, ht, hs, source=None): ... a = score / torch.sum(score, dim=-1, keepdim=True) # 追加 self.attention_map = torch.concat([self.attention_map, a.view(1, -1)], 0) if self.attention_map is not None else a.view(1, -1) ... # 追加 def reset_attention_map(self): self.attention_map = None
のように途中で計算された attention を保存するようにする。
ゼロつく 2 から流用する関数についてはラベルを簡潔する方法がパッと思いつかないので後回しにして、以下のようにする
import matplotlib.pyplot as plt def visualize(attention_map, row_labels=None, column_labels=None): fig, ax = plt.subplots() ax.pcolor(attention_map, cmap=plt.cm.Greys_r, vmin=0.0, vmax=1.0) ax.patch.set_facecolor('black') ax.set_yticks(np.arange(attention_map.shape[0])+0.5, minor=False) ax.set_xticks(np.arange(attention_map.shape[1])+0.5, minor=False) ax.invert_yaxis() if row_labels: # XXX: 後で考える ax.set_xticklabels(row_labels, minor=False) if column_labels: # XXX: 後で考える ax.set_yticklabels(column_labels, minor=False) plt.show()
そして 30 エポックくらいの訓練をテキストの通りに回して、
torch.save(model.state_dict(), '06_attention_torch.pth')
とかで保存しておく。これで色々ミスっても何度もやり直せる。そして、
model.load_state_dict(torch.load('06_attention_torch.pth')) model.eval() with torch.no_grad(): x, t = next(iter(test_dataloader)) model(x) source = x.view(-1).tolist() target = t.view(-1).tolist() source = ' '.join(en_vocab.decode(source)) target = ' '.join(ja_vocab.decode(target)) print('>', source) print('=', target) print(model.decoder.attn.attention_map.shape) visualize(model.decoder.attn.attention_map.cpu()) model.decoder.attn.reset_attention_map()
とかすると
という感じになったのだが・・・果たしてこれで正しいのだろうか・・・。
色々とスッキリしないところは残っているが、ここであまり悶々とせずにとりあえず先に進んで薄い理解を重ねていこう。ということで、p.376 まで完了とする。