import mnist_manipulator
import model_figure5a
import torch
import matplotlib.pyplot as plt
import numpy as np
import glob


# Produce Figure 5(a):

# These indices are cherry-picked to be stylistically different,
# but you can pick any images in the dataset and they will look equally good.
image_indices = [163, 134, 18, 168, 219, 60]

images, _, _, _, _, _ = mnist_manipulator.load_mnist_dataset(64, torch.device('cpu'))

# Initialize and load the trained model
putter = model_figure5a.Putter(model_figure5a.Encoder(), model_figure5a.Decoder())
putter.load_state_dict(torch.load("outputs/models/figure5a_putter.pt", map_location=torch.device('cpu')))
getter = model_figure5a.Getter(model_figure5a.Encoder())
getter.load_state_dict(torch.load("outputs/models/figure5a_getter.pt", map_location=torch.device('cpu')))

arrimages = []
for idx in image_indices:
    with torch.no_grad():
        row = [images[idx, 0, :, :].numpy()]
        # For each label, put that label onto the image:
        image = images[[idx], :, :, :]
        for label in range(10):
            # Prepare one-hot encoded label vector:
            label_vector = torch.nn.functional.one_hot(torch.tensor([label]), 10)
            # Run putter on the image and label:
            put_label = putter.forward(image, label_vector)

            numpy_output = put_label[0, 0, :, :].numpy()
            row.append(numpy_output)
        arrimages.append(row)
arrimages = np.array(arrimages)

fig, axes = plt.subplots(6, 12, width_ratios=[21/28, 0.05] + [21/28]*10, figsize=(8, 5))
axes[0][0].set_title('Original')
for i in range(10):
    axes[0][i+2].set_title(f'{i}')
for j in range(6):
    axes[j][0].imshow(0.07 + arrimages[j, 0, :, 4:-2], aspect='equal', vmin=0.0, vmax=1.0, cmap='Greys', interpolation = 'nearest')
    axes[j][0].axis('off')
    axes[j][1].axis('off')
    axes[j][-1].axis('off')
    for i in range(0, 10):
        axes[j][i+2].imshow(0.07 + arrimages[j, i+1, :, 4:-2], aspect='equal', vmin=0.0, vmax=1.0, cmap='Greys', interpolation='nearest')
        axes[j][i+2].tick_params(axis='x', which='both', top=False, bottom=False)
        axes[j][i+2].tick_params(axis='y', which='both', left=False, right=False)
        axes[j][i+2].set_xticks([])
        axes[j][i+2].set_yticks([])
        axes[j][i+2].axis('off')
fig.savefig('figure5a.pdf', dpi=300)


# Produce Figure 5 (b):

# Load saved accuracy data in outputs. If you generate your own data with run_decay_test.py
# you should name your runs putter_... transfer_... and vae_... in order for this script
# to pick them up.
putters = []
transfers = []
vaes = []
for fname in glob.glob('outputs/data/putter*.npy'):
    putters.append(np.load(fname))
for fname in glob.glob('outputs/data/transfer*.npy'):
    transfers.append(np.load(fname))
for fname in glob.glob('outputs/data/vae*.npy'):
    vaes.append(np.load(fname))
putters = np.array(putters)
transfers = np.array(transfers)
vaes = np.array(vaes)

put_m, put_s = np.mean(putters, axis=0), np.std(putters, axis=0)
tra_m, tra_s = np.mean(transfers, axis=0), np.std(transfers, axis=0)
vae_m, vae_s = np.mean(vaes, axis=0), np.std(vaes, axis=0)
plt.figure()
plt.fill_between([], [], [], color='C2', alpha=0.5, zorder=5, edgecolor='black')
plt.fill_between(np.arange(6), vae_m - vae_s, vae_m + vae_s, color='C1', alpha=0.5, zorder=3, edgecolor='black')
plt.fill_between(np.arange(6), put_m - put_s, put_m + put_s, color='C0', alpha=0.5, zorder=1, edgecolor='black')
plt.legend(["(a) $\\mathtt{get}\\to\\mathtt{get'}$", "(b) $\\mathtt{get}\\to\\mathtt{VAE}\\to\\mathtt{get'}$", "(c) $\\mathtt{get}\\to\\mathtt{put}\\to\\mathtt{get'}$"], frameon=False)
plt.plot(put_m, color='black', zorder=2, linestyle='--', linewidth=1.0)
plt.plot(vae_m, color='black', zorder=4, linestyle='dotted', linewidth=1.0)
plt.plot(tra_m, color='C2', zorder=6, linestyle='-.', linewidth=1.0)
plt.xlabel("Steps", labelpad=-2)
plt.ylabel("Accuracy", labelpad=-2)
plt.xticks(list(range(5+1)))
plt.yticks([0.4, 0.6, 0.8, 1.0], ["40%", "60%", "80%", "100%"])
plt.ylim(0.3, 1.0)
plt.xlim(0, 5)
plt.savefig('figure5b.pdf', dpi=300)
