from typing import Any, Callable, List, Optional, Sequence, Type, Union
import random
import torch
import torchvision
import numpy as np
from PIL import Image, ImageFilter, ImageOps
from torchvision import transforms

class GaussianBlur:
    def __init__(self, sigma: Sequence[float] = None):
        """Gaussian blur as a callable object.

        Args:
            sigma (Sequence[float]): range to sample the radius of the gaussian blur filter.
                Defaults to [0.1, 2.0].
        """

        if sigma is None:
            sigma = [0.1, 2.0]

        self.sigma = sigma

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """Applies gaussian blur to an input image.

        Args:
            x (torch.Tensor): an image in the tensor format.

        Returns:
            torch.Tensor: returns a blurred image.
        """

        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


class Solarization:
    """Solarization as a callable object."""

    def __call__(self, img: Image) -> Image:
        """Applies solarization to an input image.

        Args:
            img (Image): an image in the PIL.Image format.

        Returns:
            Image: a solarized image.
        """

        return ImageOps.solarize(img)

class FullTransformPipeline:
    def __init__(self, transforms: Callable) -> None:
        self.transforms = transforms

    def __call__(self, x: Image) -> List[torch.Tensor]:
        """Applies transforms n times to generate n crops.

        Args:
            x (Image): an image in the PIL.Image format.

        Returns:
            List[torch.Tensor]: an image in the tensor format.
        """

        out = []
        for transform in self.transforms:
            out.extend(transform(x))
        return out

    def __repr__(self) -> str:
        return "\n".join([str(transform) for transform in self.transforms])


class BaseTransform:
    """Adds callable base class to implement different transformation pipelines."""

    def __call__(self, x: Image) -> torch.Tensor:
        return [self.selfsup_transform(x), self.selfsup_transform2(x)]

    def __repr__(self) -> str:
        return str(self.selfsup_transform)

class AblationTransform:
    """Adds callable base class to implement different transformation pipelines."""

    def __call__(self, x: Image) -> torch.Tensor:
        return [self.selfsup_transform(x), self.selfsup_transform2(x), self.simple_tranform(x)]

    def __repr__(self) -> str:
        return str(self.selfsup_transform)



class CifarTransform(BaseTransform):
    def __init__(
        self,
        cifar: str,
        brightness: float,
        contrast: float,
        saturation: float,
        hue: float,
        color_jitter_prob: float = 0.8,
        gray_scale_prob: float = 0.2,
        horizontal_flip_prob: float = 0.5,
        gaussian_prob: float = 0.0,
        solarization_prob: float = 0.0,
        min_scale: float = 0.08,
        max_scale: float = 1.0,
        crop_size: int = 32,
        adv_train: bool = False,
    ):
        """Class that applies Cifar10/Cifar100 transformations.

        Args:
            cifar (str): type of cifar, either cifar10 or cifar100.
            brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness].
            contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast].
            saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation].
            hue (float): sampled uniformly in [-hue, hue].
            color_jitter_prob (float, optional): probability of applying color jitter.
                Defaults to 0.8.
            gray_scale_prob (float, optional): probability of converting to gray scale.
                Defaults to 0.2.
            horizontal_flip_prob (float, optional): probability of flipping horizontally.
                Defaults to 0.5.
            gaussian_prob (float, optional): probability of applying gaussian blur.
                Defaults to 0.0.
            solarization_prob (float, optional): probability of applying solarization.
                Defaults to 0.0.
            min_scale (float, optional): minimum scale of the crops. Defaults to 0.08.
            max_scale (float, optional): maximum scale of the crops. Defaults to 1.0.
            crop_size (int, optional): size of the crop. Defaults to 32.
        """

        super().__init__()

        print('CIFAR train and val dataset transform except normalize')
        self.selfsup_transform = transforms.Compose(
                [
                    transforms.RandomResizedCrop(
                        (crop_size, crop_size),
                        scale=(min_scale, max_scale),
                        interpolation=transforms.InterpolationMode.BICUBIC,
                    ),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness, contrast, saturation, hue)],
                        p=color_jitter_prob,
                    ),
                    transforms.RandomGrayscale(p=gray_scale_prob),
                    transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
                    transforms.RandomApply([Solarization()], p=solarization_prob),
                    transforms.RandomHorizontalFlip(p=horizontal_flip_prob),
                    transforms.ToTensor(),
                ]
            )
        self.selfsup_transform2 =  transforms.Compose(
                [
                    transforms.RandomResizedCrop(
                        (crop_size, crop_size),
                        scale=(min_scale, max_scale),
                        interpolation=transforms.InterpolationMode.BICUBIC,
                    ),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness, contrast, saturation, hue)],
                        p=color_jitter_prob,
                    ),
                    transforms.RandomGrayscale(p=gray_scale_prob),
                    transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
                    transforms.RandomApply([Solarization()], p=0.2),
                    transforms.RandomHorizontalFlip(p=horizontal_flip_prob),
                    transforms.ToTensor(),
                ]
            )

class SimpleTransform(BaseTransform):
    def __init__(
        self,
        adv_train: bool = False,
        min_scale: float = 0.08,
        max_scale: float = 1.0,
        crop_size: int = 32,
    ):
        super().__init__()
        self.selfsup_transform = transforms.Compose(
                [
                    transforms.RandomResizedCrop(
                        (crop_size, crop_size),
                        scale=(min_scale, max_scale),
                        interpolation=transforms.InterpolationMode.BICUBIC,
                    ),
                    transforms.RandomHorizontalFlip(p=horizontal_flip_prob),
                    transforms.ToTensor(),
                ])
        self.selfsup_transform2 = transforms.Compose(
                [
                    transforms.RandomResizedCrop(
                        (crop_size, crop_size),
                        scale=(min_scale, max_scale),
                        interpolation=transforms.InterpolationMode.BICUBIC,
                    ),
                    transforms.RandomHorizontalFlip(p=horizontal_flip_prob),
                    transforms.ToTensor(),
                ])


class NoneTransform(BaseTransform):
    def __init__(
        self,
        adv_train: bool = False,
        crop_size: int = 32,
    ):
        super().__init__()
        self.selfsup_transform = transforms.Compose(
                [
                    transforms.Resize(crop_size),
                    transforms.ToTensor(),
                ])
        self.selfsup_transform2 = transforms.Compose(
                [
                    transforms.Resize(crop_size),
                    transforms.ToTensor(),
                ])

class RandAugTransform(BaseTransform):      
    def __init__(
        self,
        adv_train: bool = False,
    ):
        super().__init__()
        self.selfsup_transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                ])
        self.selfsup_transform2 = transforms.Compose(
                [
                    transforms.ToTensor(),
                ])  
        from randaug import RandAugment
        self.selfsup_transform.transforms.insert(0, RandAugment(2,9))
        self.selfsup_transform2.transforms.insert(0, RandAugment(2,9))


class CifarFSTrasnform(BaseTransform):
    def __init__(
        self,
        brightness: float,
        contrast: float,
        saturation: float,
        hue: float,
        color_jitter_prob: float = 0.8,
        gray_scale_prob: float = 0.2,
        horizontal_flip_prob: float = 0.5,
        gaussian_prob: float = 0.0,
        solarization_prob: float = 0.0,
        min_scale: float = 0.08,
        max_scale: float = 1.0,
        crop_size: int = 32,
        padding: int = 4, 
        adv_train: bool = False,
    ):

        super().__init__()

       
        print('CIFAR_FS train and val dataset transform except normalize')
        self.selfsup_transform = transforms.Compose(
                [
                    transforms.RandomResizedCrop(
                        (crop_size, crop_size),
                        scale=(min_scale, max_scale),
                        interpolation=transforms.InterpolationMode.BICUBIC,
                    ),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness, contrast, saturation, hue)],
                        p=color_jitter_prob,
                    ),
                    transforms.RandomGrayscale(p=gray_scale_prob),
                    transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
                    transforms.RandomApply([Solarization()], p=solarization_prob),
                    transforms.RandomHorizontalFlip(p=horizontal_flip_prob),
                    transforms.ToTensor(),
                ]
            )
        self.selfsup_transform2 =  transforms.Compose(
                [
                    transforms.RandomResizedCrop(
                        (crop_size, crop_size),
                        scale=(min_scale, max_scale),
                        interpolation=transforms.InterpolationMode.BICUBIC,
                    ),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness, contrast, saturation, hue)],
                        p=color_jitter_prob,
                    ),
                    transforms.RandomGrayscale(p=gray_scale_prob),
                    transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
                    transforms.RandomApply([Solarization()], p=0.2),
                    transforms.RandomHorizontalFlip(p=horizontal_flip_prob),
                    transforms.ToTensor(),
                ]
            )
        

class AblationCifarTransform(AblationTransform):
    def __init__(
        self,
        cifar: str,
        brightness: float,
        contrast: float,
        saturation: float,
        hue: float,
        color_jitter_prob: float = 0.8,
        gray_scale_prob: float = 0.2,
        horizontal_flip_prob: float = 0.5,
        gaussian_prob: float = 0.0,
        solarization_prob: float = 0.0,
        min_scale: float = 0.08,
        max_scale: float = 1.0,
        crop_size: int = 32,
        adv_train: bool = False,
    ):
        """Class that applies Cifar10/Cifar100 transformations.

        Args:
            cifar (str): type of cifar, either cifar10 or cifar100.
            brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness].
            contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast].
            saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation].
            hue (float): sampled uniformly in [-hue, hue].
            color_jitter_prob (float, optional): probability of applying color jitter.
                Defaults to 0.8.
            gray_scale_prob (float, optional): probability of converting to gray scale.
                Defaults to 0.2.
            horizontal_flip_prob (float, optional): probability of flipping horizontally.
                Defaults to 0.5.
            gaussian_prob (float, optional): probability of applying gaussian blur.
                Defaults to 0.0.
            solarization_prob (float, optional): probability of applying solarization.
                Defaults to 0.0.
            min_scale (float, optional): minimum scale of the crops. Defaults to 0.08.
            max_scale (float, optional): maximum scale of the crops. Defaults to 1.0.
            crop_size (int, optional): size of the crop. Defaults to 32.
        """

        super().__init__()

        print('CIFAR train and val dataset transform except normalize')
        self.selfsup_transform = transforms.Compose(
                [
                    transforms.RandomResizedCrop(
                        (crop_size, crop_size),
                        scale=(min_scale, max_scale),
                        interpolation=transforms.InterpolationMode.BICUBIC,
                    ),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness, contrast, saturation, hue)],
                        p=color_jitter_prob,
                    ),
                    transforms.RandomGrayscale(p=gray_scale_prob),
                    transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
                    transforms.RandomApply([Solarization()], p=solarization_prob),
                    transforms.RandomHorizontalFlip(p=horizontal_flip_prob),
                    transforms.ToTensor(),
                ]
            )
        self.selfsup_transform2 =  transforms.Compose(
                [
                    transforms.RandomResizedCrop(
                        (crop_size, crop_size),
                        scale=(min_scale, max_scale),
                        interpolation=transforms.InterpolationMode.BICUBIC,
                    ),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness, contrast, saturation, hue)],
                        p=color_jitter_prob,
                    ),
                    transforms.RandomGrayscale(p=gray_scale_prob),
                    transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
                    transforms.RandomApply([Solarization()], p=0.2),
                    transforms.RandomHorizontalFlip(p=horizontal_flip_prob),
                    transforms.ToTensor(),
                ]
            )
        self.simple_tranform = transforms.Compose(
                [
                    transforms.Resize(crop_size),
                    transforms.ToTensor(),
                ])

class AblationSimpleTransform(AblationTransform):
    def __init__(
        self,
        adv_train: bool = False,
        min_scale: float = 0.08,
        max_scale: float = 1.0,
        crop_size: int = 32,
    ):
        super().__init__()
        self.selfsup_transform = transforms.Compose(
                [
                    transforms.RandomResizedCrop(
                        (crop_size, crop_size),
                        scale=(min_scale, max_scale),
                        interpolation=transforms.InterpolationMode.BICUBIC,
                    ),
                    transforms.RandomHorizontalFlip(p=horizontal_flip_prob),
                    transforms.ToTensor(),
                ])
        self.selfsup_transform2 = transforms.Compose(
                [
                    transforms.RandomResizedCrop(
                        (crop_size, crop_size),
                        scale=(min_scale, max_scale),
                        interpolation=transforms.InterpolationMode.BICUBIC,
                    ),
                    transforms.RandomHorizontalFlip(p=horizontal_flip_prob),
                    transforms.ToTensor(),
                ])
        self.simple_tranform = transforms.Compose(
                [
                    transforms.Resize(crop_size),
                    transforms.ToTensor(),
                ])


class AblationNoneTransform(AblationTransform):
    def __init__(
        self,
        adv_train: bool = False,
        crop_size: int = 32,
    ):
        super().__init__()
        self.selfsup_transform = transforms.Compose(
                [
                    transforms.Resize(crop_size),
                    transforms.ToTensor(),
                ])
        self.selfsup_transform2 = transforms.Compose(
                [
                    transforms.Resize(crop_size),
                    transforms.ToTensor(),
                ])
        self.simple_tranform = transforms.Compose(
                [
                    transforms.Resize(crop_size),
                    transforms.ToTensor(),
                ])

class AblationCifarFSTrasnform(AblationTransform):
    def __init__(
        self,
        brightness: float,
        contrast: float,
        saturation: float,
        hue: float,
        color_jitter_prob: float = 0.8,
        gray_scale_prob: float = 0.2,
        horizontal_flip_prob: float = 0.5,
        gaussian_prob: float = 0.0,
        solarization_prob: float = 0.0,
        min_scale: float = 0.08,
        max_scale: float = 1.0,
        crop_size: int = 32,
        padding: int = 4, 
        adv_train: bool = False,
    ):

        super().__init__()

       
        print('CIFAR_FS train and val dataset transform except normalize')
        self.selfsup_transform = transforms.Compose(
                [
                    transforms.RandomResizedCrop(
                        (crop_size, crop_size),
                        scale=(min_scale, max_scale),
                        interpolation=transforms.InterpolationMode.BICUBIC,
                    ),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness, contrast, saturation, hue)],
                        p=color_jitter_prob,
                    ),
                    transforms.RandomGrayscale(p=gray_scale_prob),
                    transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
                    transforms.RandomApply([Solarization()], p=solarization_prob),
                    transforms.RandomHorizontalFlip(p=horizontal_flip_prob),
                    transforms.ToTensor(),
                ]
            )
        self.selfsup_transform2 =  transforms.Compose(
                [
                    transforms.RandomResizedCrop(
                        (crop_size, crop_size),
                        scale=(min_scale, max_scale),
                        interpolation=transforms.InterpolationMode.BICUBIC,
                    ),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness, contrast, saturation, hue)],
                        p=color_jitter_prob,
                    ),
                    transforms.RandomGrayscale(p=gray_scale_prob),
                    transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
                    transforms.RandomApply([Solarization()], p=0.2),
                    transforms.RandomHorizontalFlip(p=horizontal_flip_prob),
                    transforms.ToTensor(),
                ]
            )
        self.simple_tranform = transforms.Compose(
                [
                    transforms.Resize(crop_size),
                    transforms.ToTensor(),
                ])
        
