import math
import torch
import numpy as np
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass


@dataclass
class GaussianCFG:
    kernel_size: int
    sgm: float


@dataclass
class BilateralCFG:
    kernel_size: int
    sgm_spatial: float
    sgm_range: float


@dataclass
class PostprocCFG:
    gauss: GaussianCFG
    bilat: BilateralCFG
    eps: float


class GaussianFilter2D(nn.Module):
    def __init__(
        self, 
        kernel_size: int = 11, 
        sgm: float = 5.0,
    ):
        super().__init__()
        assert kernel_size % 2 == 1

        self.kernel_size = kernel_size
        self.sgm = sgm

        kernel = gaussian_kernel_2d(kernel_size, sgm)
        kernel /= kernel.sum()
        kernel = kernel.unsqueeze(0).unsqueeze(0)

        self.register_buffer('kernel', kernel)

    def forward(self, x: Tensor) -> Tensor:
        in_shape = x.shape
        assert len(in_shape) == 4
        assert in_shape[1] == 1
        return F.conv2d(x, self.kernel, padding=self.kernel_size//2)


class BilateralFilter2D(nn.Module):
    def __init__(
        self,
        kernel_size: int = 3,
        sgm_spatial: float = 0.5,
        sgm_range: float = 0.05,
        eps: float = 1e-8,
    ):
        super().__init__()
        assert kernel_size % 2 == 1

        self.kernel_size = kernel_size
        self.pad = kernel_size // 2
        self.sgm_spatial = sgm_spatial
        self.sgm_range = sgm_range
        self.eps = eps

        k_spatial = gaussian_kernel_2d(kernel_size, sgm_spatial).reshape(1, -1, 1)
        self.register_buffer("kernel_spatial", k_spatial)

    def forward(self, x: Tensor) -> Tensor:
        N, C, H, W = x.shape
        assert C == 1

        x_p = F.unfold(x, kernel_size=self.kernel_size, padding=self.pad)
        x_c = x.view(N, 1, H * W)

        kernel_range = gaussian(x_p - x_c, self.sgm_range)
        weights = kernel_range * self.kernel_spatial

        out = (weights * x_p).sum(dim=1)
        norm = weights.sum(dim=1)
        out = out / (norm + self.eps)

        return out.view(N, 1, H, W)


class PostProcessing(nn.Module):
    def __init__(self, cfg: PostprocCFG):
        super().__init__()

        if cfg.gauss is None:
            self.gaussian = nn.Identity()
        else:
            self.gaussian = GaussianFilter2D(
                kernel_size=cfg.gauss.kernel_size,
                sgm=cfg.gauss.sgm,
            )

        if cfg.bilat is None:
            self.bilateral = nn.Identity()
        else:
            self.bilateral = BilateralFilter2D(
                kernel_size=cfg.bilat.kernel_size,
                sgm_spatial=cfg.bilat.sgm_spatial,
                sgm_range=cfg.bilat.sgm_range,
                eps=cfg.eps,
            )

    def lerp(self, x0: Tensor, x1: Tensor, alpha: float = 0.5) -> Tensor:
        return (1-alpha) * x0 + alpha * x1

    def forward(self, x: Tensor) -> np.ndarray:
        assert len(x.shape) == 4
        N = x.shape[0]

        x = x.sum(axis=1, keepdims=True)
        x = F.relu(x)
        x_smooth = self.gaussian(x)
        x = self.lerp(x, x_smooth)
        x = self.bilateral(x)
        x = x / x.reshape(N, -1).max(dim=-1, keepdim=True).values.reshape(N, 1, 1, 1)
        x = x.permute(0, 2, 3, 1).detach().cpu().numpy().sum(axis=-1)
        return x



def quantile_clamp(
    x: Tensor,
    q_low: float = 0.0,
    q_high: float = 0.99,
) -> Tensor:
    N = x.size(0)
    q_bounds = torch.tensor(
        [q_low, q_high],
        dtype=x.dtype,
        device=x.device,
    )
    q = torch.quantile(
        x.view(N, -1),
        q_bounds,
        dim=1,
    )
    q_low  = q[0].view(N, 1, 1, 1)
    q_high = q[1].view(N, 1, 1, 1)
    return torch.clamp(x, min=q_low, max=q_high)


def make_centering(n: int, device=None, dtype=None) -> Tensor:
    identity = torch.eye(n, device=device, dtype=dtype)
    ones = torch.ones((n, n), device=device, dtype=dtype)
    c = identity - ones / float(n)
    return c


def gaussian(x: Tensor, sgm: float) -> Tensor:
    return torch.exp(-(x**2)/(2*(sgm**2)))


def gaussian_kernel_2d(kernel_size: int, sgm: float):
    coords = torch.arange(kernel_size) - kernel_size // 2
    x_grid, y_grid = torch.meshgrid(coords, coords, indexing='ij')
    kernel = gaussian(x_grid, sgm) * gaussian(y_grid, sgm)
    return kernel

