from typing import List, Union
from torchvision import transforms
# from pl_bolts.transforms.self_supervised import RandomTranslateWithReflect


class Transforms:
    """Transforms applied for training, validation and test.

    Transforms::

        img_jitter,
        col_jitter,
        rnd_gray,
        transforms.ToTensor(),
        normalize

    Example::

        x = torch.rand(5, 3, 32, 32)

        transform = TrainTransforms()
        (view1, view2) = transform(x)
    """

    def __init__(self, crop_size: List[int], augmentation: bool):
        self.augmentation = augmentation

        self.crop = (
            transforms.RandomResizedCrop(crop_size) if crop_size[0] > 0 else None
        )

        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        )

        transform_list = []

        # if augmentation:
        #     # flipping image along vertical axis
        #     self.flip_lr = transforms.RandomHorizontalFlip(p=0.5)

        #     # image augmentation functions
        #     img_jitter = transforms.RandomApply([RandomTranslateWithReflect(4)], p=0.8)
        #     col_jitter = transforms.RandomApply(
        #         [transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8
        #     )
        #     rnd_gray = transforms.RandomGrayscale(p=0.25)
        #     transform_list += [img_jitter, col_jitter, rnd_gray]

        transform_list += [transforms.ToTensor(), normalize]

        self.transforms = transforms.Compose(transform_list)

    def __call__(self, inp):
        if self.crop:
            inp = self.crop(inp)
            
        if self.augmentation:
            inp = self.flip_lr(inp)
            return self.transforms(inp), self.transforms(inp)
        else:
            
            return self.transforms(inp)
