from paper import alfr, alfr_ds, test_against_adversaries, Run, Experiment
import torch
import torchvision
import argparse


def experiment_mnist(repeat=10, test_epochs_adversaries=10, device='cpu'):
    my_experiment = Experiment()
    for algorithm in ['ALFR-DS(1)', 'ALFR-DS(2)', 'ALFR-DS(3)', 'ALFR (LR)', 'ALFR (MLP)', 'Uncensored']:
        for _ in range(repeat):
            encoder, decoder, train_result = None, None, None

            if algorithm == 'ALFR-DS(1)':
                encoder, decoder, train_result = alfr_ds(train_loader, adversary_hidden_size=64, epochs=[30],
                                                         output_activation='sigmoid', device=device)
            elif algorithm == 'ALFR-DS(2)':
                encoder, decoder, train_result = alfr_ds(train_loader, adversary_hidden_size=64, epochs=[15, 15],
                                                         output_activation='sigmoid', device=device)
            elif algorithm == 'ALFR-DS(3)':
                encoder, decoder, train_result = alfr_ds(train_loader, adversary_hidden_size=64,
                                                         epochs=[10, 10, 10], output_activation='sigmoid', device=device)
            elif algorithm == 'ALFR (LR)':
                encoder, decoder, train_result = alfr(train_loader, adversary_hidden_size=None, epochs=30,
                                                      output_activation='sigmoid', clip_gradients=True, device=device)
            elif algorithm == 'ALFR (MLP)':
                encoder, decoder, train_result = alfr(train_loader, adversary_hidden_size=64, epochs=30,
                                                      output_activation='sigmoid', clip_gradients=True, device=device)
            elif algorithm == 'Uncensored':
                encoder, decoder, train_result = alfr(train_loader, adversary_hidden_size=None, alpha=0, epochs=30,
                                                      output_activation='sigmoid', device=device)

            evaluation_results = test_against_adversaries(train_loader, val_loader, encoder=encoder,
                                                          epochs=test_epochs_adversaries, device=device)
            my_experiment.add_run(Run(train_result, evaluation_results, algorithm))
    return my_experiment


if __name__ == "__main__":
    # Create argument parser
    parser = argparse.ArgumentParser(description='Run MNIST results.')
    parser.add_argument('--train_new', default=False, type=bool, help='Whether or not to train a new model, or load from disk.')
    parser.add_argument('--device', default='cpu', type=str, help='Device to use for training (default "cpu").')
    parser.add_argument('--num_workers', default=0, type=int, help='How many parallel workers to use for data loading.')
    parser.add_argument('--batch_size', default=128, type=int, help='Batch size for training.')
    args = parser.parse_args()

    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Lambda(lambda x: torch.flatten(x))])

    target_transform = torchvision.transforms.Lambda(lambda x: torch.FloatTensor([1]) if x==8 else torch.FloatTensor([0]))

    mnist_dataset = torchvision.datasets.MNIST(
        root="~/torch_datasets", train=True, download=True, transform=transform, target_transform=target_transform
    )

    train_set, val_set = torch.utils.data.random_split(mnist_dataset, [55000, 5000])
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    if args.train_new:
        # Uncomment below to run the experiment
        experiment = experiment_mnist(device=args.device)
        experiment.save('experiments/experiment_mnist.pkl')
    else:
        # Load the experiment
        experiment = Experiment.load('experiments/experiment_mnist.pkl')

    experiment.get_score_against_adversaries('val_accuracy').to_excel('output/mnist_val_accuracy.xlsx')
    experiment.plot_reconstruction_loss_per_trial(legend={'loc': 'upper left', 'fontsize': 9},
                                                  tofile='output/mnist_loss.pdf')

    print('Check folder "output" for output.')