import torch
from torchvision import transforms
import numpy as np

from convexrobust.data.cifar_select import CIFAR10SelectDataModule
from convexrobust.data.mnist_select import MNISTSelectDataModule
from convexrobust.data.kaggle_catsdogs import KaggleCatsDogsDataModule
from convexrobust.data.malimg import MalimgDataModule
from convexrobust.data.circles import CirclesDataModule
from convexrobust.utils import dirs

from typing import List

names = ['cifar10_catsdogs', 'cifar10_dogscats', 'mnist_38',
         'kaggle_catsdogs', 'malimg', 'circles']

null_transform_small_image = transforms.Compose([transforms.ToTensor()])
null_transform_large_image = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])
null_transform_mal_image = transforms.Compose([
    transforms.CenterCrop(512),
    transforms.ToTensor(),
])

transforms_mnist = {
    'train_transforms': transforms.Compose([
        transforms.RandomCrop(28, padding=1, padding_mode='edge'),
        transforms.ToTensor(),
    ]),
    'val_transforms': null_transform_small_image,
    'test_transforms': null_transform_small_image
}

transforms_small_image = {
    'train_transforms': transforms.Compose([
        transforms.RandomCrop(32, padding=3, padding_mode='edge'),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]),
    'val_transforms': null_transform_small_image,
    'test_transforms': null_transform_small_image
}

transforms_large_image = {
    'train_transforms': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]),
    'val_transforms': null_transform_large_image,
    'test_transforms': null_transform_large_image
}

transforms_mal_image = {
    'train_transforms': transforms.Compose([
        transforms.CenterCrop(512),
        transforms.RandomCrop(512, padding=20),
        transforms.ToTensor(),
    ]),
    'val_transforms': null_transform_mal_image,
    'test_transforms': null_transform_mal_image
}

transforms_null_small_image = {
    'train_transforms': null_transform_small_image,
    'val_transforms': null_transform_small_image,
    'test_transforms': null_transform_small_image
}

transforms_null_large_image = {
    'train_transforms': null_transform_large_image,
    'val_transforms': null_transform_large_image,
    'test_transforms': null_transform_large_image
}

transforms_null_mal_image = {
    'train_transforms': null_transform_mal_image,
    'val_transforms': null_transform_mal_image,
    'test_transforms': null_transform_mal_image
}


# If change, update in l infty nets dataset.py

_CIFAR10_CATSDOGS_MEAN = [0.4986, 0.4610, 0.4165]
_CIFAR10_CATSDOGS_STDDEV = [0.2542, 0.2482, 0.2534]

_MNIST_38_MEAN = [0.1457]
_MNIST_38_STDDEV = [0.3215]

_KAGGLE_CATSDOGS_MEAN = [0.4874, 0.4499, 0.4109]
_KAGGLE_CATSDOGS_STDDEV = [0.2588, 0.2510, 0.2517]

_MALIMG_MEAN = [0.1857]
_MALIMG_STDDEV = [0.3029]


def scalar_std_normalizer(mean, stddev):
    return transforms.Normalize(mean, [np.mean(stddev)] * len(mean))


def append_normalize(trans, normalize):
    for (_, transform) in trans.items():
        transform.transforms.append(normalize)


def get_datamodule(name, batch_size=None, no_transforms=False, normalize_scalar_std=False):
    if batch_size is None:
        batch_size = 32 if name=='malimg' else 64
    fixed_params = {'batch_size': batch_size, 'num_workers': 0, 'shuffle': True}

    if name == 'cifar10_catsdogs':
        trans = transforms_null_small_image if no_transforms else transforms_small_image
        if normalize_scalar_std:
            normalize = scalar_std_normalizer(_CIFAR10_CATSDOGS_MEAN, _CIFAR10_CATSDOGS_STDDEV)
            append_normalize(trans, normalize)
        datamodule = CIFAR10SelectDataModule(
            data_dir=dirs.data_path('cifar10'), labels=[3, 5], **fixed_params, **trans
        )
    elif name == 'cifar10_dogscats':
        # Reverse labels for ablation tests
        trans = transforms_null_small_image if no_transforms else transforms_small_image
        if normalize_scalar_std:
            normalize = scalar_std_normalizer(_CIFAR10_CATSDOGS_MEAN, _CIFAR10_CATSDOGS_STDDEV)
            append_normalize(trans, normalize)
        datamodule = CIFAR10SelectDataModule(
            data_dir=dirs.data_path('cifar10'), labels=[5, 3], **fixed_params, **trans
        )
    elif name == 'mnist_38':
        trans = transforms_null_small_image if no_transforms else transforms_mnist
        if normalize_scalar_std:
            normalize = scalar_std_normalizer(_MNIST_38_MEAN, _MNIST_38_STDDEV)
            append_normalize(trans, normalize)
        datamodule = MNISTSelectDataModule(
            data_dir=dirs.data_path('mnist'), labels=[3, 8], **fixed_params, **trans
        )
    elif name == 'kaggle_catsdogs':
        trans = transforms_null_large_image if no_transforms else transforms_large_image
        if normalize_scalar_std:
            normalize = scalar_std_normalizer(_KAGGLE_CATSDOGS_MEAN, _KAGGLE_CATSDOGS_STDDEV)
            append_normalize(trans, normalize)
        datamodule = KaggleCatsDogsDataModule(**fixed_params, **trans)
    elif name == 'malimg':
        trans = transforms_null_mal_image if no_transforms else transforms_mal_image
        if normalize_scalar_std:
            normalize = scalar_std_normalizer(_MALIMG_MEAN, _MALIMG_STDDEV)
            append_normalize(trans, normalize)
        datamodule = MalimgDataModule(**fixed_params, **trans)
    elif name == 'circles':
        datamodule = CirclesDataModule()
    else:
        raise NotImplementedError()

    datamodule.prepare_data()
    datamodule.setup()

    return datamodule


# Adapted from https://github.com/locuslab/smoothing

def get_normalize_layer(name: str, mean_only=False, average_stddev=False) -> torch.nn.Module:
    """
    Return the data's normalization layer.
    This is currently unused as the pretrained networks
    have their own normalization layers.
    """
    args = {'mean_only': mean_only, 'average_stddev': average_stddev}
    if name == 'cifar10_catsdogs':
        return NormalizeLayer(_CIFAR10_CATSDOGS_MEAN, _CIFAR10_CATSDOGS_STDDEV, **args)
    elif name == 'mnist_38':
        return NormalizeLayer(_MNIST_38_MEAN, _MNIST_38_STDDEV, **args)
    elif name == 'kaggle_catsdogs':
        return NormalizeLayer(_KAGGLE_CATSDOGS_MEAN, _KAGGLE_CATSDOGS_STDDEV, **args)
    elif name == 'malimg':
        return NormalizeLayer(_MALIMG_MEAN, _MALIMG_STDDEV, **args)
    raise ValueError('Invalid dataset selected')


class NormalizeLayer(torch.nn.Module):
    """Standardize the channels of a batch of images by subtracting the dataset mean
      and dividing by the dataset standard deviation.
      In order to certify radii in original coordinates rather than standardized coordinates, we
      add the Gaussian noise _before_ standardizing, which is why we have standardization be
      the first layer of the classifier rather than as a part of preprocessing as is typical.
      """

    def __init__(self, means: List[float], sds: List[float], mean_only, average_stddev):
        """
        :param means: the channel means
        :param sds: the channel standard deviations
        """
        super(NormalizeLayer, self).__init__()
        self.means = torch.tensor(means).cuda()
        self.sds = torch.tensor(sds).cuda()
        if average_stddev:
            self.sds.fill_(self.sds.mean())
        if mean_only:
            self.sds.fill_(1.0)

    def forward(self, input: torch.Tensor):
        (batch_size, num_channels, height, width) = input.shape
        means = self.means.repeat((batch_size, height, width, 1)).permute(0, 3, 1, 2)
        sds = self.sds.repeat((batch_size, height, width, 1)).permute(0, 3, 1, 2)

        return (input - means) / sds
