import torch
import torchvision.datasets
import torchvision.transforms as transforms
import numpy as np
import os
import torch.nn as nn
import matplotlib.pyplot as plt
from pgd_attack import forward, perturb, perturb_random_restarts
from sklearn.metrics import roc_curve, auc as auc_fn
from dataset import *
from utils import set_eval, set_train
import seaborn as sns
import torchvision
from torchvision.utils import make_grid
import pandas as pd
from PIL import ImageFilter
from PIL import Image


def roc_auc(pos_data, neg_data, model):
    device = next(model.parameters()).device
    assert not model.training
    normalization = 'imagenet' if pos_data.shape[-1] == 128 else 'cifar10'
    with torch.no_grad():
        # pos_data_out = forward(model, pos_data, normalization)
        pos_data_out = torch.cat(
            [forward(model, batch.to(device), normalization) for batch in torch.split(pos_data, 100)])
        torch.cuda.empty_cache()
        # neg_data_out = forward(model, neg_data, normalization)
        neg_data_out = torch.cat(
            [forward(model, batch.to(device), normalization) for batch in torch.split(neg_data, 100)])
        torch.cuda.empty_cache()
    return auto_auc(pos_data_out, neg_data_out)


def auto_auc(x1, x0):
    x = torch.cat([x1.cpu(), x0.cpu()])
    target = torch.cat([torch.ones(x1.shape[0], dtype=torch.float32),
                        torch.zeros(x0.shape[0], dtype=torch.float32)])
    fpr_, tpr_, thresholds = roc_curve(target, x)
    result = auc_fn(fpr_, tpr_)
    return result


def compute_adv(x, model, attack_config, num_random_restarts=1):
    device = next(model.parameters()).device
    assert not model.training
    assert x.shape[-1] in [32, 128, 256, 224, 512]
    normalization = 'imagenet' if x.shape[-1] != 32 else 'cifar10'

    # If epsilon is 0
    if attack_config['eps'] < 1e-8:
        return x.clone()

    adv = []
    for batch in torch.split(x, 100):
        if num_random_restarts > 1:
            batch_adv = perturb_random_restarts(model, batch.to(device), normalization=normalization, num_random_restarts=num_random_restarts, **attack_config)
        else:
            batch_adv = perturb(model, batch.to(device), normalization=normalization, **attack_config)
        adv.append(batch_adv)
    adv = torch.cat(adv)
    #     print(torch.norm(ood_x_test.view([200, -1]) - ood_x_test_adv.view([200, -1]), p=2, dim=1)[:10])
    torch.cuda.empty_cache()
    return adv.cpu()


def ood_adv(model, indist_data, ood_data, epsilon, steps, step_size):
    attack_config = dict(norm='L2', eps=epsilon, steps=steps, step_size=step_size)
    ood_data_adv = compute_adv(ood_data, model, attack_config)
    roc_auc(indist_data, ood_data_adv, model)


dataset_names = ['Gaussian noise', 'Uniform noise', 'CIFAR10', 'ImageNet', 'Bedroom', 'SVHN', 'CelebAHQ', 'CIFAR100']


def load_datasets(data_size, num_samples=None):
    assert data_size in [32, 128, 256, 224, 512]
    if num_samples is not None and num_samples != -1:
        data_shape = torch.Size([num_samples, 3, data_size, data_size])
    else:
        if data_size == 32:
            data_shape = torch.Size([10000, 3, data_size, data_size])
        else:
            data_shape = torch.Size([3000, 3, data_size, data_size])

    torch.manual_seed(0)
    gaussian_noise = torch.randn(data_shape)
    gaussian_noise -= gaussian_noise.min()
    gaussian_noise /= gaussian_noise.max()
    uniform_noise = torch.rand(data_shape)
    datadir = './datasets/'
    transform = transforms.Compose([transforms.Resize(data_size), transforms.ToTensor()])
    datasets = {
        'CIFAR10': torchvision.datasets.CIFAR10(datadir, train=False, transform=transform, download=True),
        'Gaussian noise': torch.utils.data.TensorDataset(gaussian_noise),
        'Uniform noise': torch.utils.data.TensorDataset(uniform_noise),
        'ImageNet': get_imagenet128_val_dataset(datadir, transform),
        'Bedroom': get_bedroom128_val_dataset(datadir, transform),
        'SVHN': torchvision.datasets.SVHN(datadir, split='test', transform=transform, download=True),
        'CelebAHQ': get_celebahq128_val_dataset(datadir, transform),
        'CIFAR100': torchvision.datasets.CIFAR100(datadir, train=False, transform=transform, download=True),
    }

    if 'cifar10_class' in os.environ and os.environ['cifar10_class'] is not None:
        datasets['CIFAR10'] = CIFAR10Unsupervised(target_class=int(os.environ['cifar10_class']),
                                                  mode='include', train=False,
                                                  root=datadir, transform=transform)

    # assert set(datasets.keys()) == set(dataset_names)
    return datasets


def load_data(dataset, test_samples):
    if test_samples == -1:
        # Load all the data
        loader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=True)
        all_data = []
        for data in loader:
            if isinstance(data, list):
                data = data[0]
            all_data.append(data)
        return torch.cat(all_data)
    else:
        torch.manual_seed(0)  # Make sure to set the seed
        loader = torch.utils.data.DataLoader(dataset, batch_size=test_samples, shuffle=True)
        # Only want the data (not the labels if there are any)
        data = iter(loader).next()
        if isinstance(data, list):
            data = data[0]
        return data


def eval_ood_detection(model, indist_dataset_name, epsilon, test_samples, norm='L2', plot=False, ood_datasets=None, num_random_restarts=1, steps=None, step_size=None):
    """
    Evaluate out-of-distribution detection performance
    :param model: discriminative model
    :param indist_dataset_name: name of the in-distribution dataset
    :param epsilon: the maximum allowable perturbation
    :param test_samples: number of samples used for this test
    :param plot: whether or not to plot the perturbed OOD samples
    :return: a dataframe containing the test results
    dataframe format:
    	     CIFAR10	Gaussian noise	...	mean
    0.0	 *           *              *    *
    5.0	 *           *              *    *
    10.0	 *           *              *    *
    """
    data_size = 32 if indist_dataset_name in ['CIFAR10', 'svhn'] else 128
    data_shape = torch.Size([3, data_size, data_size])

    
    if indist_dataset_name == 'CIFAR10':
        extended_datasets = {'iSUN': get_iSUN_dataset('./datasets'),
                             'LSUN (resize)': get_LSUN_dataset('./datasets'), 
                             'TinyImageNet (resize)': get_TinyImageNet_dataset('./datasets')}
        datasets = load_datasets(data_size, test_samples)
        datasets.update(extended_datasets)
        indist_dataset = datasets[indist_dataset_name]
        if ood_datasets is not None:
            assert 'CIFAR10' not in ood_datasets
            ood_datasets = {k: v for k, v in datasets.items() if k in ood_datasets and k != indist_dataset_name}
        else:
            ood_datasets = {k: v for k, v in datasets.items() if k != indist_dataset_name}
    elif indist_dataset_name == 'svhn':
        datasets = load_datasets(data_size, test_samples)
        indist_dataset = torchvision.datasets.SVHN(root='./datasets', split='test', download=True, transform=transforms.ToTensor())
        assert ood_datasets is not None
        ood_datasets = {k: v for k, v in datasets.items() if k in ood_datasets and k != indist_dataset_name}
    else:
        datasets = load_datasets(data_size, test_samples)
        indist_dataset = datasets[indist_dataset_name]
        ood_datasets = {k: v for k, v in datasets.items() if k != indist_dataset_name}

    indist_data = load_data(indist_dataset, test_samples)
    # print('in-distribution', indist_data.shape)
    assert indist_data.shape[1:] == data_shape

    if plot:
        fig, axes = plt.subplots(nrows=len(datasets), ncols=2, dpi=150)
        for ax in axes.ravel():
            ax.set_axis_off()
        grid_img = transforms.ToPILImage()(make_grid(indist_data[:8], nrow=8))
        axes[0, 0].imshow(grid_img)

    if data_size == 32:
        if norm == 'L2':
            attack_config = dict(norm='L2', eps=epsilon, steps=200, step_size=0.5)
        else:
            if steps is not None and step_size is not None:
                attack_config = dict(norm='Linf', eps=epsilon, steps=steps, step_size=step_size)
            else:
                attack_config = dict(norm='Linf', eps=epsilon, steps=200, step_size=0.002)
    elif data_size == 128:
        if norm == 'L2':
            attack_config = dict(norm='L2', eps=epsilon, steps=200, step_size=2.0)
        else:
            attack_config = dict(norm='Linf', eps=epsilon, steps=200, step_size=0.01)
    print(attack_config)

    auc_scores = {}
    for i, (dataset_name, ood_dataset) in enumerate(ood_datasets.items()):
        ood_data = load_data(ood_dataset, test_samples)
        # print(dataset_name, ood_data.shape)
        # assert ood_data.shape == data_shape
        ood_data_adv = compute_adv(ood_data, model, attack_config, num_random_restarts)
        auc_score = roc_auc(indist_data, ood_data_adv, model)
        # print('{:15s}{:<9.4f}{:<9.4f}'.format(dataset_name, auc_raw, auc_score))
        auc_scores[dataset_name] = auc_score
        if plot:
            grid_img_raw = transforms.ToPILImage()(make_grid(ood_data[:8], nrow=8))
            grid_img_adv = transforms.ToPILImage()(make_grid(ood_data_adv[:8], nrow=8))
            axes[i + 1, 0].imshow(grid_img_raw)
            axes[i + 1, 1].imshow(grid_img_adv)
    if plot:
        plt.show()
    df = pd.DataFrame(data=auc_scores, index=['{:.1f}'.format(epsilon)])
    df['mean'] = df.mean(axis=1)
    return df


def eval_robustness(model, indist_dataset_name, ood_dataset_name, test_samples):
    """
    Out-of-distribution detection robustness test using different combinations of steps and step sizes
    :param model: the discriminative model
    :param indist_dataset_name: string, name of the in-distribution dataset
    :param ood_dataset_name: string, name of the out-of-distribution dataset
    :param test_samples: int, number of samples to be used for this test
    :return: a data frame that contains the ROC-AUC scores
    result format
        8.0	4.0	2.0	1.0	0.5
    100	*   *   *   *   *
    200	*   *   *   *   *
    500	*   *   *   *   *
    """
    assert indist_dataset_name in dataset_names
    assert ood_dataset_name in dataset_names
    data_size = 32 if indist_dataset_name == 'CIFAR10' else 128
    data_shape = torch.Size([test_samples, 3, data_size, data_size])

    # Get in-distribution and out-of-distribution data
    datasets = load_datasets(data_size, test_samples)
    indist_dataset = datasets[indist_dataset_name]
    ood_dataset = datasets[ood_dataset_name]

    indist_data = load_data(indist_dataset, test_samples)
    ood_data = load_data(ood_dataset, test_samples)
    assert indist_data.shape == data_shape
    assert ood_data.shape == data_shape

    if indist_data.shape[-1] == 32:
        epsilon = 2.0
        steps_seq = [200, 500, 1000]
        step_size_seq = [2.0, 1.0, 0.5, 0.1, 0.05]
    elif indist_data.shape[-1] == 128:
        epsilon = 10
        steps_seq = [100, 200, 500]
        step_size_seq = [8.0, 4.0, 2.0, 1.0, 0.5]
    auc_scores = {}
    for step_size in step_size_seq:
        auc_scores[step_size] = []
        for steps in steps_seq:
            attack_config = dict(norm='L2', eps=epsilon, steps=steps, step_size=step_size)
            ood_data_adv = compute_adv(ood_data, model, attack_config)
            auc_score = roc_auc(indist_data, ood_data_adv, model)
            auc_scores[step_size].append(auc_score)
    df = pd.DataFrame(data=auc_scores, index=steps_seq)
    return df


def generate(model, indist_dataset_name, ood_dataset_name, resolution=None):
    assert indist_dataset_name in dataset_names
    assert ood_dataset_name is None or ood_dataset_name in dataset_names + ['precomputed']
    if resolution == None:
        resolution = 32 if indist_dataset_name == 'CIFAR10' else 128
    samples = {32: 64, 128: 25, 256: 64, 512: 4}[resolution]

    datasets = load_datasets(resolution, samples)

    ood_datasets = {k: v for k, v in datasets.items() if k != indist_dataset_name}


    if ood_dataset_name == 'precomputed':
        assert indist_dataset_name == 'CelebAHQ'
        seed_file = {'CelebAHQ': 'img_seed_celebahq128_betavae.npy'}[indist_dataset_name]
        seed = np.load(seed_file)[:samples]
        if resolution != seed.shape[1]:
            seed = np.stack([Image.fromarray(im).resize([resolution, resolution]) for im in seed], axis=0)
        seed = seed.transpose([0, 3, 1, 2]).astype(np.float32) / 255
        ood_data = torch.from_numpy(seed)
    elif ood_dataset_name in dataset_names:
        ood_data = load_data(ood_datasets[ood_dataset_name], samples)
    else:
        ood_data = torch.cat([load_data(v, int(np.sqrt(samples))) for k, v in ood_datasets.items()])

    if ood_dataset_name is not None and ood_dataset_name != 'precomputed' and 'noise' not in ood_dataset_name:
        ToPILImage = transforms.ToPILImage()
        ToTensor = transforms.ToTensor()
        ood_data_blur = []
        if resolution == 32:
            radius = 3
        elif resolution in [128, 256, 224, 512]:
            radius = {'Bedroom': 13, 'CelebAHQ': 13, 'ImageNet': 15}[indist_dataset_name]
        for x in ood_data:
            img = ToPILImage(x)
            img = img.filter(ImageFilter.GaussianBlur(radius=radius))
            ood_data_blur.append(ToTensor(img))
        ood_data = torch.stack(ood_data_blur, axis=0)

    if resolution == 32:
        attack_config = dict(norm='L2', eps=15, steps=200, step_size=0.15)
    elif resolution == 128 and indist_dataset_name == 'Bedroom':
        attack_config = dict(norm='L2', eps=70, steps=400, step_size=0.8)
    elif resolution == 128 and indist_dataset_name == 'CelebAHQ':
        attack_config = dict(norm='L2', eps=40, steps=100, step_size=1.2)
    elif resolution == 256 and indist_dataset_name == 'CelebAHQ':
        attack_config = dict(norm='L2', eps=100, steps=400, step_size=1.2)
    elif resolution == 256 and indist_dataset_name == 'Bedroom':
        attack_config = dict(norm='L2', eps=500, steps=200, step_size=1.0)
    elif resolution == 128 and indist_dataset_name == 'ImageNet':
        attack_config = dict(norm='L2', eps=70, steps=400, step_size=0.8)
    elif resolution == 256 and indist_dataset_name == 'ImageNet':
        attack_config = dict(norm='L2', eps=500, steps=350, step_size=0.8)
    elif resolution == 512 and indist_dataset_name == 'ImageNet':
        attack_config = dict(norm='L2', eps=1000, steps=600, step_size=1.6)
    elif resolution == 224 and indist_dataset_name == 'ImageNet':
        attack_config = dict(norm='L2', eps=80, steps=100, step_size=1.2)
    print(attack_config)
    ood_data_adv = compute_adv(ood_data, model, attack_config)
    with torch.no_grad():
        device = next(model.parameters()).device
        ood_data_logits = forward(model, ood_data.to(device), 'imagenet')
        ood_data_adv_logits = forward(model, ood_data_adv.to(device), 'imagenet')
    grid_img_raw = transforms.ToPILImage()(make_grid(ood_data, nrow=int(np.sqrt(samples)), padding=1))
    grid_img_adv = transforms.ToPILImage()(make_grid(ood_data_adv, nrow=int(np.sqrt(samples)), padding=1))
    return grid_img_raw, grid_img_adv


