らんだむな記憶

blogというものを体験してみようか!的なー

nn.Conv2d の動きを探る

たぶん nn.Conv2d の処理は “2 次元” とは言いつつも、チャネル方向に幅をもったカーネルをデータに畳み込む、つまりテンソルとして要素ごとの積をとって得たテンソルの和をとる処理をしていると思われる。これを検証しよう。

tensor = torch.tensor(...) # torch.Size([1, 3, 4, 4]) のテンソル
conv = nn.Conv2d(3, 1, kernel_size = 3, stride=1, padding=0, bias=False)
output = conv(tensor)

tensor_ul = tensor[:,:,:3,:3] # 4x4 テンソルの左上 3x3 部分
tensor_ur = tensor[:,:,:3,1:] # 4x4 テンソルの右上 3x3 部分
tensor_ll = tensor[:,:,1:,:3] # 4x4 テンソルの左下 3x3 部分
tensor_lr = tensor[:,:,1:,1:] # 4x4 テンソルの右下 3x3 部分
kernel = conv.weight
expected = torch.tensor([[[
    [torch.sum(kernel * tensor_ul), torch.sum(kernel * tensor_ur)],
    [torch.sum(kernel * tensor_ll), torch.sum(kernel * tensor_lr)]
]]])
print(torch.equal(output, expected))

で結果は一致する。3 次元の畳み込みのような気持ちになるのでクラス名に違和感はあるのだが仕方ない・・・。