import torch
from torch import Generator, Tensor
from torch.utils.data import Dataset, default_collate

import smlm


class SimulateSceneDataset(Dataset):
    def __init__(
        self,
        ds: Dataset,
        adu_baseline: float,
        camera_type: float,
        e_adu: float,
        em_gain: float,
        jitter_std: float,
        psf_center: Tensor,
        psf: Tensor,
        quantum_efficiency: float,
        readout_noise: float,
        seed: int,
        spurious_charge: float,
        voxel_size: Tensor,
    ):
        super().__init__()
        self.seed = smlm.utils.random.derive_new_seed(seed)
        self.ds = ds
        self.max_n_acts = self.ds.get_max_n_acts()

        self.adu_baseline = adu_baseline
        self.camera_type = camera_type
        self.em_gain = em_gain
        self.inv_e_adu = 1.0 / e_adu
        self.inv_voxel_size = 1.0 / voxel_size
        self.psf = psf
        self.psf_center = psf_center
        self.quantum_efficiency = quantum_efficiency
        self.readout_noise = readout_noise
        self.spurious_charge = spurious_charge
        self.jitter_std = jitter_std

    def collate_fn(self, batch):
        batch_dict = {key: [d[key] for d in batch] for key in batch[0]}
        y, x, s = batch_dict["y"], batch_dict["x"], batch_dict["s"]
        y = default_collate(y)
        x = smlm.utils.nested.pad_sequence(
            x,
            target_len=self.max_n_acts,
            returns_lengths=True,
        )
        s = smlm.utils.nested.pad_sequence(
            s,
            target_len=self.max_n_acts,
            returns_lengths=True,
        )
        return {"y": y, "x": x, "s": s}

    def increment_seed(self):
        self.seed = smlm.utils.random.derive_new_seed(self.seed)
        self.ds.increment_seed()

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        if idx < 0 or idx >= self.__len__():
            raise StopIteration()
        seed = smlm.utils.random.derive_new_seed(self.seed + idx + 1)
        gen = smlm.utils.random.get_generator(seed)

        device, size = "cpu", ()
        em_gain = smlm.utils.random.jitter(
            self.em_gain, std=self.jitter_std, size=size, device=device, gen=gen
        )
        readout_noise = smlm.utils.random.jitter(
            self.readout_noise, std=self.jitter_std, size=size, device=device, gen=gen
        )
        spurious_charge = smlm.utils.random.jitter(
            self.spurious_charge, std=self.jitter_std, size=size, device=device, gen=gen
        )
        adu_baseline = smlm.utils.random.jitter(
            self.adu_baseline, std=self.jitter_std, size=size, device=device, gen=gen
        )
        inv_e_adu = smlm.utils.random.jitter(
            self.inv_e_adu, std=self.jitter_std, size=size, device=device, gen=gen
        )

        scene = self.ds[idx]
        y = smlm.simulation.simulate(
            scene.pop("x_all"),
            bg=scene.pop("bg"),
            adu_baseline=adu_baseline,
            camera_type=self.camera_type,
            em_gain=em_gain,
            inv_e_adu=inv_e_adu,
            inv_voxel_size=self.inv_voxel_size,
            psf_center=self.psf_center,
            psf=self.psf,
            quantum_efficiency=self.quantum_efficiency,
            readout_noise=readout_noise,
            spurious_charge=spurious_charge,
            seed=seed,
        )
        scene["y"] = y
        return scene
