import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
from .data_path import global_data_path


def create_reader_one(dataset, batch_size, shuffle=True, drop_last=False, num_worker=1):
    from torch.utils.data import DataLoader
    kwargs = {'num_workers': num_worker, 'pin_memory': True}
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
                             **kwargs)
    return data_loader


def get_cifar_datasets(data_augment, dataset):
    normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
    if data_augment:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4,4,4,4), mode='reflect').squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])
    train_set = datasets.__dict__[dataset.upper()](global_data_path,
                                                   train=True,
                                                   download=True,
                                                   transform=transform_train)
    val_set = datasets.__dict__[dataset.upper()](global_data_path,
                                                 train=False,
                                                 transform=transform_test)
    return train_set, val_set


def cifar_reader(data_augment, dataset, batch_size, test_batch_size, num_worker=1):
    assert(dataset == 'cifar10' or dataset == 'cifar100')
    train_set, val_set = get_cifar_datasets(data_augment, dataset)
    train_loader = create_reader_one(train_set, batch_size, True, True, num_worker)
    val_loader = create_reader_one(val_set, test_batch_size, False, False, num_worker)

    return train_loader, val_loader


def cifar_reader_simple(dataset, batch_size, test_batch_size, num_worker=1):
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    from torch.utils.data import DataLoader
    kwargs = {'num_workers': num_worker, 'pin_memory': True}
    train_loader = DataLoader(
        datasets.__dict__[dataset.upper()](global_data_path, train=True, download=True, transform=transform),
        batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
    val_loader = DataLoader(
        datasets.__dict__[dataset.upper()](global_data_path, train=False, transform=transform),
        batch_size=test_batch_size, shuffle=False)

    return train_loader, val_loader


class TwinsLabelCIFAR10(datasets.CIFAR10):
    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=False):

        super(TwinsLabelCIFAR10, self).__init__(root, train, transform=transform,
                                                target_transform=target_transform,
                                                download=download)

        self.targets_t = self.targets
        self.targets_s = self.targets

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = super(TwinsLabelCIFAR10, self).__getitem__(index)

        target_s = self.targets_s[index]
        target_t = self.targets_t[index]
        if self.target_transform is not None:
            target_s = self.target_transform(target_s)
            target_t = self.target_transform(target_t)

        return img, target, target_s, target_t


def get_twinscifar10_datasets(data_augment):
    normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
    if data_augment:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4,4,4,4), mode='reflect').squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])
    train_set = TwinsLabelCIFAR10(global_data_path,
                                  train=True,
                                  download=True,
                                  transform=transform_train)
    val_set = TwinsLabelCIFAR10(global_data_path,
                                train=False,
                                transform=transform_test)
    return train_set, val_set


def get_twinscifar10_datasets_resize_to_256(data_augment):
    # the transform if from https://github.com/yukimasano/self-label
    if data_augment:
        transform_train = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    train_set = TwinsLabelCIFAR10(global_data_path,
                                  train=True,
                                  download=True,
                                  transform=transform_train)
    val_set = TwinsLabelCIFAR10(global_data_path,
                                train=False,
                                transform=transform_test)
    return train_set, val_set


def get_cifar_datasets_resize_to_256(data_augment, dataset='cifar10'):
    # the transform if from https://github.com/yukimasano/self-label
    if data_augment:
        transform_train = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    train_set = datasets.__dict__[dataset.upper()](global_data_path,
                                                   train=True,
                                                   download=True,
                                                   transform=transform_train)
    val_set = datasets.__dict__[dataset.upper()](global_data_path,
                                                 train=False,
                                                 transform=transform_test)
    return train_set, val_set


def get_cifar_datasets_with_validation(data_augment, dataset, val_ratio=0.5):
    from torch.utils.data import Subset
    # assert val_ratio in [0.2, 0.5]

    origin_train_set, test_set = get_cifar_datasets(data_augment, dataset)

    indices = list(range(len(origin_train_set)))
    val_set_num_samples = int(val_ratio * len(origin_train_set))
    train_set_indices = indices[:val_set_num_samples]
    val_set_indices = indices[val_set_num_samples:]

    train_set = Subset(origin_train_set, train_set_indices)
    val_set = Subset(origin_train_set, val_set_indices)

    return train_set, val_set, test_set


def split_dataset(origin_set, set1_indices, set2_indices):
    from torch.utils.data import Subset
    # assert val_ratio in [0.2, 0.5]

    set1 = Subset(origin_set, set1_indices)
    set2 = Subset(origin_set, set2_indices)

    return set1, set2