import torch
import torchvision
from torch.utils.data import SubsetRandomSampler, Sampler, Subset, ConcatDataset
from sklearn.model_selection import StratifiedShuffleSplit
from torchvision import transforms
from PIL import Image
from operation import Lighting
import os
import numpy as np
import argparse
import pandas as pd
from ast import literal_eval


def to_categorical(y, num_classes=None):
    """Converts a class vector (integers) to binary class matrix.
    E.g. for use with categorical_crossentropy.
    # Arguments
        y: class vector to be converted into a matrix
            (integers from 0 to num_classes).
        num_classes: total number of classes.
    # Returns
        A binary matrix representation of the input. The classes axis
        is placed last.
    """
    y = np.array(y, dtype='int')
    input_shape = y.shape
    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
        input_shape = tuple(input_shape[:-1])
    y = y.ravel()
    if not num_classes:
        num_classes = np.max(y) + 1
    n = y.shape[0]
    categorical = np.zeros((n, num_classes), dtype=np.float32)
    categorical[np.arange(n), y] = 1
    output_shape = input_shape + (num_classes,)
    categorical = np.reshape(categorical, output_shape)
    return categorical


class CutoutDefault(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img


class SubsetSampler(Sampler):
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return (i for i in self.indices)

    def __len__(self):
        return len(self.indices)


class GenDataset(torch.utils.data.Dataset):
    def __init__(self, dir):
        self.dir = dir
        self.df = pd.read_csv(os.path.join(dir, 'train.csv'), sep=',',
                              converters={'label': literal_eval})

    def __getitem__(self, index):
        img_id = self.df.iloc[index].id
        target = self.df.iloc[index].label
        img = Image.open(os.path.join(self.dir, f'{img_id}.JPG')).convert('RGB')
        return img, np.array(target, dtype=np.float32)

    def __len__(self):
        return len(self.df)


class OnehotDatasetWarper(torch.utils.data.Dataset):
    def __init__(self, dataset, n_class):
        super(OnehotDatasetWarper, self).__init__()
        self.dataset = dataset
        self.n_class = n_class

    def __getitem__(self, index):
        img, target = self.dataset.__getitem__(index)
        target = to_categorical(target, self.n_class)
        return img, target

    def __len__(self):
        return self.dataset.__len__()


class AugmentDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, pre_transforms, after_transforms, valid_transforms, training=False, use_autoDA=False):
        super(AugmentDataset, self).__init__()
        self.dataset = dataset
        self.pre_transforms = pre_transforms
        self.after_transforms = after_transforms
        self.valid_transforms = valid_transforms
        self.use_autoDA = use_autoDA
        self.training = training

    def __getitem__(self, index):
        if self.training:  # self.pre_transform_only:
            # start_time = time.time()
            img, target = self.dataset.__getitem__(index)
            img = self.pre_transforms(img)
            if self.use_autoDA:
                img = transforms.ToTensor()(img)
                return img, target
            img = self.after_transforms(img)
            return img, target
        else:
            img, target = self.dataset.__getitem__(index)
            if self.valid_transforms is not None:
                img = self.valid_transforms(img)
            return img, target

    def __len__(self):
        return self.dataset.__len__()


_CIFAR10_MEAN, _CIFAR10_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
_CIFAR100_MEAN, _CIFAR100_STD = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
_TINT_IMGNET_MEAN, _TINT_IMGNET_STD = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)


def get_num_class(dataset):
    return {
        'cifar10': 10,
        'aug_reduced_cifar10': 10,
        'reduced_cifar10': 10,
        'cifar10.1': 10,
        'cifar100': 100,
        'reduced_cifar100': 100,
        'svhn': 10,
        'reduced_svhn': 10,
        'imagenet': 1000,
        'reduced_imagenet': 120,
        'tiny_imagenet': 200,
        'pet': 37,
        'car': 196,
        'flower': 102,
        'caltech': 101,
        'aircraft': 102,
        'imagenet_search': 1000
    }[dataset]


def get_img_size(dataset):
    return {
        'aug_reduced_cifar10': 32,
        'cifar10': 32,
        'reduced_cifar10': 32,
        'cifar100': 32,
        'reduced_cifar100': 32,
        'svhn': 32,
        'reduced_svhn': 32,
        'imagenet': 224,
        'tiny_imagenet': 64,
        'pet': 224,
        'car': 224,
        'flower': 224,
        'caltech': 224,
        'aircraft': 224,
        'imagenet_search': 128
    }[dataset]


def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


def get_label_name(dset, dataroot):
    if dset in ['cifar_svhn', 'reduced_cifar_svhn']:
        meta = unpickle(f'{dataroot}/cifar-10-batches-py/batches.meta')
        classes = [t.decode('utf8') for t in meta[b'label_names']]
    elif 'cifar100' in dset:
        meta = unpickle(f'{dataroot}/cifar-100-python/meta')
        classes = [t.decode('utf8') for t in meta[b'fine_label_names']]
    elif 'cifar10' in dset:
        meta = unpickle(f'{dataroot}/cifar-10-batches-py/batches.meta')
        classes = [t.decode('utf8') for t in meta[b'label_names']]
    elif 'svhn' in dset:
        classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    elif 'emnist' in dset:
        classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                   'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
                   'a', 'b', 'd', 'e', 'f', 'g', 'h', 'n', 'q', 'r', 't']
    elif 'mnist' in dset:
        classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    else:
        class_idxs = np.arange(0, get_num_class(dset))
        classes = [str(i) for i in class_idxs]
    return classes


def get_dataloaders(dataset, batch, num_workers, dataroot, cutout, cutout_length, train_ratio=1, split_idx=0, target_lb=-1, use_autoDA=False, crop_size=224):
    if 'cifar100' in dataset:
        transform_train_pre = transforms.Compose([
            transforms.RandomCrop(32, padding=4, padding_mode='edge'),
            transforms.RandomHorizontalFlip(),
        ])
        transform_train_after = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR100_MEAN, _CIFAR100_STD),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR100_MEAN, _CIFAR100_STD),
        ])
    elif 'cifar10' in dataset:
        transform_train_pre = transforms.Compose([
            transforms.RandomCrop(32, padding=4, padding_mode='edge'),
            transforms.RandomHorizontalFlip()
        ])
        transform_train_after = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR10_MEAN, _CIFAR10_STD),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR10_MEAN, _CIFAR10_STD),
        ])
    elif 'tiny_imagenet' == dataset:
        transform_train_pre = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(64, padding=4, padding_mode='edge'),
        ])
        transform_train_after = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(_TINT_IMGNET_MEAN, _TINT_IMGNET_STD)
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(_TINT_IMGNET_MEAN, _TINT_IMGNET_STD),
        ])
    else:
        raise ValueError('dataset=%s' % dataset)

    if cutout and cutout_length != 0:
        transform_train_after.transforms.append(CutoutDefault(cutout_length))

    if dataset == 'cifar10':
        total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=None)
        testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=None)
    elif dataset == 'cifar100':
        total_trainset = torchvision.datasets.CIFAR100(root=dataroot, train=True, download=True, transform=None)
        testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=None)
    elif dataset == 'tiny_imagenet':
        train_root = os.path.join(dataroot, 'tiny-imagenet-200', 'train')  # this is path to training images folder
        validation_root = os.path.join(dataroot, 'tiny-imagenet-200', 'val')  # this is path to validation images folder
        total_trainset = torchvision.datasets.ImageFolder(train_root, transform=None)
        testset = torchvision.datasets.ImageFolder(validation_root, transform=None)
    else:
        raise ValueError('invalid dataset name=%s' % dataset)

    if train_ratio < 1.0:
        shuffle = False
        sss = StratifiedShuffleSplit(n_splits=5, test_size=1-train_ratio, random_state=0)
        sss = sss.split(list(range(len(total_trainset))), total_trainset.targets)
        for _ in range(split_idx + 1):
            train_idx, valid_idx = next(sss)

        if target_lb >= 0:
            train_idx = [i for i in train_idx if total_trainset.targets[i] == target_lb]
            valid_idx = [i for i in valid_idx if total_trainset.targets[i] == target_lb]

        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetSampler(valid_idx)
    else:
        train_sampler = None
        shuffle = True
        valid_sampler = SubsetSampler([])

    train_data = AugmentDataset(total_trainset, transform_train_pre, transform_train_after, transform_test, training=True, use_autoDA=use_autoDA)
    valid_data = AugmentDataset(total_trainset, transform_train_pre, transform_train_after, transform_test, False)
    test_data = AugmentDataset(testset, transform_train_pre, transform_train_after, transform_test, False)

    trainloader = torch.utils.data.DataLoader(
        train_data, batch_size=batch, shuffle=shuffle,
        sampler=train_sampler, drop_last=False,
        pin_memory=True, num_workers=num_workers)

    validloader = torch.utils.data.DataLoader(
        valid_data, batch_size=batch, shuffle=False,
        sampler=valid_sampler, drop_last=False,
        pin_memory=True, num_workers=num_workers)

    testloader = torch.utils.data.DataLoader(
        test_data, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True,
        drop_last=False
    )

    print(f'Dataset: {dataset}')
    print(f'  |total: {len(train_data)}')
    print(f'  |train: {len(trainloader)*batch}')
    print(f'  |valid: {len(validloader)*batch}')
    print(f'  |test: {len(testloader)*batch}')

    return trainloader, validloader, testloader


def main():
    pass


if __name__ == '__main__':
    main()
