from torchvision.transforms import transforms
from PIL import ImageFilter,ImageOps
import random
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from moco.augmentations import RandAugment2
import torch

class GaussianBlur_imagenet(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x

class solarize2(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, threshold=128):
        self.threshold = threshold

    def __call__(self, x):
        # print('------')
        # print(x.getpixel((100, 100)))
        x = ImageOps.solarize(x, threshold=self.threshold)
        # print(x.getpixel((100, 100)))
        return x


def get_moco2_cifar_transforms(input_shape, mean, std):
    # get a set of data augmentation transformations as described in the SimCLR paper.
    normalize = transforms.Normalize(mean=mean, std=std)

    augmentation1 = [
        RandAugment2(),
        transforms.RandomResizedCrop(input_shape, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur_imagenet([.1, 2.])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ]
    augmentation2 = [
        transforms.RandomResizedCrop(input_shape, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur_imagenet([.1, 2.])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ]
    return augmentation1, augmentation2


def get_test_transforms(mean, std):
    # get a set of data augmentation transformations as described in the SimCLR paper.
    normalize = transforms.Normalize(mean=mean, std=std)
    augmentation_train = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])

    augmentation_test = transforms.Compose([
        transforms.ToTensor(),
        normalize,
        ])
    return augmentation_train, augmentation_test

class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform1, base_transform2):
        self.base_transform1 = base_transform1
        self.base_transform2 = base_transform2

    def __call__(self, x):
        q = self.base_transform1(x)
        k = self.base_transform2(x)
        return [q, k]

def load_data(data_path, batch_size):
    mean, std = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
    input_shape, input_shape_all = 32, [32, 32, 3]
    ## prepare data,
    augmentation1, augmentation2 = get_moco2_cifar_transforms(input_shape, mean, std)

    # pdb.set_trace()
    _, augmentation_test = get_test_transforms(mean, std)


    train_dataset = datasets.CIFAR10(data_path, train=True, download=True,
                                     transform=TwoCropsTransform(transforms.Compose(augmentation1), transforms.Compose(augmentation2)))
    memory_dataset = datasets.CIFAR10(data_path, train=True, download=True, transform=augmentation_test)
    test_dataset = datasets.CIFAR10(data_path, train=False, download=True, transform=augmentation_test)
    print('-----------------------------------------------------------------------')
    print('data size %d'%(len(train_dataset)))
    print('memory data size %d'%(len(memory_dataset)))
    print('test datas size %d'%(len(test_dataset)))
    print('-----------------------------------------------------------------------')

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=16, pin_memory=True, sampler=train_sampler, drop_last=True)

    memory_loader = torch.utils.data.DataLoader(
        memory_dataset, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True, )

    return train_loader, memory_loader, test_loader, input_shape_all


def load_linear_train_data(data_path, batch_size):
    mean, std = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]

    augmentation_train, augmentation_test = get_test_transforms(mean, std)

    train_dataset = datasets.CIFAR10(data_path, train=True, download=True,
                                     transform=augmentation_train)
    test_dataset = datasets.CIFAR10(data_path, train=False, download=True,
                                    transform=augmentation_test)
    train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
        num_workers=16, pin_memory=True, sampler=train_sampler, drop_last=True)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True,)
    return train_loader, test_loader
