import numpy as np

import torch
from torch import Tensor
from torch.nn.functional import conv2d, pad as torch_pad
from PIL import Image, ImageFilter, ImageEnhance


class GaussianBlur(torch.nn.Module):
    def __init__(self, kernel_size=(5, 5), sigma=(0.1, 5)):
        super().__init__()     
        self.kernel_size = kernel_size
        self.sigma = sigma

    @staticmethod
    def get_params(sigma_min: float, sigma_max: float) -> float:
        """Choose sigma for random gaussian blurring.

        Args:
            sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
            sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.

        Returns:
            float: Standard deviation to be passed to calculate kernel for gaussian blurring.
        """
        return torch.empty(1).uniform_(sigma_min, sigma_max).item()
    
    def forward(self, img):
        sigma = self.get_params(self.sigma[0], self.sigma[1])
        
        t_img = img
        if not isinstance(img, torch.Tensor):
            t_img = torch.as_tensor(np.array(img, copy=True))
            t_img = t_img.view(img.size[1], img.size[0], len(img.getbands()))
            # put it from HWC to CHW format
            t_img = t_img.permute((2, 0, 1))

        output = gaussian_blur(t_img, self.kernel_size, [sigma, sigma])

        if not isinstance(img, torch.Tensor):
            if output.is_floating_point() and img.mode != "F":
                output = output.mul(255).byte()
            output = np.transpose(output.cpu().numpy(), (1, 2, 0))
            output = Image.fromarray(output, mode=img.mode)

        return output
    
    def __repr__(self):
        s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})"
        return s
    
class RandomAdjustSharpness(torch.nn.Module):
    def __init__(self, sharpness_factor=1.5, p=0.1):
        super().__init__()     
        self.sharpness_factor = sharpness_factor
        self.p = p
    
    def forward(self, img):
        if torch.rand(1).item() < self.p:
            if not isinstance(img, torch.Tensor):
                img = ImageEnhance.Sharpness(img).enhance(self.sharpness_factor)
            else:
                img = adjust_sharpness(img, self.sharpness_factor)
        
        return img
    
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(sharpness_factor={self.sharpness_factor},p={self.p})"
    
def _cast_squeeze_in(img, req_dtypes):
    need_squeeze = False
    # make image NCHW
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
    if out_dtype not in req_dtypes:
        need_cast = True
        req_dtype = req_dtypes[0]
        img = img.to(req_dtype)
    return img, need_cast, need_squeeze, out_dtype


def _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype):
    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
        if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
            # it is better to round before cast
            img = torch.round(img)
        img = img.to(out_dtype)

    return img

def _blend(img1, img2, ratio):
    ratio = float(ratio)
    bound = _max_value(img1.dtype)
    return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)

def _get_gaussian_kernel1d(kernel_size, sigma):
    ksize_half = (kernel_size - 1) * 0.5

    x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
    pdf = torch.exp(-0.5 * (x / sigma).pow(2))
    kernel1d = pdf / pdf.sum()

    return kernel1d


def _get_gaussian_kernel2d(kernel_size, sigma, dtype, device):
    kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
    kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
    kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
    return kernel2d


def gaussian_blur(img, kernel_size, sigma):
    if not (isinstance(img, torch.Tensor)):
        raise TypeError(f"img should be Tensor. Got {type(img)}")

    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])

    # padding = (left, right, top, bottom)
    padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
    img = torch_pad(img, padding, mode="reflect")
    img = conv2d(img, kernel, groups=img.shape[-3])

    img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img

def _blurred_degenerate_image(img):
    dtype = img.dtype if torch.is_floating_point(img) else torch.float32

    kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
    kernel[1, 1] = 5.0
    kernel /= kernel.sum()
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

    result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype])
    result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
    result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)

    result = img.clone()
    result[..., 1:-1, 1:-1] = result_tmp

    return result


def adjust_sharpness(img, sharpness_factor):
    if sharpness_factor < 0:
        raise ValueError(f"sharpness_factor ({sharpness_factor}) is not non-negative.")

    if img.size(-1) <= 2 or img.size(-2) <= 2:
        return img

    return _blend(img, _blurred_degenerate_image(img), sharpness_factor)

