import torch
import random
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from PIL import Image, ImageEnhance, ImageOps


def load_data(batch_size, workers, dataset, data_target_dir, data_aug, cutout, autoaug):
    if dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif dataset == 'svhn':
        mean = [x / 255 for x in [109.9, 109.7, 113.8]]
        std = [x / 255 for x in [50.1, 50.6, 50.8]]
    else:
        assert False, f"Unknow dataset : {dataset}"

    if data_aug:
        if dataset == 'svhn':
            if autoaug:
                transform_list = [SVHNPolicy()]
            else:
                transform_list = [transforms.RandomCrop(32, padding=2, fill=128)]
            transform_list.append(transforms.ToTensor())
            if cutout or autoaug:
                transform_list.append(Cutout(n_holes=1, length=20))
            transform_list.append(transforms.Normalize(mean, std))
            train_transform = transforms.Compose(transform_list)
            test_transform = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize(mean, std)
                        ])
        elif dataset == 'cifar10' or dataset == 'cifar100':
            transform_list = [transforms.RandomCrop(32, padding=4, fill=128)]
            transform_list.append(transforms.RandomHorizontalFlip())
            if autoaug:
                transform_list.append(CIFAR10Policy())
            transform_list.append(transforms.ToTensor())
            if cutout or autoaug:
                transform_list.append(Cutout(n_holes=1, length=(16 if dataset == 'cifar10' else 8)))
            transform_list.append(transforms.Normalize(mean, std))
            train_transform = transforms.Compose(transform_list)
            test_transform = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize(mean, std)
                        ])
        else:
            raise NotImplementedError
    else:
        train_transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean, std)
                    ])
        test_transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean, std)
                    ])
    if dataset == 'cifar10':
        train_data = datasets.CIFAR10(data_target_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR10(data_target_dir, train=False, transform=test_transform, download=True)
    elif dataset == 'cifar100':
        train_data = datasets.CIFAR100(data_target_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR100(data_target_dir, train=False, transform=test_transform, download=True)
    elif dataset == 'svhn':
        train_data = datasets.SVHN(data_target_dir, split='train', transform=train_transform, download=True)
        test_data = datasets.SVHN(data_target_dir, split='test', transform=test_transform, download=True)
        # assert False, 'Do not support dataset : {}'.format(dataset)

    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)

    return train_dataloader, test_dataloader


class Cutout:

    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        ''' img: Tensor image of size (C, H, W) '''
        _, h, w = img.size()
        mask = np.ones((h, w), np.float32)
        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)
            y1 = int(np.clip(y - self.length // 2, 0, h))
            y2 = int(np.clip(y + self.length // 2, 0, h))
            x1 = int(np.clip(x - self.length // 2, 0, w))
            x2 = int(np.clip(x + self.length // 2, 0, w))
            mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask
        return img


class ImageNetPolicy:
    ''' Randomly choose one of the best 25 Sub-policies on ImageNet. '''
    def __init__(self, fillcolor=(128, 128, 128)):
        self.policies = [
            SubPolicy(0.4, 'posterize', 8, 0.6, 'rotate', 9, fillcolor),
            SubPolicy(0.6, 'solarize', 5, 0.6, 'autocontrast', 5, fillcolor),
            SubPolicy(0.8, 'equalize', 8, 0.6, 'equalize', 3, fillcolor),
            SubPolicy(0.6, 'posterize', 7, 0.6, 'posterize', 6, fillcolor),
            SubPolicy(0.4, 'equalize', 7, 0.2, 'solarize', 4, fillcolor),

            SubPolicy(0.4, 'equalize', 4, 0.8, 'rotate', 8, fillcolor),
            SubPolicy(0.6, 'solarize', 3, 0.6, 'equalize', 7, fillcolor),
            SubPolicy(0.8, 'posterize', 5, 1.0, 'equalize', 2, fillcolor),
            SubPolicy(0.2, 'rotate', 3, 0.6, 'solarize', 8, fillcolor),
            SubPolicy(0.6, 'equalize', 8, 0.4, 'posterize', 6, fillcolor),

            SubPolicy(0.8, 'rotate', 8, 0.4, 'color', 0, fillcolor),
            SubPolicy(0.4, 'rotate', 9, 0.6, 'equalize', 2, fillcolor),
            SubPolicy(0.0, 'equalize', 7, 0.8, 'equalize', 8, fillcolor),
            SubPolicy(0.6, 'invert', 4, 1.0, 'equalize', 8, fillcolor),
            SubPolicy(0.6, 'color', 4, 1.0, 'contrast', 8, fillcolor),

            SubPolicy(0.8, 'rotate', 8, 1.0, 'color', 2, fillcolor),
            SubPolicy(0.8, 'color', 8, 0.8, 'solarize', 7, fillcolor),
            SubPolicy(0.4, 'sharpness', 7, 0.6, 'invert', 8, fillcolor),
            SubPolicy(0.6, 'shearX', 5, 1.0, 'equalize', 9, fillcolor),
            SubPolicy(0.4, 'color', 0, 0.6, 'equalize', 3, fillcolor),

            SubPolicy(0.4, 'equalize', 7, 0.2, 'solarize', 4, fillcolor),
            SubPolicy(0.6, 'solarize', 5, 0.6, 'autocontrast', 5, fillcolor),
            SubPolicy(0.6, 'invert', 4, 1.0, 'equalize', 8, fillcolor),
            SubPolicy(0.6, 'color', 4, 1.0, 'contrast', 8, fillcolor),
            SubPolicy(0.8, 'equalize', 8, 0.6, 'equalize', 3, fillcolor)
        ]

    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

    def __repr__(self):
        return 'AutoAugment ImageNet Policy'


class CIFAR10Policy:
    ''' Randomly choose one of the best 25 Sub-policies on CIFAR10. '''
    def __init__(self, fillcolor=(128, 128, 128)):
        self.policies = [
            SubPolicy(0.1, 'invert', 7, 0.2, 'contrast', 6, fillcolor),
            SubPolicy(0.7, 'rotate', 2, 0.3, 'translateX', 9, fillcolor),
            SubPolicy(0.8, 'sharpness', 1, 0.9, 'sharpness', 3, fillcolor),
            SubPolicy(0.5, 'shearY', 8, 0.7, 'translateY', 9, fillcolor),
            SubPolicy(0.5, 'autocontrast', 8, 0.9, 'equalize', 2, fillcolor),

            SubPolicy(0.2, 'shearY', 7, 0.3, 'posterize', 7, fillcolor),
            SubPolicy(0.4, 'color', 3, 0.6, 'brightness', 7, fillcolor),
            SubPolicy(0.3, 'sharpness', 9, 0.7, 'brightness', 9, fillcolor),
            SubPolicy(0.6, 'equalize', 5, 0.5, 'equalize', 1, fillcolor),
            SubPolicy(0.6, 'contrast', 7, 0.6, 'sharpness', 5, fillcolor),

            SubPolicy(0.7, 'color', 7, 0.5, 'translateX', 8, fillcolor),
            SubPolicy(0.3, 'equalize', 7, 0.4, 'autocontrast', 8, fillcolor),
            SubPolicy(0.4, 'translateY', 3, 0.2, 'sharpness', 6, fillcolor),
            SubPolicy(0.9, 'brightness', 6, 0.2, 'color', 8, fillcolor),
            SubPolicy(0.5, 'solarize', 2, 0.0, 'invert', 3, fillcolor),

            SubPolicy(0.2, 'equalize', 0, 0.6, 'autocontrast', 0, fillcolor),
            SubPolicy(0.2, 'equalize', 8, 0.6, 'equalize', 4, fillcolor),
            SubPolicy(0.9, 'color', 9, 0.6, 'equalize', 6, fillcolor),
            SubPolicy(0.8, 'autocontrast', 4, 0.2, 'solarize', 8, fillcolor),
            SubPolicy(0.1, 'brightness', 3, 0.7, 'color', 0, fillcolor),

            SubPolicy(0.4, 'solarize', 5, 0.9, 'autocontrast', 3, fillcolor),
            SubPolicy(0.9, 'translateY', 9, 0.7, 'translateY', 9, fillcolor),
            SubPolicy(0.9, 'autocontrast', 2, 0.8, 'solarize', 3, fillcolor),
            SubPolicy(0.8, 'equalize', 8, 0.1, 'invert', 3, fillcolor),
            SubPolicy(0.7, 'translateY', 9, 0.9, 'autocontrast', 1, fillcolor)
        ]

    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

    def __repr__(self):
        return 'AutoAugment CIFAR10 Policy'


class SVHNPolicy:
    ''' Randomly choose one of the best 25 Sub-policies on SVHN. '''
    def __init__(self, fillcolor=(128, 128, 128)):
        self.policies = [
            SubPolicy(0.9, 'shearX', 4, 0.2, 'invert', 3, fillcolor),
            SubPolicy(0.9, 'shearY', 8, 0.7, 'invert', 5, fillcolor),
            SubPolicy(0.6, 'equalize', 5, 0.6, 'solarize', 6, fillcolor),
            SubPolicy(0.9, 'invert', 3, 0.6, 'equalize', 3, fillcolor),
            SubPolicy(0.6, 'equalize', 1, 0.9, 'rotate', 3, fillcolor),

            SubPolicy(0.9, 'shearX', 4, 0.8, 'autocontrast', 3, fillcolor),
            SubPolicy(0.9, 'shearY', 8, 0.4, 'invert', 5, fillcolor),
            SubPolicy(0.9, 'shearY', 5, 0.2, 'solarize', 6, fillcolor),
            SubPolicy(0.9, 'invert', 6, 0.8, 'autocontrast', 1, fillcolor),
            SubPolicy(0.6, 'equalize', 3, 0.9, 'rotate', 3, fillcolor),

            SubPolicy(0.9, 'shearX', 4, 0.3, 'solarize', 3, fillcolor),
            SubPolicy(0.8, 'shearY', 8, 0.7, 'invert', 4, fillcolor),
            SubPolicy(0.9, 'equalize', 5, 0.6, 'translateY', 6, fillcolor),
            SubPolicy(0.9, 'invert', 4, 0.6, 'equalize', 7, fillcolor),
            SubPolicy(0.3, 'contrast', 3, 0.8, 'rotate', 4, fillcolor),

            SubPolicy(0.8, 'invert', 5, 0.0, 'translateY', 2, fillcolor),
            SubPolicy(0.7, 'shearY', 6, 0.4, 'solarize', 8, fillcolor),
            SubPolicy(0.6, 'invert', 4, 0.8, 'rotate', 4, fillcolor),
            SubPolicy(0.3, 'shearY', 7, 0.9, 'translateX', 3, fillcolor),
            SubPolicy(0.1, 'shearX', 6, 0.6, 'invert', 5, fillcolor),

            SubPolicy(0.7, 'solarize', 2, 0.6, 'translateY', 7, fillcolor),
            SubPolicy(0.8, 'shearY', 4, 0.8, 'invert', 8, fillcolor),
            SubPolicy(0.7, 'shearX', 9, 0.8, 'translateY', 3, fillcolor),
            SubPolicy(0.8, 'shearY', 5, 0.7, 'autocontrast', 3, fillcolor),
            SubPolicy(0.7, 'shearX', 2, 0.1, 'invert', 5, fillcolor)
        ]

    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

    def __repr__(self):
        return 'AutoAugment SVHN Policy'


class SubPolicy:

    def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
        ranges = {
            'shearX': np.linspace(0, 0.3, 10),
            'shearY': np.linspace(0, 0.3, 10),
            'translateX': np.linspace(0, 150 / 331, 10),
            'translateY': np.linspace(0, 150 / 331, 10),
            'rotate': np.linspace(0, 30, 10),
            'color': np.linspace(0.0, 0.9, 10),
            'posterize': np.round(np.linspace(8, 4, 10), 0).astype(np.int),
            'solarize': np.linspace(256, 0, 10),
            'contrast': np.linspace(0.0, 0.9, 10),
            'sharpness': np.linspace(0.0, 0.9, 10),
            'brightness': np.linspace(0.0, 0.9, 10),
            'autocontrast': [0] * 10,
            'equalize': [0] * 10,
            'invert': [0] * 10
        }

        def rotate_with_fill(img, magnitude):
            rot = img.convert('RGBA').rotate(magnitude)
            return Image.composite(rot, Image.new('RGBA', rot.size, (128,) * 4), rot).convert(img.mode)

        func = {
            'shearX': lambda img, magnitude: img.transform(
                img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
                Image.BICUBIC, fillcolor=fillcolor),
            'shearY': lambda img, magnitude: img.transform(
                img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
                Image.BICUBIC, fillcolor=fillcolor),
            'translateX': lambda img, magnitude: img.transform(
                img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
                fillcolor=fillcolor),
            'translateY': lambda img, magnitude: img.transform(
                img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
                fillcolor=fillcolor),
            'rotate': lambda img, magnitude: rotate_with_fill(img, magnitude),
            'color': lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
            'posterize': lambda img, magnitude: ImageOps.posterize(img, magnitude),
            'solarize': lambda img, magnitude: ImageOps.solarize(img, magnitude),
            'contrast': lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
                1 + magnitude * random.choice([-1, 1])),
            'sharpness': lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
                1 + magnitude * random.choice([-1, 1])),
            'brightness': lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
                1 + magnitude * random.choice([-1, 1])),
            'autocontrast': lambda img, magnitude: ImageOps.autocontrast(img),
            'equalize': lambda img, magnitude: ImageOps.equalize(img),
            'invert': lambda img, magnitude: ImageOps.invert(img)
        }

        self.p1 = p1
        self.operation1 = func[operation1]
        self.magnitude1 = ranges[operation1][magnitude_idx1]
        self.p2 = p2
        self.operation2 = func[operation2]
        self.magnitude2 = ranges[operation2][magnitude_idx2]


    def __call__(self, img):
        if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
        if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
        return img
