import os
import torch
import platform
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, SVHN, CelebA, ImageFolder, FashionMNIST
# for testing the code
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))


cDT = {'cifar10': CIFAR10,
       'cifar10H': CIFAR10,
       'cifar100': CIFAR100,
       'svhn': SVHN,
       'mnist': MNIST,
       'celeba': CelebA,
       'imagenet': ImageFolder,
       'fmnist': FashionMNIST
       }
DATASETS = ['cifar10', 'cifar10H', 'cifar100', 'svhn', 'stl10', 'mnist', 'celeba', 'imagenet']
data_path = {'cifar10': "./data/cifar10/" if platform.system()=="Windows" else '~/data/cifar10/',
            'cifar10H': "./data/cifar10/" if platform.system()=="Windows" else '~/data/cifar10/',
             'cifar100': "./data/cifar100/" if platform.system()=="Windows" else '~/data/cifar100/',
             'svhn': "./data/svhn/" if platform.system()=="Windows" else '~/data/svhn/',
             'mnist': "./data/mnist/" if platform.system() == "Windows" else '~/data/mnist/',
             'stl10': "./data/stl10/" if platform.system()=="Windows" else '~/data/stl10/',
             'imagenet': "./data/imagenet" if platform.system()=="Windows" else '~/data/tiny-imagenet-200/',
            'celeba': "./data/celeba/" if platform.system()=="Windows" else '~/data/celeba/',
            'fmnist': "./data/fmnist/" if platform.system()=="Windows" else '~/data/fmnist/'
             }
nDT = {'cifar10': 50000,
       'cifar10H': 0,
       'cifar100': 0,
       'svhn': 73257,
       'mnist': 60000,
       'fmnist': 60000,
       'celeba': 0,
       'imagenet': 0
       }
n_cls = {'cifar10': 10,
        'cifar10H': 10,
       'cifar100': 100,
       'svhn': 10,
       'mnist' : 10,
       'fmnist': 10,
       'stl10': 10,
       'imagenet': 200,
       'celeba': 40 # 40 attributes, so this is incorrect, but temporarily used
       }
normalization_infos = {
    'cifar10' : [(0.4914, 0.4822, 0.4465),
                 (0.2023, 0.1994, 0.2010)], # MEAN, STD
    'cifar100' : [(0.5071, 0.4867, 0.4408),
                  (0.2675, 0.2565, 0.2761)], # MEAN, STD
    'cifar10H': [(0.4914, 0.4822, 0.4465),
                 (0.2023, 0.1994, 0.2010)],  # MEAN, STD
    'svhn' : [(0.4376821, 0.4437697, 0.47280442),
              (0.19803012, 0.20101562, 0.19703614)],
    'mnist' : [(0.1307,), (0.3081,)],
    'fmnist': [(0.2860), (0.3530)],
    'celeba' : [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)] # for uncertain dataset?
}
sh = {'cifar10': (3, 32, 32),
        'cifar10H': (3, 32, 32),
       'cifar100': (3, 32, 32),
       'svhn': (3, 32, 32),
       'mnist' : (1, 28, 28),
        'imagenet': (3, 64, 64),
        'celeba': (3, 64, 64) # (3, 178, 218) https://paperswithcode.com/dataset/celeba, but resized to 64x64
    }
input_range = {'cifar10': (-1, 1),
        'cifar10H': (-1, 1),
       'cifar100': (-1, 1),
       'svhn': (-1, 1),
       'mnist' : (-1, 1),
       'fmnist': (-1, 1),
       'imagenet': (-1, 1),
       'celeba': (-1, 1)}


def normalize(dt):
    mean, std = normalization_infos[dt]
    return transforms.Normalize(mean=mean, std=std)


def get_transform(args, dt, sigma=0.):
    tr_lst = [transforms.ToTensor()]
    if not args.at:
        tr_lst.append(normalize(dt))
    tr_lst.append(lambda x: x + sigma * torch.randn_like(x))
    transform_test = transforms.Compose(tr_lst)
    return transform_test


def transform(args, dt, type):
    if type == 'tr':
        if dt == 'mnist' or dt == 'fmnist':
            tr_lst = list()
        elif dt == 'svhn':
            tr_lst = [transforms.Pad(4, padding_mode="reflect"),
                      transforms.ColorJitter(brightness=63. / 255., saturation=[0.5, 1.5], contrast=[0.2, 1.8]),
                      transforms.RandomCrop(32) ]
        elif dt == 'imagenet':
            tr_lst = [
                transforms.RandomResizedCrop(size=(64, 64)),
                transforms.RandomHorizontalFlip(),
            ]
        # TODO: add more augmentation
        elif dt == 'celeba':
            tr_lst = [transforms.Resize(size=(64, 64))]
        else:
            tr_lst = [transforms.Pad(4, padding_mode="reflect"),
                      transforms.RandomCrop(32),
                      transforms.RandomHorizontalFlip()]
        tr_lst.append(transforms.ToTensor())
        tr_lst.append(normalize(dt))
        # if not args.at:
        #     tr_lst.append(normalize())
        # if args.x_noise:
        #     tr_lst.append(lambda x: x + args.x_sigma * torch.randn_like(x))
        transform = transforms.Compose(tr_lst)
    elif type == 'vl':
        tr_lst = []
        if dt == 'imagenet' or dt == 'celeba':
            tr_lst = [
                transforms.Resize(size=(64, 64)),
                transforms.ToTensor()
            ]
        else:
            tr_lst = [transforms.ToTensor()]
        # TODO: corrupt with another distribution of noise
        if args.vl_corrupt_ns:
            tr_lst.append(lambda x: x + args.x_sigma * torch.randn_like(x))
        # if not args.at:
        #     tr_lst.append(normalize())
        tr_lst.append(normalize(dt))
        transform = transforms.Compose(tr_lst)
    else:
        raise ValueError
    return transform


def get_dl_tr(args, shuffle=True, no_aug=False):
    return _get_dl(args, True, shuffle, no_aug)


def get_dl_vl(args, shuffle=False, no_aug=False):
    return _get_dl(args, False, shuffle, no_aug)


'''
def get_dl_robust(args, no_aug=False):
    from data.ds import Robust_ds
    ds = Robust_ds(args)
    return DataLoader(
        ds, batch_size=args.bsz, shuffle=False)
'''


def _get_dl(args, is_train, shuffle, no_aug, is_uc=False):
    sdt = args.dataset
    bsz = args.bsz if is_train else args.bsz_vl
    tr_vl = 'tr' if is_train and not no_aug else 'vl'
    # TODO : download=False if the dataset is already downloaded
    # download = False if os.path.exists(data_path[sdt]) else True
    if sdt == 'svhn':
        s = ['test', 'train'][is_train]
        dt = cDT[sdt](root=data_path[sdt], download=True,
                      transform=transform(args, sdt, tr_vl),
                      split=s)
    elif sdt == 'imagenet':
        pth = 'train/' if is_train else 'val/'
        dt = cDT[sdt](root=data_path[sdt] + pth,
                      transform=transform(args, sdt, tr_vl))
    elif sdt == 'mnist' or sdt == 'cifar10' or sdt == 'fmnist':
        pth = 'train/' if is_train else 'val/'
        if args.stability and tr_vl == 'tr' and args.z_init == 'use_db':
            dt = dataset_with_indices(cDT[sdt])(
                root=data_path[sdt]+pth, train=is_train,
                download=False, transform=transform(args, sdt, tr_vl))
        else:
            dt = cDT[sdt](root=data_path[sdt]+pth, train=is_train,
                download=True, transform=transform(args, sdt, tr_vl))
    elif sdt == 'celeba':
        pth = 'train/' if is_train else 'val/'
        split = 'train' if is_train else 'valid'
        # can't download if Google Drive quota is exceeded, need to try again later
        # https://github.com/pytorch/vision/pull/1920 (2262)
        dt = cDT[sdt](root=data_path[sdt], split=split, download=False,
                      transform=transform(args, sdt, tr_vl))
    if args.debug:
        dt.data = dt.data[:64]
    if (sdt == 'cifar10h' or sdt == 'cifar10H') and not is_train:
        human_labels = np.load(os.path.expanduser(os.path.join(data_path[sdt],
                             'cifar10h-probs.npy'))).argmax(-1);
        dt.targets = human_labels
    dl = torch.utils.data.DataLoader(
        dt,
        drop_last=True,
        sampler=None,
        batch_size=bsz,
        num_workers=0,
        pin_memory=True,
        shuffle=shuffle,
    )
    print(f"Data loaded with {len(dt)} {tr_vl} imgs.")
    return len(dt), dl


def cycle(loader):
    while True:
        for data in loader:
            yield data


def dataset_with_indices(cls):
    """
    Modifies the given Dataset class to return a tuple data, target, index
    instead of just data, target.
    """

    def __getitem__(self, index):
        data, target = cls.__getitem__(self, index)
        return data, target, index

    return type(cls.__name__, (cls,), {
        '__getitem__': __getitem__,
    })


if __name__ == '__main__':
    from collections import namedtuple
    # attack, x_noise, x_sigma, dataset, batch size, validation batch size, debug, balanced sampling
    Args = namedtuple('Args', ['at', 'x_noise', 'x_sigma', 'dataset',
                            'bsz', 'bsz_vl', 'debug', 'bs', 'stability', 'vl_corrupt_ns'])
    args = Args(at=False, x_noise=False, x_sigma=0, dataset='fmnist',
                bsz=64, bsz_vl=256, debug=False, bs=False, stability=True, vl_corrupt_ns=False)
    tr_dl = get_dl_tr(args)
    vl_dl = get_dl_vl(args)
