らんだむな記憶

blogというものを体験してみようか!的なー

詳解ディープラーニング 第2版 (4)

x,t = next(iter(train_dataloader))
print('t of train_dataloader:', t.is_contiguous())
x,t = next(iter(val_dataloader))
print('t of val_dataloader:', t.is_contiguous())
x,t = next(iter(test_dataloader))
print('t of test_dataloader:', t.is_contiguous())
t of train_dataloader: False
t of val_dataloader: False
t of test_dataloader: True

によって、実際に前回で触れたように、

train_dataloaderval_dataloader から得た t がメモリ上で不連続

が確認できた。

train_dataloader2 = DataLoader((x_train, t_train),
                              batch_first=False,
                              device=device)
x,t = next(iter(train_dataloader2))
print('t of train_dataloader2:', t.is_contiguous())
train_dataloader3 = DataLoader((x_train, t_train),
                              batch_size=1,
                              batch_first=False,
                              device=device)
x,t = next(iter(train_dataloader3))
print('t of train_dataloader3:', t.is_contiguous())
t of train_dataloader2: False
t of train_dataloader3: True

とすれば分かるように batch_size=1 で変化している。

6/utils/torch/DataLoader.py#L41-L42 でパディングを追加する時に、パディングの実体が不連続なメモリ領域に乗っていると思われる。torch.t() も不連続になるようなので、__next__ の最後で .contiguous() を呼んで連続性を保証するようにしてみた。

diff --git a/6/utils/torch/DataLoader.py b/6/utils/torch/DataLoader.py
index 25da8fb..3a16d93 100644
--- a/6/utils/torch/DataLoader.py
+++ b/6/utils/torch/DataLoader.py
@@ -48,6 +48,8 @@ class DataLoader(object):
             x = x.t()
             t = t.t()
 
+        t = t.contiguous()
+
         self._idx += self.batch_size
 
         return x.to(self.device), t.to(self.device)

p.355 まで完了ということで。