x,t = next(iter(train_dataloader)) print('t of train_dataloader:', t.is_contiguous()) x,t = next(iter(val_dataloader)) print('t of val_dataloader:', t.is_contiguous()) x,t = next(iter(test_dataloader)) print('t of test_dataloader:', t.is_contiguous())
t of train_dataloader: False t of val_dataloader: False t of test_dataloader: True
によって、実際に前回で触れたように、
train_dataloader
とval_dataloader
から得たt
がメモリ上で不連続
が確認できた。
train_dataloader2 = DataLoader((x_train, t_train), batch_first=False, device=device) x,t = next(iter(train_dataloader2)) print('t of train_dataloader2:', t.is_contiguous()) train_dataloader3 = DataLoader((x_train, t_train), batch_size=1, batch_first=False, device=device) x,t = next(iter(train_dataloader3)) print('t of train_dataloader3:', t.is_contiguous())
t of train_dataloader2: False t of train_dataloader3: True
とすれば分かるように batch_size=1
で変化している。
6/utils/torch/DataLoader.py#L41-L42 でパディングを追加する時に、パディングの実体が不連続なメモリ領域に乗っていると思われる。torch.t()
も不連続になるようなので、__next__
の最後で .contiguous()
を呼んで連続性を保証するようにしてみた。
diff --git a/6/utils/torch/DataLoader.py b/6/utils/torch/DataLoader.py index 25da8fb..3a16d93 100644 --- a/6/utils/torch/DataLoader.py +++ b/6/utils/torch/DataLoader.py @@ -48,6 +48,8 @@ class DataLoader(object): x = x.t() t = t.t() + t = t.contiguous() + self._idx += self.batch_size return x.to(self.device), t.to(self.device)
p.355 まで完了ということで。