[TensorFlow]
import tensorflow_datasets as tfds dataset, metadata = tfds.load('mnist', as_supervised=True, with_info=True) train_dataset = dataset['train'] num_train_examples = metadata.splits['train'].num_examples def normalize(images, labels): images = tf.cast(images, tf.float32) images /= 255 return images, labels train_dataset = train_dataset.map(normalize) train_dataset = train_dataset.repeat().shuffle(num_train_examples).batch(64) for images, labels in train_dataset.take(1): break print(len(images), len(labels)) print(images[...,0][0].shape, labels.shape) print(labels)
[PyTorch]
import torch from torchvision import datasets, transforms transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), ]) trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) images, labels = next(iter(trainloader)) print(len(images), len(labels)) print(images[0][0].shape, labels.shape) print(labels)
PyTorch(TorchVision)の便利なデータや変形機能をTensorFlow/Kerasでも使う - Qiitaにあるように、TensorFlowとPyTorchでchannelの場所が違うのでややこしい・・・。