from math import ceil

import smlmsim
import torch
from torch import Generator, Tensor
from torch.utils.data import Dataset

import smlm


class PerSampleSyntheticSceneDataset(Dataset):
    """
    Generates one synthetic scene per requested sample.
    """

    def __init__(
        self,
        bg_photon_mean: float,
        bg_photon_std: float,
        length: int,
        n_acts_per_frame: float,
        n_frames: int,
        n_pixels: int,
        photon_flux_mean: float,
        photon_flux_std: float,
        pixel_size: Tensor,
        seed: int,
        time_bleach: float,
        time_off: float,
        time_on: float,
        z_extent: Tensor,
    ):
        super().__init__()
        self.length = length
        self.seed = smlm.utils.random.derive_new_seed(seed)

        if n_frames % 2 != 1:
            raise ValueError("n_frames must be odd")
        self.n_frames = n_frames
        self.tg_frame_idx = n_frames // 2
        self.n_pixels = n_pixels
        self.bg_photon_mean = bg_photon_mean
        self.bg_photon_std = bg_photon_std
        self.vol_extent = smlm.utils.extent.get_vol_extent(
            h=n_pixels, w=n_pixels, pixel_size=pixel_size, z_extent=z_extent
        )

        self.photon_flux_mean = photon_flux_mean
        self.photon_flux_std = photon_flux_std
        self.time_on = time_on
        self.time_off = time_off
        self.time_bleach = time_bleach
        self.n_periods_max = smlmsim.dynamics.icdf_n_activations(
            1e-3, time_bleach=time_bleach, time_off=time_off
        )

        self.min_intensity = 0.01 * self.photon_flux_mean

        # We estimate a few constants based on some generated fluorophores
        gen = smlm.utils.random.get_generator(0)
        x = smlm.activations.sample.sample_fluorophores(
            N=5000,
            vol_extent=self.vol_extent,
            time_on=self.time_on,
            time_off=self.time_off,
            time_bleach=self.time_bleach,
            n_periods_max=self.n_periods_max,
            n_frames=self.n_periods_max,
            photon_flux_mean=self.photon_flux_mean,
            photon_flux_std=self.photon_flux_std,
            gen=gen,
        )
        x = x[:, 3:]  # remove xyz

        mean_n_frames_lifetime = (x > 0.0).sum(dim=-1).float().mean(dim=0)
        self.n_warmup_frames = ceil(mean_n_frames_lifetime)
        self.total_n_frames = self.n_warmup_frames + self.n_frames
        self.n_fluos = ceil(
            n_acts_per_frame * self.total_n_frames / mean_n_frames_lifetime
        )

        x = x.flatten()
        x = x[x > 0.0]
        self.significant_threshold = torch.quantile(x, q=0.25)

    def collate_fn(self, batch):
        batch_dict = {key: [d[key] for d in batch] for key in batch[0]}

        x_all = smlm.utils.nested.pad_sequence(
            batch_dict["x_all"],
            target_len=self.n_fluos,
            returns_lengths=True,
        )
        x = smlm.utils.nested.pad_sequence(
            batch_dict["x"],
            target_len=self.n_fluos,
            returns_lengths=True,
        )
        s = smlm.utils.nested.pad_sequence(
            batch_dict["s"],
            target_len=self.n_fluos,
            returns_lengths=True,
        )
        bg = torch.utils.data.default_collate(batch_dict["bg"])
        return {"x_all": x_all, "x": x, "s": s, "bg": bg}

    def increment_seed(self):
        self.seed = smlm.utils.random.derive_new_seed(self.seed)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if idx < 0 or idx >= self.__len__():
            raise StopIteration()

        seed = smlm.utils.random.derive_new_seed(self.seed + idx + 1)
        gen = Generator().manual_seed(seed)

        N = self.n_fluos
        if N > 10:
            N = torch.randint(low=10, high=N, size=(), generator=gen)

        x_all = smlm.activations.sample.sample_fluorophores(
            N=N,
            vol_extent=self.vol_extent,
            time_on=self.time_on,
            time_off=self.time_off,
            time_bleach=self.time_bleach,
            n_periods_max=self.n_periods_max,
            n_frames=self.total_n_frames,
            photon_flux_mean=self.photon_flux_mean,
            photon_flux_std=self.photon_flux_std,
            gen=gen,
        )
        xyz = x_all[:, :3]
        n_all = x_all[:, 3:]
        n_all = n_all[:, self.n_warmup_frames :]
        x_all = torch.cat([xyz, n_all], dim=-1)

        mask = n_all[:, self.tg_frame_idx] >= self.min_intensity
        x = x_all[mask]
        s = x[:, 3 + self.tg_frame_idx] >= self.significant_threshold
        x_all = x_all[n_all.max(dim=-1).values > 0.0]

        bg = smlm.activations.sample.sample_background(
            h=self.n_pixels,
            w=self.n_pixels,
            mean=self.bg_photon_mean,
            std=self.bg_photon_std,
            gen=gen,
        )

        return {"x_all": x_all, "x": x, "s": s, "bg": bg}
