らんだむな記憶

blogというものを体験してみようか!的なー

neuralnet_mnist.pyのTensorflow版

deep-learning-from-scratch/neuralnet_mnist.py at master · oreilly-japan/deep-learning-from-scratch · GitHubのch03/neuralnet_mnist.pyを学習を含めてTensorflowでやってみる。
一部必要に応じて少し書き換えた。

from __future__ import absolute_import, division, print_function, unicode_literals

import sys, os
import numpy as np
from dataset.mnist import load_mnist

import tensorflow as tf
print(tf.__version__)
tf.enable_eager_execution()

# 訓練データも返すように変更
def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_train, t_train, x_test, t_test

def visualize_predictions(x_test, t_test):
    # Udacityのud187の関数を適当に改造して使う。
    from PIL import Image
    import matplotlib.pyplot as plt

    predictions = model.predict(x_test)
    test_images = [Image.fromarray(im.reshape(28, 28)) for im in x_test]

    num_rows = 10
    num_cols = 5
    num_images = num_rows*num_cols
    plt.figure(figsize=(2*2*num_cols, 2*num_rows))
    for i in range(num_images):
        plt.subplot(num_rows, 2*num_cols, 2*i+1)
        plot_image(i, predictions, t_test, test_images)
        plt.subplot(num_rows, 2*num_cols, 2*i+2)
        plot_value_array(i, predictions, t_test)

def main():
    # x_trainは既にshape=(60000, 784)のndarrayである。
    x_train, t_train, x_test, t_test = get_data()

    # オリジナルのpickleデータからすると中間層はそれぞれ50個と100個のニューロンからなるらしい。
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(50, activation=tf.nn.sigmoid, input_shape=(784,)),
        tf.keras.layers.Dense(100, activation=tf.nn.sigmoid),
        tf.keras.layers.Dense(10,  activation=tf.nn.softmax)
    ])

    model.compile(optimizer='adam', 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])

    # 10回くらい回せば損失関数が十分に収束していく。
    history = model.fit(x=x_train, y=t_train, epochs=10)

    # 大体97%の正確さ
    test_loss, test_accuracy = model.evaluate(x=x_test, y=t_test)
    print('Accuracy on test dataset:', test_accuracy)

    # オマケの可視化。
    visualize_predictions(x_test, t_test)

if __name__ == "__main__":
    main()