"""
Gaussian Blur Data Augmentation

Implements Gaussian blur as used in SimCLR for data augmentation.
"""

import numpy as np
import torch
from torch import nn
from torchvision.transforms import transforms

np.random.seed(0)


class GaussianBlur(object):
    """
    Apply Gaussian blur to a PIL Image with random sigma.
    
    The kernel is computed dynamically with sigma sampled uniformly
    from [0.1, 2.0] for each image.
    
    Args:
        kernel_size: Size of the Gaussian kernel
    """
    def __init__(self, kernel_size):
        radias = kernel_size // 2
        kernel_size = radias * 2 + 1
        self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
                                stride=1, padding=0, bias=False, groups=3)
        self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
                                stride=1, padding=0, bias=False, groups=3)
        self.k = kernel_size
        self.r = radias

        self.blur = nn.Sequential(
            nn.ReflectionPad2d(radias),
            self.blur_h,
            self.blur_v
        )

        self.pil_to_tensor = transforms.ToTensor()
        self.tensor_to_pil = transforms.ToPILImage()

    def __call__(self, img):
        """
        Apply Gaussian blur to image.
        
        Args:
            img: PIL Image
            
        Returns:
            Blurred PIL Image
        """
        img = self.pil_to_tensor(img).unsqueeze(0)

        # Sample random sigma for blur strength
        sigma = np.random.uniform(0.1, 2.0)
        x = np.arange(-self.r, self.r + 1)
        x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
        x = x / x.sum()
        x = torch.from_numpy(x).view(1, -1).repeat(3, 1)

        # Set kernel weights
        self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
        self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))

        with torch.no_grad():
            img = self.blur(img)
            img = img.squeeze()

        img = self.tensor_to_pil(img)
        return img
