import torch
from torch import Tensor, nn

import smlm
from smlm.losses.ot import OptimalTransportWithUncertaintiesLoss
from smlm.models.simulator import Simulator


class OTTrainingModule(nn.Module):
    def __init__(
        self,
        adu_baseline: float,
        background_photon_flux: float,
        camera_type: str,
        e_adu: float,
        em_gain: float,
        eps: float,
        jitter_std: float,
        model: nn.Module,
        n_frames: int,
        photon_flux_mean: float,
        psf_center: Tensor,
        psf: Tensor,
        quantum_efficiency: float,
        readout_noise: float,
        reg: float,
        spurious_charge: float,
        voxel_size: Tensor,
    ):
        super().__init__()
        if n_frames % 2 != 1:
            raise ValueError("n_frames must be odd")
        self.n_frames = n_frames
        self.tg_frame_idx = n_frames // 2
        pixel_size = voxel_size[:2]

        self.model = model
        inv_voxel_size = voxel_size.reciprocal()
        self.simulator = Simulator(
            adu_baseline=adu_baseline,
            e_adu=e_adu,
            em_gain=em_gain,
            inv_voxel_size=inv_voxel_size,
            psf_center=psf_center,
            psf=psf,
            quantum_efficiency=quantum_efficiency,
            readout_noise=readout_noise,
            spurious_charge=spurious_charge,
            camera_type=camera_type,
            jitter_std=jitter_std,
        )
        self.seed = 0
        self.ot_loss = OptimalTransportWithUncertaintiesLoss(
            reg=reg,
            pixel_size=pixel_size,
            photon_cst=photon_flux_mean,
        )

    def forward(self, batch):
        x_all_frames, _ = batch["x_all"]
        x_gt, x_gt_lengths = batch["x"]
        bg_gt = batch["bg"]
        device = x_gt.device

        self.seed = smlm.utils.random.derive_new_seed(self.seed)
        y = self.simulator(x_all_frames, bg=bg_gt, seed=self.seed)

        mask_gt = torch.arange(x_gt.size(1), device=device)
        mask_gt = mask_gt < x_gt_lengths[:, None]
        x_gt = x_gt[..., [0, 1, 2, 3 + self.tg_frame_idx]]

        x = self.model(y)

        losses_ot = []
        losses_bg = []
        X = x if isinstance(x, list) else [x]
        for x, bg in X:
            p, x = x[..., 0], x[..., 1:]
            loss_ot = self.ot_loss(x=x, p=p, x_gt=x_gt, mask_gt=mask_gt)
            losses_ot.append(loss_ot)
            loss_bg = torch.nn.functional.mse_loss(bg, bg_gt)
            losses_bg.append(loss_bg)
        loss_ot = torch.stack(losses_ot).mean()
        loss_bg = torch.stack(losses_bg).mean()
        loss = loss_ot + 1e-6 * loss_bg
        return {"loss": loss, "ot": loss_ot}
