import torch
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader

import mnist_config
import numpy as np
from tqdm import tqdm


if __name__ == '__main__':
    device = 'cpu'
    model_types = ['mlp_map']
    test_data_sets = ['MNIST_train', 'MNIST']
    for model_type in model_types:
        for test_data_set in test_data_sets:
            disable_tqdm = False

            print("Model Type: {}".format(model_type))
            alphas_filepath = "./result_dicts/alphas/MNIST_alphas_{}_{}.npy".format(model_type, test_data_set)

            n_seeds = 10
            n_classes = 10
            n_channels = 1
            n_height = 28
            n_width = 28
            n_features = n_channels * n_height * n_width
            max_precision = 60000
            args = (n_classes, n_features, max_precision)
            config = getattr(mnist_config, model_type)
            kwargs = config['kwargs']
            model = config['model'](*args, **kwargs)

            if test_data_set == 'MNIST_train':
                transform = Compose([ToTensor(), Normalize(mean=.5, std=.5)])
                test_data = MNIST("./data/", train=True, transform=transform, download=True)
                print("Using MNIST train data with {} images.".format(len(test_data)))
            elif test_data_set == 'MNIST':
                transform = Compose([ToTensor(), Normalize(mean=.5, std=.5)])
                test_data = MNIST("./data/", train=False, transform=transform, download=True)
                print("Using MNIST test data with {} images.".format(len(test_data)))
            else:
                raise Exception("Unknown test data set: {}".format(test_data_set))

            initial_seed = 12345
            alphas = np.zeros([n_seeds, len(test_data), n_classes])
            for seed in range(n_seeds):
                print("Seed: {}/{}".format(seed + 1, n_seeds))
                torch.manual_seed(initial_seed + seed)
                model.load_state_dict(torch.load("./state_dicts/MNIST_{}_{}.pt".format(model_type, seed)))
                model.to(device)
                data_loader = DataLoader(test_data, batch_size=config['n_batch'], shuffle=False)
                alpha = torch.zeros(0, n_classes).to(device)
                if disable_tqdm:
                    iterator = data_loader
                else:
                    iterator = tqdm(data_loader)
                with torch.no_grad():
                    for x, _ in iterator:
                        alpha = torch.cat([alpha, model.predict(x.to(device))])
                alpha = alpha.to('cpu')
                alphas[seed] = alpha
            np.save(alphas_filepath, alphas)
