import torch
import torchvision.transforms.functional as F

class GaussianBlur(object):
    def __init__(self, kernel_size, sigma_min=0.1, sigma_max=2.0):
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.kernel_size = 3 if kernel_size < 3 else kernel_size

    def __call__(self, x):
        sigma = torch.rand(1).item() * (self.sigma_max - self.sigma_min) + self.sigma_min
        x = F.gaussian_blur(x, kernel_size=self.kernel_size, sigma=[sigma, sigma])
        return x
"""
class SobelFilter(torch.nn.Module):
    def __init__(self):
        super(SobelFilter, self).__init__()
        self.sobel_kernel_x = torch.tensor([[2.0, 0.0, -2.0], [4.0, 0.0, -4.0], [2.0, 0.0, -2.0]], dtype=torch.float32,
                                           requires_grad=False).unsqueeze(0).unsqueeze(0)
        self.sobel_kernel_y = torch.tensor([[2.0, 4.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -4.0, -2.0]], dtype=torch.float32,
                                           requires_grad=False).unsqueeze(0).unsqueeze(0)

    def forward(self, x):
        grad_x = torch.nn.functional.conv2d(x, self.sobel_kernel_x.to(x.device))
        grad_y = torch.nn.functional.conv2d(x, self.sobel_kernel_y.to(x.device))
        magnitude = torch.sqrt(grad_x ** 2 + grad_y ** 2)
        return magnitude.squeeze(1)"""

class SobelFilter(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.filter = torch.nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1,
                                      padding=1, bias=False).cuda()

        Gx = torch.tensor([[2.0, 0.0, -2.0], [4.0, 0.0, -4.0], [2.0, 0.0, -2.0]]).cuda()
        Gy = torch.tensor([[2.0, 4.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -4.0, -2.0]]).cuda()
        G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0)
        G = G.unsqueeze(1)
        self.filter.weight = torch.nn.Parameter(G, requires_grad=False)

    def forward(self, img):
        x = self.filter(img)
        x = torch.mul(x, x)
        x = torch.sum(x, dim=1, keepdim=True)
        x = torch.sqrt(x)
        return x


def augment_images(x, augmentation, num_augmentations=2):
    """
    Augmentations for NT-Xent loss. Standard number of pairs is 2.
    """
    x_list = []
    for img in x:
        img = img[None, None, ...]
        for _ in range(num_augmentations):
            aug_img = augmentation(img)
            x_list.append(aug_img)
    return torch.stack(x_list)
