import torchvision.transforms as T

try:
    from torchvision.transforms import GaussianBlur
except ImportError:
    from .gaussian_blur import GaussianBlur

    T.GaussianBlur = GaussianBlur

imagenet_mean_std = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]


class SimCLRTransform():
    def __init__(self, image_size, mean_std=imagenet_mean_std, s=1.0):
        image_size = 224 if image_size is None else image_size
        self.transform = T.Compose([
            T.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply([T.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([T.GaussianBlur(kernel_size=image_size // 20 * 2 + 1, sigma=(0.1, 2.0))], p=0.5),
            # We blur the image 50% of the time using a Gaussian kernel. We randomly sample σ ∈ [0.1, 2.0], and the kernel size is set to be 10% of the image height/width.
            T.Normalize(*mean_std)
        ])

    def __call__(self, x):
        x = self.transform(x)
        return x