import matplotlib.pyplot as plt def show_reconstruction(test_loader): fig = plt.figure(figsize=(15, 10)) vae.eval() x, t = next(iter(test_loader)) x, t = x.to("cuda"), t.to("cuda") z_mean, z_log_var = vae.encoder(x) z = vae.encoder.sampling(z_mean, z_log_var) y = vae.decoder(z) row = 5 col = 6 for i, (im, im2) in enumerate(zip(x.view(-1, 28, 28).cpu().detach().numpy()[:6], y.view(-1, 28, 28).cpu().detach().numpy()[:6])): plt.subplot(row, col, i+1) plt.imshow(im, cmap="gray") plt.axis("off") plt.subplot(row, col, i+1+col) plt.imshow(im2, cmap="gray") plt.axis("off") plt.show() vae.train()
くらいで確認できるかなぁ。