import torch
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader
from evaluation import fgsm_attack

import cifar10_config
import numpy as np
from tqdm import tqdm


if __name__ == '__main__':
    device = 'cuda'
    model_types = ['cnn_map']
    for model_type in model_types:
        disable_tqdm = False

        print("Model Type: {}".format(model_type))
        alphas_filepath = "./result_dicts/alphas/CIFAR10_alphas_{}_adversarial.npy".format(model_type)

        n_seeds = 10
        n_classes = 10
        n_channels = 3
        n_height = 32
        n_width = 32
        n_features = n_channels * n_height * n_width
        max_precision = 50000
        args = (n_classes, n_features, max_precision)
        config = getattr(cifar10_config, model_type)
        kwargs = config['kwargs']
        model = config['model'](*args, **kwargs)

        transform = Compose([ToTensor(), Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616))])
        test_data = CIFAR10("./data/", train=False, transform=transform, download=True)
        print("Using CIFAR10 test data with {} images.".format(len(test_data)))

        initial_seed = 12345
        epsilons = [0.00, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30]
        alphas = np.zeros([n_seeds, len(epsilons), len(test_data), n_classes])
        for seed in range(n_seeds):
            print("Seed: {}/{}".format(seed + 1, n_seeds))
            torch.manual_seed(initial_seed + seed)
            model.load_state_dict(torch.load("./state_dicts/CIFAR10_{}_{}.pt".format(model_type, seed)))
            model.eval()
            model.to(device)
            data_loader = DataLoader(test_data, batch_size=64, shuffle=False)
            alpha = [torch.zeros(0, n_classes).to(device)] * len(epsilons)
            if disable_tqdm:
                iterator = data_loader
            else:
                iterator = tqdm(data_loader)
            for data, target in iterator:
                data, target = data.to(device), target.to(device)

                data.requires_grad = True
                if ('dropout' in model_type) or ('gaussian' in model_type):
                    n_samples = 10
                    output = torch.zeros(n_samples, data.shape[0], n_classes).to(device)
                    for i in range(n_samples):
                        output[i] = model(data)
                    output = output.mean(dim=0)
                else:
                    output = model(data)
                loss = model.nll_loss(output, target)
                model.zero_grad()
                loss.backward()
                data_grad = data.grad.data

                with torch.no_grad():
                    for i, eps in enumerate(epsilons):
                        perturbed_data = fgsm_attack(data, data_grad, eps=eps)
                        alpha[i] = torch.cat([alpha[i], model.predict(perturbed_data)])
            for i in range(len(epsilons)):
                alpha[i] = alpha[i].to('cpu')
                alphas[seed, i] = alpha[i]
        np.save(alphas_filepath, alphas)
