from math import ceil

import smlmsim
import torch
from torch import Generator, Tensor
from torch.utils.data import Dataset

import smlm


class SingleSyntheticSceneDataset(Dataset):
    """
    Generates one synthetic scene per requested sample.
    """

    def __init__(
        self,
        bg_photon_mean: float,
        bg_photon_std: float,
        length: int,
        n_acts_per_frame: float,
        n_frames: int,
        n_pixels: int,
        photon_flux_mean: float,
        photon_flux_std: float,
        pixel_size: Tensor,
        seed: int,
        time_bleach: float,
        time_off: float,
        time_on: float,
        z_extent: Tensor,
    ):
        super().__init__()
        self.length = length
        self.seed = smlm.utils.random.derive_new_seed(seed)

        if n_frames % 2 != 1:
            raise ValueError("n_frames must be odd")
        self.n_frames = n_frames
        self.tg_frame_idx = n_frames // 2
        self.n_pixels = n_pixels
        self.bg_photon_mean = bg_photon_mean
        self.bg_photon_std = bg_photon_std
        self.vol_extent = smlm.utils.extent.get_vol_extent(
            h=n_pixels, w=n_pixels, pixel_size=pixel_size, z_extent=z_extent
        )

        self.photon_flux_mean = photon_flux_mean
        self.photon_flux_std = photon_flux_std
        self.time_on = time_on
        self.time_off = time_off
        self.time_bleach = time_bleach
        self.n_periods_max = smlmsim.dynamics.icdf_n_activations(
            1e-3, time_bleach=time_bleach, time_off=time_off
        )

        self.min_intensity = 0.01 * self.photon_flux_mean

        # We estimate a few constants based on some generated fluorophores
        gen = smlm.utils.random.get_generator(0)
        x = smlm.activations.sample.sample_fluorophores(
            N=5000,
            vol_extent=self.vol_extent,
            time_on=self.time_on,
            time_off=self.time_off,
            time_bleach=self.time_bleach,
            n_periods_max=self.n_periods_max,
            n_frames=self.n_periods_max,
            photon_flux_mean=self.photon_flux_mean,
            photon_flux_std=self.photon_flux_std,
            gen=gen,
        )
        x = x[:, 3:]  # remove xyz

        mean_n_frames_lifetime = (x > 0.0).sum(dim=-1).float().mean(dim=0)
        self.n_fluos = ceil(n_acts_per_frame * self.length / mean_n_frames_lifetime)

        x = x.flatten()
        x = x[x > 0.0]
        self.significant_threshold = torch.quantile(x, q=0.25)

        self.sample_scene(seed=self.seed)

    def increment_seed(self):
        self.seed = smlm.utils.random.derive_new_seed(self.seed)

    def sample_scene(self, seed: int):
        gen = Generator().manual_seed(seed)

        self.xyz = smlm.activations.sample.sample_coordinates(
            self.n_fluos, vol_extent=self.vol_extent, gen=gen
        )
        self.d = smlm.activations.sample.sample_dynamics(
            self.n_fluos,
            time_bleach=self.time_bleach,
            time_off=self.time_off,
            time_on=self.time_on,
            n_periods_max=self.n_periods_max,
            gen=gen,
        )
        t_start = smlm.activations.sample.sample_starting_times(
            self.n_fluos, length=self.length, gen=gen
        )
        self.d += t_start[:, None, None]
        self.photon_flux = smlm.activations.sample.sample_photon_flux(
            self.n_fluos, mean=self.photon_flux_mean, std=self.photon_flux_std, gen=gen
        )

        self.bg = smlm.activations.sample.sample_background(
            h=self.n_pixels,
            w=self.n_pixels,
            mean=self.bg_photon_mean,
            std=self.bg_photon_std,
            gen=gen,
        )

        # not easy to get the max number of active fluorophores, bruteforce
        n_max_acts = [
            torch.sum(
                smlmsim.dynamics.batch_discretize_periods(
                    self.d - i, n_frames=self.n_frames
                )
                > 0
            )
            for i in range(self.length)
        ]
        self.max_n_acts = max(n_max_acts)

    def get_max_n_acts(self):
        return self.max_n_acts

    def collate_fn(self, batch):
        batch_dict = {key: [d[key] for d in batch] for key in batch[0]}

        x_all = smlm.utils.nested.pad_sequence(
            batch_dict["x_all"],
            target_len=self.max_n_acts,
            returns_lengths=True,
        )
        x = smlm.utils.nested.pad_sequence(
            batch_dict["x"],
            target_len=self.max_n_acts,
            returns_lengths=True,
        )
        s = smlm.utils.nested.pad_sequence(
            batch_dict["s"],
            target_len=self.max_n_acts,
            returns_lengths=True,
        )
        bg = torch.utils.data.default_collate(batch_dict["bg"])
        return {"x_all": x_all, "x": x, "s": s, "bg": bg}

    def __len__(self):
        return self.length + 1 - self.n_frames

    def __getitem__(self, idx):
        if idx < 0 or idx >= self.__len__():
            raise StopIteration()

        d = self.d - idx  # remove current time
        n = smlmsim.dynamics.batch_discretize_periods(d, n_frames=self.n_frames)
        n = self.photon_flux[:, None] * n

        mask_all = n.max(dim=-1).values > self.min_intensity
        xyz = self.xyz[mask_all]
        n = n[mask_all]
        x_all = torch.cat([xyz, n], dim=-1)

        mask = n[:, self.tg_frame_idx] > self.min_intensity
        x = x_all[mask]
        s = x[:, 3 + self.tg_frame_idx] >= self.significant_threshold

        return {"x_all": x_all, "x": x, "s": s, "bg": self.bg}
