らんだむな記憶

blogというものを体験してみようか!的なー

MNISTの手書き数字データの取得

[TensorFlow]

import tensorflow_datasets as tfds

dataset, metadata = tfds.load('mnist', as_supervised=True, with_info=True)
train_dataset = dataset['train']
num_train_examples = metadata.splits['train'].num_examples

def normalize(images, labels):
  images = tf.cast(images, tf.float32)
  images /= 255
  return images, labels

train_dataset =  train_dataset.map(normalize)
train_dataset = train_dataset.repeat().shuffle(num_train_examples).batch(64)

for images, labels in train_dataset.take(1):
  break
print(len(images), len(labels))
print(images[...,0][0].shape, labels.shape)
print(labels)

[PyTorch]

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

images, labels = next(iter(trainloader))
print(len(images), len(labels))
print(images[0][0].shape, labels.shape)
print(labels)

PyTorch(TorchVision)の便利なデータや変形機能をTensorFlow/Kerasでも使う - Qiitaにあるように、TensorFlowとPyTorchでchannelの場所が違うのでややこしい・・・。