import torch
import torch.nn.functional as F
from torch import Tensor, nn

import smlm
from smlm.models.base.affine_norm import AffineNorm
from smlm.models.base.unet_model import UNet
from smlm.models.base.unet_parts import Down, OutConv


class Decode(nn.Module):
    _PHOTONS_CST = 1e3

    def __init__(
        self, inner_dim: int, n_frames: int, pixel_size: Tensor, z_extent: Tensor
    ):
        super().__init__()
        assert n_frames % 2 == 1
        self.dim = inner_dim
        self.n_frames = n_frames

        self.normalize = AffineNorm(mu=0, sigma=self._PHOTONS_CST)
        self.frame_network = UNet(1, inner_dim, depth=2, init_features=48)
        self.core_network = UNet(n_frames * inner_dim, 10, depth=2, init_features=48)
        self.register_buffer("z_extent", z_extent)
        self.register_buffer("pixel_size", pixel_size)

    def forward(self, y: Tensor) -> Tensor:
        bs, _, h, w = y.shape
        device = y.device

        x = self.normalize(y)
        x = x.view(bs * self.n_frames, 1, h, w)
        x = self.frame_network(x)
        x = x.view(bs, self.n_frames * self.dim, h, w)
        x = self.core_network(x)

        p = x[:, 0, None]
        xy = x[:, 1:3]
        z = x[:, 3, None]
        n = x[:, 4, None]
        uxy = x[:, 5:7]
        uz = x[:, 7, None]
        un = x[:, 8, None]
        bg = x[:, 9]

        p = torch.sigmoid(p)

        xy_ref = smlm.utils.coordinates.map_coordinates_cell_center(
            h=h,
            w=w,
            cell_width=self.pixel_size[0],
            cell_height=self.pixel_size[1],
            device=device,
        )
        xy = 1.5 * self.pixel_size[:, None, None] * torch.tanh(xy)
        xy = xy_ref + xy
        uxy = self.pixel_size[:, None, None] * F.softplus(uxy) + 5.0

        z = (self.z_extent[1] - self.z_extent[0]) * torch.sigmoid(z) + self.z_extent[0]
        uz = self.z_extent[1] * F.softplus(uz) + 5.0

        n = self._PHOTONS_CST * F.softplus(n)
        un = self._PHOTONS_CST * F.softplus(un) + 5.0

        x = torch.cat([p, xy, z, n, uxy, uz, un], dim=1)
        x = smlm.utils.map2list.map2list(x)

        bg = self._PHOTONS_CST * F.softplus(bg)

        if self.training:
            return x, bg

        x = self.post_processing(x, h=h, w=w)
        return x

    def post_processing(self, x: Tensor, h: int, w: int):
        p, x = x[..., 0], x[..., 1:]
        p = smlm.utils.map2list.list2map(p, size_h=h, size_w=w)
        p = non_maximum_suppression(p, raw_th=0.1, split_th=0.6)
        p = smlm.utils.map2list.map2list(p)
        p = p >= 0.4
        x = [x_[p_] for x_, p_ in zip(x, p)]
        x = torch.nested.nested_tensor(x, layout=torch.jagged)
        return x


def norm_sum(*args):
    return torch.clamp(torch.add(*args), 0.0, 1.0)


def non_maximum_suppression(
    p: Tensor, raw_th, split_th, p_aggregation=norm_sum
) -> Tensor:
    p_copy = p.clone()

    """Probability values > 0.3 are regarded as possible locations"""
    p_clip = torch.where(p > raw_th, p, torch.zeros_like(p))[:, None]

    """localize maximum values within a 3x3 patch"""
    pool = torch.nn.functional.max_pool2d(p_clip, 3, 1, padding=1)
    max_mask1 = torch.eq(p[:, None], pool).float()

    """Add probability values from the 4 adjacent pixels"""
    diag = 0.0  # 1/np.sqrt(2)
    filt = (
        Tensor([[diag, 1.0, diag], [1, 1, 1], [diag, 1, diag]])
        .unsqueeze(0)
        .unsqueeze(0)
        .to(p.device)
    )
    conv = torch.nn.functional.conv2d(p[:, None], filt, padding=1)
    p_ps1 = max_mask1 * conv

    """
        In order do be able to identify two fluorophores in adjacent pixels we look for
        probablity values > 0.6 that are not part of the first mask
        """
    p_copy *= 1 - max_mask1[:, 0]
    # p_clip = torch.where(p_copy > split_th, p_copy, torch.zeros_like(p_copy))[:, None]
    max_mask2 = torch.where(
        p_copy > split_th, torch.ones_like(p_copy), torch.zeros_like(p_copy)
    )[:, None]
    p_ps2 = max_mask2 * conv

    """This is our final clustered probablity which we then threshold (normally > 0.7)
        to get our final discrete locations"""
    p_ps = p_aggregation(p_ps1, p_ps2)
    assert p_ps.size(1) == 1

    return p_ps.squeeze(1)


class DecodeNoNMS(Decode):
    def __init__(
        self, inner_dim: int, n_frames: int, pixel_size: Tensor, z_extent: Tensor
    ):
        super().__init__(
            inner_dim=inner_dim,
            n_frames=n_frames,
            pixel_size=pixel_size,
            z_extent=z_extent,
        )
        t = torch.tensor([0.0])
        self.register_buffer("t", t, persistent=False)

    def supports_fast_calibration(self):
        return True

    def set_thresholds(self, thresholds: Tensor):
        if thresholds is not None:
            device = next(self.parameters()).device
            thresholds = thresholds.to(device=device)
        self.t = thresholds

    def get_thresholds(self) -> Tensor:
        return self.t

    def post_processing(self, x: Tensor, h: int, w: int):
        x = self.apply_thresholds(x, thresholds=self.t)
        return x

    def apply_thresholds(self, x: Tensor, thresholds: Tensor):
        if thresholds is None:
            return x
        x = [x_[x_[:, 0] > thresholds, 1:] for x_ in x]
        x = torch.nested.nested_tensor(x, layout=torch.jagged)
        return x


class DecodeNoUncert(nn.Module):
    _PHOTONS_CST = 1e3

    def __init__(
        self, inner_dim: int, n_frames: int, pixel_size: Tensor, z_extent: Tensor
    ):
        super().__init__()
        assert n_frames % 2 == 1
        self.dim = inner_dim
        self.n_frames = n_frames

        self.normalize = AffineNorm(mu=0, sigma=self._PHOTONS_CST)
        self.frame_network = UNet(1, inner_dim, depth=2, init_features=48)
        self.core_network = UNet(n_frames * inner_dim, 6, depth=2, init_features=48)
        self.register_buffer("z_extent", z_extent)
        self.register_buffer("pixel_size", pixel_size)

        t = torch.tensor([0.0])
        self.register_buffer("t", t, persistent=False)

    def forward(self, y: Tensor) -> Tensor:
        bs, _, h, w = y.shape
        device = y.device

        x = self.normalize(y)
        x = x.view(bs * self.n_frames, 1, h, w)
        x = self.frame_network(x)
        x = x.view(bs, self.n_frames * self.dim, h, w)
        x = self.core_network(x)

        p = x[:, 0, None]
        xy = x[:, 1:3]
        z = x[:, 3, None]
        n = x[:, 4, None]
        bg = x[:, 5]

        p = torch.sigmoid(p)

        xy_ref = smlm.utils.coordinates.map_coordinates_cell_center(
            h=h,
            w=w,
            cell_width=self.pixel_size[0],
            cell_height=self.pixel_size[1],
            device=device,
        )
        xy = 1.5 * self.pixel_size[:, None, None] * torch.tanh(xy)
        xy = xy_ref + xy
        z = (self.z_extent[1] - self.z_extent[0]) * torch.sigmoid(z) + self.z_extent[0]
        n = self._PHOTONS_CST * F.softplus(n)
        x = torch.cat([p, xy, z, n], dim=1)
        x = smlm.utils.map2list.map2list(x)

        bg = self._PHOTONS_CST * F.softplus(bg)

        if self.training:
            return x, bg

        x = self.post_processing(x)
        return x

    def supports_fast_calibration(self):
        return True

    def set_thresholds(self, thresholds: Tensor):
        if thresholds is not None:
            device = next(self.parameters()).device
            thresholds = thresholds.to(device=device)
        self.t = thresholds

    def get_thresholds(self) -> Tensor:
        return self.t

    def post_processing(self, x: Tensor):
        x = self.apply_thresholds(x, thresholds=self.t)
        return x

    def apply_thresholds(self, x: Tensor, thresholds: Tensor):
        if thresholds is None:
            return x
        x = [x_[x_[:, 0] > thresholds, 1:] for x_ in x]
        x = torch.nested.nested_tensor(x, layout=torch.jagged)
        return x


class DecodeNoUncertHalf(DecodeNoUncert):
    def __init__(
        self, inner_dim: int, n_frames: int, pixel_size: Tensor, z_extent: Tensor
    ):
        super().__init__(
            inner_dim=inner_dim,
            n_frames=n_frames,
            pixel_size=pixel_size,
            z_extent=z_extent,
        )
        self.core_network = UNet(
            n_frames * inner_dim, inner_dim, depth=2, init_features=48
        )
        self.decoder_net = nn.Sequential(
            Down(inner_dim, 2 * inner_dim),
            OutConv(2 * inner_dim, 5),
        )
        self.bg_net = OutConv(inner_dim, 1)
        self.register_buffer("pixel_size", 2 * pixel_size)

    def forward(self, y: Tensor) -> Tensor:
        bs, _, h, w = y.shape
        device = y.device

        x = self.normalize(y)
        x = x.view(bs * self.n_frames, 1, h, w)
        x = self.frame_network(x)
        x = x.view(bs, self.n_frames * self.dim, h, w)
        x = self.core_network(x)
        bg = self.bg_net(x)[:, 0]
        x = self.decoder_net(x)

        p = x[:, 0, None]
        xy = x[:, 1:3]
        z = x[:, 3, None]
        n = x[:, 4, None]

        if not self.training:
            p = torch.sigmoid(p)

        xy_ref = smlm.utils.coordinates.map_coordinates_cell_center(
            h=xy.size(-2),
            w=xy.size(-1),
            cell_width=self.pixel_size[0],
            cell_height=self.pixel_size[1],
            device=device,
        )
        xy = 1.5 * self.pixel_size[:, None, None] * torch.tanh(xy)
        xy = xy_ref + xy
        z = (self.z_extent[1] - self.z_extent[0]) * torch.sigmoid(z) + self.z_extent[0]
        n = self._PHOTONS_CST * F.softplus(n)
        x = torch.cat([p, xy, z, n], dim=1)
        x = smlm.utils.map2list.map2list(x)

        bg = self._PHOTONS_CST * F.softplus(bg)

        if self.training:
            return x, bg

        x = self.post_processing(x)
        return x
