import torch
from torchvision.datasets import CIFAR10
from datasets import CorruptedCIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader

import cifar10_config
import numpy as np
from tqdm import tqdm


if __name__ == '__main__':
    device = 'cuda'
    model_types = ['cnn_map']
    test_data_sets = ['CIFAR10', 'CIFAR10-C1', 'CIFAR10-C2', 'CIFAR10-C3', 'CIFAR10-C4', 'CIFAR10-C5']
    disable_tqdm = False
    for model_type in model_types:
        for test_data_set in test_data_sets:
            print("Model Type: {}".format(model_type))
            alphas_filepath = "./result_dicts/alphas/CIFAR10_alphas_{}_{}.npy".format(model_type, test_data_set)

            n_seeds = 10
            n_classes = 10
            n_channels = 3
            n_height = 32
            n_width = 32
            n_features = n_channels * n_height * n_width
            max_precision = 50000
            args = (n_classes, n_features, max_precision)
            config = getattr(cifar10_config, model_type)
            kwargs = config['kwargs']
            model = config['model'](*args, **kwargs)

            if test_data_set == 'CIFAR10_train':
                transform = Compose([ToTensor(), Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616))])
                test_data = CIFAR10("./data/", train=True, transform=transform, download=True)
                print("Using CIFAR10 train data with {} images.".format(len(test_data)))
            elif test_data_set == 'CIFAR10':
                transform = Compose([ToTensor(), Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616))])
                test_data = CIFAR10("./data/", train=False, transform=transform, download=True)
                print("Using CIFAR10 test data with {} images.".format(len(test_data)))
            elif test_data_set == 'CIFAR10-C1':
                transform = Compose([ToTensor(), Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616))])
                test_data = CorruptedCIFAR10("./data/", severity=1, transform=transform)
                print("Using CIFAR10-C1 test data with {} images.".format(len(test_data)))
            elif test_data_set == 'CIFAR10-C2':
                transform = Compose([ToTensor(), Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616))])
                test_data = CorruptedCIFAR10("./data/", severity=2, transform=transform)
                print("Using CIFAR10-C2 test data with {} images.".format(len(test_data)))
            elif test_data_set == 'CIFAR10-C3':
                transform = Compose([ToTensor(), Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616))])
                test_data = CorruptedCIFAR10("./data/", severity=3, transform=transform)
                print("Using CIFAR10-C3 test data with {} images.".format(len(test_data)))
            elif test_data_set == 'CIFAR10-C4':
                transform = Compose([ToTensor(), Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616))])
                test_data = CorruptedCIFAR10("./data/", severity=4, transform=transform)
                print("Using CIFAR10-C4 test data with {} images.".format(len(test_data)))
            elif test_data_set == 'CIFAR10-C5':
                transform = Compose([ToTensor(), Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616))])
                test_data = CorruptedCIFAR10("./data/", severity=5, transform=transform)
                print("Using CIFAR10-C5 test data with {} images.".format(len(test_data)))
            else:
                raise Exception("Unknown test data set: {}".format(test_data_set))

            initial_seed = 12345
            alphas = np.zeros([n_seeds, len(test_data), n_classes])
            for seed in range(n_seeds):
                print("Seed: {}/{}".format(seed + 1, n_seeds))
                torch.manual_seed(initial_seed + seed)
                model.load_state_dict(torch.load("./state_dicts/CIFAR10_{}_{}.pt".format(model_type, seed)))
                model.eval()
                model.to(device)
                data_loader = DataLoader(test_data, batch_size=config['n_batch'], shuffle=False)
                alpha = torch.zeros(0, n_classes).to(device)
                if disable_tqdm:
                    iterator = data_loader
                else:
                    iterator = tqdm(data_loader)
                with torch.no_grad():
                    for x, _ in iterator:
                        alpha = torch.cat([alpha, model.predict(x.to(device))])
                alpha = alpha.to('cpu')
                alphas[seed] = alpha
            np.save(alphas_filepath, alphas)
