import torch
import torch.nn.functional as F
from torch import Tensor, nn

import smlm
from smlm.models.base.affine_norm import AffineNorm
from smlm.models.base.unet_model import UNet
from smlm.models.base.unet_parts import Down, OutConv
from smlm.models.decode import non_maximum_suppression
from smlm.models.simulator import Renderer


def LayerNorm(n):
    return nn.GroupNorm(1, n)


class SHOT(nn.Module):
    _PHOTONS_CST = 1e3

    def __init__(
        self,
        adu_baseline: float,
        camera_type: str,
        e_adu: float,
        em_gain: float,
        inner_dim: int,
        n_frames: int,
        n_iters: int,
        psf_center: Tensor,
        psf: Tensor,
        quantum_efficiency: float,
        voxel_size: Tensor,
        z_extent: Tensor,
    ):
        super().__init__()
        assert n_frames % 2 == 1
        self.x0 = None
        self.dim = inner_dim
        self.n_frames = n_frames
        self.center_frame_idx = n_frames // 2
        self.n_iters = n_iters
        self.out_dim = 5

        self.register_buffer("pixel_size", voxel_size[:2])
        self.register_buffer("z_extent", z_extent)

        self.normalize = AffineNorm(mu=0, sigma=self._PHOTONS_CST)
        self.encoder_net = UNet(
            1,
            inner_dim,
            depth=2,
            init_features=48,
            norm_module=LayerNorm,
        )
        self.residual_net = UNet(
            inner_dim * (n_frames + 2),
            inner_dim,
            depth=2,
            init_features=48,
            norm_module=LayerNorm,
        )
        self.decoder_net = nn.Sequential(
            Down(inner_dim, 2 * inner_dim, norm_module=LayerNorm),
            OutConv(2 * inner_dim, self.out_dim),
        )
        self.bg_net = OutConv(inner_dim, 1)

        t = torch.tensor([0.0])
        self.register_buffer("t", t, persistent=False)

        self.psf = psf
        self.renderer = Renderer(
            psf=self.psf,
            psf_center=psf_center,
            voxel_size=voxel_size,
            quantum_efficiency=quantum_efficiency,
            em_gain=em_gain,
            adu_baseline=adu_baseline,
            e_adu=e_adu,
            camera_type=camera_type,
        )

    def supports_fast_calibration(self):
        return True

    def set_thresholds(self, thresholds: Tensor):
        if thresholds is not None:
            device = next(self.parameters()).device
            thresholds = thresholds.to(device=device)
        self.t = thresholds

    def get_thresholds(self) -> Tensor:
        return self.t

    def forward(self, y0: Tensor) -> Tensor:
        z0 = self.encode(y0)
        idx = self.center_frame_idx
        z = z0[:, idx * self.dim : (idx + 1) * self.dim]
        x_hat, bg_hat = self.decode(z)
        if self.training:
            X = [(x_hat, bg_hat)]
        for i in range(self.n_iters):
            y_hat = self.render(x_hat, bg=bg_hat)
            z_hat = self.encode(y_hat)
            z = z + self.residual(z0=z0, z_hat=z_hat, z=z)
            x_hat, bg_hat = self.decode(z)
            if self.training:
                X.append((x_hat, bg_hat))
        if self.training:
            return X
        x_hat[..., 0].sigmoid_()  # send p to (0,1)
        x_hat = self.post_processing(x_hat)
        return x_hat

    def encode(self, y: Tensor) -> Tensor:
        assert y.ndim == 4
        bs, f, h, w = y.shape

        y = self.normalize(y)
        y = y.view(bs * f, 1, h, w)
        z = self.encoder_net(y)
        z = z.view(bs, self.dim * f, h, w)
        return z

    def decode(self, z: Tensor) -> Tensor:
        device = z.device

        bg = self.bg_net(z)
        bg = bg.squeeze(1)
        bg = self._PHOTONS_CST * F.softplus(bg)

        x = self.decoder_net(z)
        pixel_size = 2 * self.pixel_size  # due to the down conv

        p = x[:, 0, None]
        xy = x[:, 1:3]
        z = x[:, 3, None]
        n = x[:, 4, None]

        # p is keept as logits
        # p = torch.sigmoid(p)

        xy_ref = smlm.utils.coordinates.map_coordinates_cell_center(
            h=x.size(-2),
            w=x.size(-1),
            cell_width=pixel_size[0],
            cell_height=pixel_size[1],
            device=device,
        )
        xy = 1.5 * pixel_size[:, None, None] * torch.tanh(xy)
        xy = xy_ref + xy

        z = (self.z_extent[1] - self.z_extent[0]) * torch.sigmoid(z) + self.z_extent[0]

        n = self._PHOTONS_CST * F.softplus(n)

        x = torch.cat([p, xy, z, n], dim=1)
        x = smlm.utils.map2list.map2list(x)

        return x, bg

    def residual(self, z0: Tensor, z_hat: Tensor, z: Tensor) -> Tensor:
        input = torch.cat([z0, z_hat, z], dim=1)
        dz = self.residual_net(input)
        z = z + dz
        return z

    def render(self, x: Tensor, bg: Tensor) -> Tensor:
        xyz = x[..., 1:4]
        p = torch.sigmoid(x[..., 0, None])  # p to (0,1)
        n = p * x[..., 4, None]  # n <- p * n
        x = torch.cat([xyz, n], dim=-1)
        y = self.renderer(x, bg=bg)
        return y

    def post_processing(self, x: Tensor):
        x = self.apply_thresholds(x, thresholds=self.t)
        return x

    def apply_thresholds(self, x: Tensor, thresholds: Tensor):
        if thresholds is None:
            return x
        x = [x_[x_[:, 0] > thresholds, 1:] for x_ in x]
        x = torch.nested.nested_tensor(x, layout=torch.jagged)
        return x


class SHOTWithUncert(nn.Module):
    _PHOTONS_CST = 1e3

    def __init__(
        self,
        adu_baseline: float,
        camera_type: str,
        e_adu: float,
        em_gain: float,
        inner_dim: int,
        n_frames: int,
        n_iters: int,
        psf_center: Tensor,
        psf: Tensor,
        quantum_efficiency: float,
        voxel_size: Tensor,
        z_extent: Tensor,
    ):
        super().__init__()
        assert n_frames % 2 == 1
        self.x0 = None
        self.dim = inner_dim
        self.n_frames = n_frames
        self.center_frame_idx = n_frames // 2
        self.n_iters = n_iters
        self.out_dim = 9

        self.register_buffer("pixel_size", voxel_size[:2])
        self.register_buffer("z_extent", z_extent)

        self.normalize = AffineNorm(mu=0, sigma=self._PHOTONS_CST)
        self.encoder_net = UNet(
            1,
            inner_dim,
            depth=2,
            init_features=48,
            norm_module=LayerNorm,
        )
        self.residual_net = UNet(
            inner_dim * (n_frames + 2),
            inner_dim,
            depth=2,
            init_features=48,
            norm_module=LayerNorm,
        )
        self.decoder_net = OutConv(inner_dim, self.out_dim)
        self.bg_net = OutConv(inner_dim, 1)

        self.psf = psf
        self.renderer = Renderer(
            psf=self.psf,
            psf_center=psf_center,
            voxel_size=voxel_size,
            quantum_efficiency=quantum_efficiency,
            em_gain=em_gain,
            adu_baseline=adu_baseline,
            e_adu=e_adu,
            camera_type=camera_type,
        )

    def forward(self, y0: Tensor) -> Tensor:
        z0 = self.encode(y0)
        idx = self.center_frame_idx
        z = z0[:, idx * self.dim : (idx + 1) * self.dim]
        x_hat, bg_hat = self.decode(z)
        if self.training:
            X = [(x_hat, bg_hat)]
        for i in range(self.n_iters):
            y_hat = self.render(x_hat, bg=bg_hat)
            z_hat = self.encode(y_hat)
            z = z + self.residual(z0=z0, z_hat=z_hat, z=z)
            x_hat, bg_hat = self.decode(z)
            if self.training:
                X.append((x_hat, bg_hat))
        if self.training:
            return X
        x_hat = self.post_processing(x_hat, h=y0.size(-2), w=y0.size(-1))
        return x_hat

    def encode(self, y: Tensor) -> Tensor:
        assert y.ndim == 4
        bs, f, h, w = y.shape

        y = self.normalize(y)
        y = y.view(bs * f, 1, h, w)
        z = self.encoder_net(y)
        z = z.view(bs, self.dim * f, h, w)
        return z

    def decode(self, z: Tensor) -> Tensor:
        device = z.device

        bg = self.bg_net(z)
        bg = bg.squeeze(1)
        bg = self._PHOTONS_CST * F.softplus(bg)

        x = self.decoder_net(z)

        p = x[:, 0, None]
        xy = x[:, 1:3]
        z = x[:, 3, None]
        n = x[:, 4, None]
        uxy = x[:, 5:7]
        uz = x[:, 7, None]
        un = x[:, 8, None]

        p = torch.sigmoid(p)

        xy_ref = smlm.utils.coordinates.map_coordinates_cell_center(
            h=x.size(-2),
            w=x.size(-1),
            cell_width=self.pixel_size[0],
            cell_height=self.pixel_size[1],
            device=device,
        )
        xy = 1.5 * self.pixel_size[:, None, None] * torch.tanh(xy)
        xy = xy_ref + xy
        uxy = self.pixel_size[:, None, None] * F.softplus(uxy) + 5.0

        z = (self.z_extent[1] - self.z_extent[0]) * torch.sigmoid(z) + self.z_extent[0]
        uz = self.z_extent[1] * F.softplus(uz) + 5.0

        n = self._PHOTONS_CST * F.softplus(n)
        un = self._PHOTONS_CST * F.softplus(un) + 5.0

        x = torch.cat([p, xy, z, n, uxy, uz, un], dim=1)
        x = smlm.utils.map2list.map2list(x)

        return x, bg

    def residual(self, z0: Tensor, z_hat: Tensor, z: Tensor) -> Tensor:
        input = torch.cat([z0, z_hat, z], dim=1)
        dz = self.residual_net(input)
        z = z + dz
        return z

    def render(self, x: Tensor, bg: Tensor) -> Tensor:
        xyz = x[..., 1:4]
        n = x[..., 0, None] * x[..., 4, None]  # n <- p * n
        x = torch.cat([xyz, n], dim=-1)
        y = self.renderer(x, bg=bg)
        return y

    def post_processing(self, x: Tensor, h: int, w: int):
        p, x = x[..., 0], x[..., 1:]
        p = smlm.utils.map2list.list2map(p, size_h=h, size_w=w)
        p = non_maximum_suppression(p, raw_th=0.1, split_th=0.6)
        p = smlm.utils.map2list.map2list(p)
        p = p >= 0.4
        x = [x_[p_] for x_, p_ in zip(x, p)]
        x = torch.nested.nested_tensor(x, layout=torch.jagged)
        return x
