import random
import torch
import os
from PIL import Image, ImageOps, ImageFilter
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100
import torchvision.datasets as datasets
import torch.utils.data as data
from glob import glob

class GaussianBlur:
    def __init__(self, sigma = [0.1, 2.0]):
        self.sigma = sigma

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x

class Solarization:
    def __call__(self, img: Image) -> Image:
        return ImageOps.solarize(img)

##For CIFAR10 and CIFAR-100(32x32) 
class CIFARPairTransform:
    def __init__(self, train_transform = True):
        if train_transform is True:
            self.transform_1 = transforms.Compose([
                transforms.RandomResizedCrop((32, 32), scale=(0.08, 1.0), interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.2, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomApply([GaussianBlur()], p=0.0),
                transforms.RandomApply([Solarization()], p=0.0),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
            ])
            self.transform_2 = transforms.Compose([
                transforms.RandomResizedCrop((32, 32), scale=(0.08, 1.0), interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.2, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomApply([GaussianBlur()], p=0.0),
                transforms.RandomApply([Solarization()], p=0.2),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
            ])
        else:
            self.transform_1 = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                    ])
            self.transform_2 = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                    ])

    def __call__(self, x):
        y1 = self.transform_1(x)
        y2 = self.transform_2(x)
        return y1, y2


###from Zero-CL (https://github.com/Sherrylone/Zero-CL) slightly better than CIFARSingleTransform2 for our proposed model during linear evaluation
class CIFARSingleTransform:
    def __init__(self, train_transform = True):
        if train_transform is True:
            self.transform = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
        else:
            self.transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

    def __call__(self, x):
        y = self.transform(x)
        return y

class CIFARSingleTransform2:
    def __init__(self, train_transform = True):
        if train_transform is True:
            self.transform = transforms.Compose([
                    transforms.RandomResizedCrop((32, 32), scale=(0.08, 1.0)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
        else:
            self.transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

    def __call__(self, x):
        y = self.transform(x)
        return y
