らんだむな記憶

blogというものを体験してみようか!的なー

VAEの再構成確認

import matplotlib.pyplot as plt

def show_reconstruction(test_loader):
    fig = plt.figure(figsize=(15, 10))
    
    vae.eval()

    x, t = next(iter(test_loader))
    x, t = x.to("cuda"), t.to("cuda")
    
    z_mean, z_log_var = vae.encoder(x)
    z = vae.encoder.sampling(z_mean, z_log_var)
    y = vae.decoder(z)
    
    row = 5
    col = 6
    
    for i, (im, im2) in enumerate(zip(x.view(-1, 28, 28).cpu().detach().numpy()[:6], y.view(-1, 28, 28).cpu().detach().numpy()[:6])):
        plt.subplot(row, col, i+1)
        plt.imshow(im, cmap="gray")
        plt.axis("off")
        plt.subplot(row, col, i+1+col)
        plt.imshow(im2, cmap="gray")
        plt.axis("off")
    plt.show()
        
    vae.train()

くらいで確認できるかなぁ。