import torch
from torch import Tensor, nn

import smlm


class Simulator(nn.Module):
    def __init__(
        self,
        adu_baseline: float,
        camera_type: str,
        e_adu: float,
        em_gain: float,
        inv_voxel_size: Tensor,
        jitter_std: float,
        psf_center: Tensor,
        psf: Tensor,
        quantum_efficiency: float,
        readout_noise: float,
        spurious_charge: float,
    ):
        super().__init__()
        # Load and prepare the PSF
        self.register_buffer("psf", psf, persistent=False)
        self.register_buffer("psf_center", psf_center, persistent=False)
        self.register_buffer("inv_voxel_size", inv_voxel_size, persistent=False)

        self.quantum_efficiency = quantum_efficiency
        self.spurious_charge = spurious_charge
        self.em_gain = em_gain
        self.readout_noise = readout_noise
        self.adu_baseline = adu_baseline
        self.inv_e_adu = 1.0 / e_adu
        self.camera_type = camera_type
        self.jitter_std = jitter_std

    def forward(self, x: Tensor, bg: Tensor, seed: int) -> Tensor:
        device, dtype = x.device, x.dtype
        size = [x.size(0), 1, 1, 1]
        gen = smlm.utils.random.get_generator(seed, device=device)

        em_gain = smlm.utils.random.jitter(
            self.em_gain, std=self.jitter_std, size=size, device=device, gen=gen
        )
        readout_noise = smlm.utils.random.jitter(
            self.readout_noise, std=self.jitter_std, size=size, device=device, gen=gen
        )
        spurious_charge = smlm.utils.random.jitter(
            self.spurious_charge, std=self.jitter_std, size=size, device=device, gen=gen
        )
        adu_baseline = smlm.utils.random.jitter(
            self.adu_baseline, std=self.jitter_std, size=size, device=device, gen=gen
        )
        inv_e_adu = smlm.utils.random.jitter(
            self.inv_e_adu, std=self.jitter_std, size=size, device=device, gen=gen
        )

        y = smlm.simulation.batched_simulate(
            x,
            bg=bg,
            adu_baseline=adu_baseline,
            camera_type=self.camera_type,
            em_gain=em_gain,
            inv_e_adu=inv_e_adu,
            inv_voxel_size=self.inv_voxel_size,
            psf_center=self.psf_center,
            psf=self.psf,
            quantum_efficiency=self.quantum_efficiency,
            readout_noise=readout_noise,
            spurious_charge=spurious_charge,
            seed=seed,
        )
        return y


class Renderer(nn.Module):
    def __init__(
        self,
        adu_baseline: float,
        camera_type: str,
        e_adu: float,
        em_gain: float,
        voxel_size: Tensor,
        psf_center: Tensor,
        psf: Tensor,
        quantum_efficiency: float,
        chunk_size: int = 0,
    ):
        super().__init__()
        self.chunk_size = chunk_size

        self.register_buffer("psf", psf, persistent=False)
        self.register_buffer("psf_center", psf_center, persistent=False)
        self.register_buffer(
            "inv_voxel_size", voxel_size.reciprocal(), persistent=False
        )

        self.adu_baseline = adu_baseline
        self.camera_type = camera_type
        self.em_gain = em_gain
        self.inv_e_adu = 1.0 / e_adu
        self.quantum_efficiency = quantum_efficiency

    def forward(self, x: Tensor, bg: Tensor) -> Tensor:
        y = smlm.simulation.batched_render(
            x,
            bg=bg,
            adu_baseline=self.adu_baseline,
            camera_type=self.camera_type,
            em_gain=self.em_gain,
            inv_e_adu=self.inv_e_adu,
            inv_voxel_size=self.inv_voxel_size,
            psf_center=self.psf_center,
            psf=self.psf,
            quantum_efficiency=self.quantum_efficiency,
        )
        return y
