import os
import numpy as np
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100, SVHN, MNIST, ImageNet, GTSRB
from torchvision.datasets.folder import ImageFolder
from .cifar_custom import CustomCIFAR10
from .cifar_badnet import BadNetCIFAR10
from .cifar_sig import SIGCIFAR10
from .cifar_trojan import TrojanCIFAR10
from .cifar_blend import BlendCIFAR10
from .cifar_cl import CLCIFAR10
from .cifar_dynamic import DynamicCIFAR10
from .cifar_fc import FCCIFAR10
from .cifar_dfst import DFSTCIFAR10
from .cifar_wanet import WaNetCIFAR10
from .cifar_smooth import SmoothCIFAR10
from .cifar_nashville import NashvilleCIFAR10
from .cifar_badnet_adaptive import BadNetAdaptiveCIFAR10
from .gtsrb_badnet import BadNetGTSRB
from .gtsrb_smooth import SmoothGTSRB
from .issba import ISSBAImageNet, ISSBAImageNetClean
from .mixed_mnist import MIXED_MNIST
from .celeba import CustomCelebA
from .imagenet_badnet import BadNetImageNet, ImageNetSubset

transform_options = {
    "None": {
        "train_transform": None,
        "test_transform": None},
    "NoAug": {
        "train_transform": [transforms.ToTensor()],
        "test_transform": [transforms.ToTensor()]},
    "CMNIST": {
        "train_transform": [
            transforms.ToPILImage(),
            transforms.RandomCrop(28, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ],
        "test_transform": None},

    "CIFAR10": {
        "train_transform": [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ],
        "test_transform": [transforms.ToTensor()]},
    "CIFAR100": {
        "train_transform": [transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.RandomRotation(15),
                            transforms.ToTensor()],
        "test_transform": [transforms.ToTensor()]},
    "GTSRB": {
        "train_transform": [
            transforms.Resize((32, 32)),
            transforms.ToTensor()
        ],
        "test_transform": [
            transforms.Resize((32, 32)),
            transforms.ToTensor()
        ],
    },
    "ImageNet": {
        "train_transform": [transforms.RandomResizedCrop(224),
                            transforms.RandomHorizontalFlip(),
                            transforms.ColorJitter(brightness=0.4,
                                                   contrast=0.4,
                                                   saturation=0.4,
                                                   hue=0.2),
                            transforms.ToTensor()],
        "test_transform": [transforms.Resize((224, 224)),
                           transforms.ToTensor()]},
    "ImageNetNoAug": {
        "train_transform": [transforms.Resize((224, 224)),
                            transforms.ToTensor()],
        "test_transform": [transforms.Resize((224, 224)),
                           transforms.ToTensor()]},
    "ISSBAImageNet": {
        "train_transform": [
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ],
        "test_transform": [
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ]},
    "BadNetImageNet": {
        "train_transform": [
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ],
        "test_transform": [
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ]},
    "CelebA": {
            "train_transform": [transforms.Resize((224, 224)),
                                transforms.ToTensor()],
            "test_transform": [transforms.Resize((224, 224)),
                               transforms.ToTensor()]
        },
    }

dataset_options = {
        "CIFAR10": lambda path, transform, is_test, kwargs:
        CIFAR10(root=path, train=not is_test, download=True,
                transform=transform),
        "CIFAR10NoAug": lambda path, transform, is_test, kwargs:
        CIFAR10(root=path, train=not is_test, download=True,
                transform=transform),
        "CustomCIFAR10": lambda path, transform, is_test, kwargs:
        CustomCIFAR10(root=path, train=not is_test, download=True,
                      transform=transform, **kwargs),
        "BadNetCIFAR10": lambda path, transform, is_test, kwargs:
        BadNetCIFAR10(root=path, train=not is_test, download=True,
                      transform=transform, **kwargs),
        "BadNetAdaptiveCIFAR10": lambda path, transform, is_test, kwargs:
        BadNetAdaptiveCIFAR10(root=path, train=not is_test, download=True,
                              transform=transform, **kwargs),
        "SIGCIFAR10": lambda path, transform, is_test, kwargs:
        SIGCIFAR10(root=path, train=not is_test, download=True,
                   transform=transform, **kwargs),
        "TrojanCIFAR10": lambda path, transform, is_test, kwargs:
        TrojanCIFAR10(root=path, train=not is_test, download=True,
                      transform=transform, **kwargs),
        "BlendCIFAR10": lambda path, transform, is_test, kwargs:
        BlendCIFAR10(root=path, train=not is_test, download=True,
                     transform=transform, **kwargs),
        "CLCIFAR10": lambda path, transform, is_test, kwargs:
        CLCIFAR10(root=path, train=not is_test, download=True,
                  transform=transform, **kwargs),
        "DynamicCIFAR10": lambda path, transform, is_test, kwargs:
        DynamicCIFAR10(root=path, train=not is_test, download=True,
                       transform=transform, **kwargs),
        "FCCIFAR10": lambda path, transform, is_test, kwargs:
        FCCIFAR10(root=path, train=not is_test, download=True,
                  transform=transform, **kwargs),
        "DFSTCIFAR10": lambda path, transform, is_test, kwargs:
        DFSTCIFAR10(root=path, train=not is_test, download=True,
                    transform=transform, **kwargs),
        "WaNetCIFAR10": lambda path, transform, is_test, kwargs:
        WaNetCIFAR10(root=path, train=not is_test, download=True,
                     transform=transform, **kwargs),
        "NashvilleCIFAR10": lambda path, transform, is_test, kwargs:
        NashvilleCIFAR10(root=path, train=not is_test, download=True,
                         transform=transform, **kwargs),
        "SmoothCIFAR10": lambda path, transform, is_test, kwargs:
        SmoothCIFAR10(root=path, train=not is_test, download=True,
                      transform=transform, **kwargs),
        "CIFAR100": lambda path, transform, is_test, kwargs:
        CIFAR100(root=path, train=not is_test, download=True,
                 transform=transform),
        "GTSRB": lambda path, transform, is_test, kwargs:
        GTSRB(root=path, split='test' if is_test else 'train', download=True,
              transform=transform),
        "SmoothGTSRB": lambda path, transform, is_test, kwargs:
        SmoothGTSRB(root=path, split='test' if is_test else 'train', download=True,
                    transform=transform),
        "BadNetGTSRB": lambda path, transform, is_test, kwargs:
        BadNetGTSRB(root=path, split='test' if is_test else 'train', download=True,
                    transform=transform, **kwargs),
        "BadNetImageNet": lambda path, transform, is_test, kwargs:
        BadNetImageNet(root=path, split='val' if is_test else 'train', transform=transform, download=True, ** kwargs),
        "ImageNetSubset": lambda path, transform, is_test, kwargs:
        ImageNetSubset(root=path, split='val' if is_test else 'train', transform=transform, download=True, ** kwargs),
        "SVHN": lambda path, transform, is_test, kwargs:
        SVHN(root=path, split='test' if is_test else 'train', download=True,
             transform=transform),
        "MNIST": lambda path, transform, is_test, kwargs:
        MNIST(root=path, train=not is_test, download=True,
              transform=transform),
        "ImageNet": lambda path, transform, is_test, kwargs:
        ImageNet(root=path, split='val' if is_test else 'train',
                 transform=transform),
        "ImageFolder": lambda path, transform, is_test, kwargs:
        ImageFolder(root=os.path.join(path, 'train') if not is_test else
                    os.path.join(path, 'val'),
                    transform=transform),
        "ISSBAImageNet": lambda path, transform, is_test, kwargs:
        ISSBAImageNet(root=os.path.join(path, 'train') if not is_test else
                      os.path.join(path, 'val'),
                      mode='val' if is_test else 'train',
                      transform=transform,  **kwargs),
        "ISSBAImageNetClean": lambda path, transform, is_test, kwargs:
        ISSBAImageNetClean(root=os.path.join(path, 'train') if not is_test else
                           os.path.join(path, 'val'),
                           mode='val' if is_test else 'train',
                           transform=transform,  **kwargs),
        "MIXED_MNIST": lambda path, transform, is_test, kwargs:
        MIXED_MNIST(root=path, train=not is_test, download=True,
                    transform=transform, **kwargs),
        "CustomCelebA": lambda path, transform, is_test, kwargs:
        CustomCelebA(root=path, split='valid' if is_test else 'train', download=True,
                     transform=transform, **kwargs),
}


def get_classidx(dataset_type, dataset):
    if 'CIFAR100' in dataset_type:
        return [
            np.where(np.array(dataset.targets) == i)[0] for i in range(100)
        ]
    elif 'CIFAR10' in dataset_type:
        return [np.where(np.array(dataset.targets) == i)[0] for i in range(10)]
    elif 'SVHN' in dataset_type:
        return [np.where(np.array(dataset.labels) == i)[0] for i in range(10)]
    else:
        error_msg = 'dataset_type %s not supported' % dataset_type
        raise(error_msg)
