from paper import alfr, alfr_ds
import torch, torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import argparse


def transform_pil_img(pil_img, e, d):
    x = transform(pil_img)
    x_flattened = x.view(-1, 784)
    x_transformed = d(e(x_flattened))
    x_numpy = x_transformed.view(28, 28).detach().numpy()
    return Image.fromarray(np.rint(x_numpy * 255).astype(np.uint8))


if __name__ == "__main__":
    # Create argument parser
    parser = argparse.ArgumentParser(description='Run IMDb results.')
    parser.add_argument('--train_new', default=False, type=bool, help='Whether or not to train a new model, or load from disk.')
    args = parser.parse_args()

    # Load dataset
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(), torchvision.transforms.Lambda(lambda x: torch.flatten(x))])
    target_transform = torchvision.transforms.Lambda(
        lambda x: torch.FloatTensor([1]) if x == 8 else torch.FloatTensor([0]))
    mnist_dataset = torchvision.datasets.MNIST(root="~/torch_datasets", train=True, download=True, transform=transform,
                                               target_transform=target_transform)
    train_loader = torch.utils.data.DataLoader(mnist_dataset, batch_size=128, shuffle=True)

    # Whether or not to train a new model, or load from disk
    # Disclaimer: Training a new model may alter results slightly
    train_new = False

    if args.train_new:
        # Train a new encoder / decoder pair using our method
        encoder_alfrds, decoder_alfrds, _ = alfr_ds(
            train_loader, adversary_hidden_size=64, epochs=[10, 10, 10], output_activation='sigmoid'
        )
        # Save to disk
        torch.save(encoder_alfrds, 'models/mnist-encoder-alfr-ds.pkl')
        torch.save(decoder_alfrds, 'models/mnist-decoder-alfr-ds.pkl')

        # Train a new encoder / decoder pair using old method
        encoder_alfr, decoder_alfr, _ = alfr(
            train_loader, adversary_hidden_size=64, epochs=30, output_activation='sigmoid', clip_gradients=True
        )
        # Save to disk
        torch.save(encoder_alfr, 'models/mnist-encoder-alfr.pkl')
        torch.save(decoder_alfr, 'models/mnist-decoder-alfr.pkl')
    else:
        # Load from disk
        encoder_alfrds = torch.load('models/mnist-encoder-alfr-ds.pkl')
        decoder_alfrds = torch.load('models/mnist-decoder-alfr-ds.pkl')

        encoder_alfr = torch.load('models/mnist-encoder-alfr.pkl')
        decoder_alfr = torch.load('models/mnist-decoder-alfr.pkl')

    # Load data
    example_data = torchvision.datasets.MNIST(root="~/torch_datasets", download=True)

    # Example data to show. We show 1598 because it looks interesting
    data = [402, 403, 404, 405, 1598]
    plt.rcParams["figure.figsize"] = (1.5 * len(data), 4.5)
    plt.figure()
    plt.set_cmap('Greys')

    _, axarrs = plt.subplots(3, len(data))
    for i, img_index in enumerate(data):
        img, label = example_data[img_index]
        axarrs[0][i].imshow(img)
        axarrs[1][i].imshow(transform_pil_img(img, encoder_alfrds, decoder_alfrds))
        axarrs[2][i].imshow(transform_pil_img(img, encoder_alfr, decoder_alfr))

        for spine in axarrs[0][i].spines.values():
            spine.set_edgecolor('black')
        for spine in axarrs[1][i].spines.values():
            spine.set_edgecolor('green')
        for spine in axarrs[2][i].spines.values():
            spine.set_edgecolor('red')

        axarrs[0][i].xaxis.set_ticks([]), axarrs[0][i].yaxis.set_ticks([])
        axarrs[1][i].xaxis.set_ticks([]), axarrs[1][i].yaxis.set_ticks([])
        axarrs[2][i].xaxis.set_ticks([]), axarrs[2][i].yaxis.set_ticks([])

    plt.savefig('output/mnist_digits.pdf', bbox_inches='tight')
    print(f'Saved to "output/mnist_digits.pdf"')
