import torchvision.transforms as transforms
import numpy as np
import cv2

import random
from PIL import ImageFilter

def get_transforms(transform_scheme):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    if transform_scheme == 'imagenet':
        transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    elif transform_scheme == 'simclr':
        # adapted from
        # https://github.com/PyTorchLightning/Lightning-Bolts/blob/master/pl_bolts/models/self_supervised/simclr/transforms.py
        size = 224
        kernel_size = int(0.1 * size)
        if kernel_size % 2 == 0:
            kernel_size += 1
        transform = transforms.Compose([
            transforms.RandomResizedCrop(size),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(kernel_size=kernel_size, p=0.5),
            transforms.ToTensor(),
        ])
    elif transform_scheme == 'swav-simclr':
        # taken from SwAV code
        size = 224
        min_scale = 0.14
        max_scale = 1.0
        means = [0.485, 0.456, 0.406]
        stds = [0.228, 0.224, 0.225]
        transform = transforms.Compose([
            transforms.RandomResizedCrop(size, scale=(min_scale, max_scale)),
            transforms.RandomHorizontalFlip(p=0.5), # 0.5 is default, not sure why they added this
            transforms.RandomApply([
                transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            PILRandomGaussianBlur(),
            transforms.ToTensor(),
            transforms.Normalize(mean=means, std=stds), # not in normal SimCLR (I don't think)
        ])
    else:
        raise ValueError('Transformation scheme not supported')
    return transform

class GaussianBlur(object):
    # Implements Gaussian blur as described in the SimCLR paper
    # https://github.com/PyTorchLightning/Lightning-Bolts/blob/master/pl_bolts/models/self_supervised/simclr/transforms.py
    # for simclr transforms
    def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0):
        self.min = min
        self.max = max

        # kernel size is set to be 10% of the image height/width
        self.kernel_size = kernel_size
        self.p = p

    def __call__(self, sample):
        sample = np.array(sample)

        # blur the image with a 50% chance
        prob = np.random.random_sample()

        if prob < self.p:
            sigma = (self.max - self.min) * np.random.random_sample() + self.min
            sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)

        return sample

class PILRandomGaussianBlur(object):
    # for swav-simclr transforms
    def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
        self.prob = p
        self.radius_min = radius_min
        self.radius_max = radius_max

    def __call__(self, img):
        do_it = np.random.rand() <= self.prob
        if not do_it:
            return img

        return img.filter(
            ImageFilter.GaussianBlur(
                radius=random.uniform(self.radius_min, self.radius_max)
            )
        )
