import smlmsim
from torch import Tensor

import smlm


def simulate(
    x: Tensor,
    bg: Tensor,
    adu_baseline: float,
    camera_type: float,
    em_gain: float,
    inv_e_adu: float,
    inv_voxel_size: Tensor,
    psf_center: Tensor,
    psf: Tensor,
    quantum_efficiency: float,
    readout_noise: float,
    spurious_charge: float,
    seed: int,
):
    _check_x(x)
    x = x.unsqueeze(0)
    bg = bg.unsqueeze(0)
    y = batched_simulate(
        x=x,
        bg=bg,
        adu_baseline=adu_baseline,
        camera_type=camera_type,
        em_gain=em_gain,
        inv_e_adu=inv_e_adu,
        inv_voxel_size=inv_voxel_size,
        psf_center=psf_center,
        psf=psf,
        quantum_efficiency=quantum_efficiency,
        readout_noise=readout_noise,
        spurious_charge=spurious_charge,
        seed=seed,
    )
    y = y.squeeze(0)
    return y


def batched_simulate(
    x: Tensor,
    bg: Tensor,
    adu_baseline: float,
    camera_type: float,
    em_gain: float,
    inv_e_adu: float,
    inv_voxel_size: Tensor,
    psf_center: Tensor,
    psf: Tensor,
    quantum_efficiency: float,
    readout_noise: float,
    spurious_charge: float,
    seed: int,
):
    _check_x_batch(x)
    y = _batched_img_photons(
        x=x,
        bg=bg,
        inv_voxel_size=inv_voxel_size,
        psf_center=psf_center,
        psf=psf,
    )
    gen = smlm.utils.random.get_generator(seed, device=x.device)
    y = smlmsim.camera.apply_camera(
        y=y,
        adu_baseline=adu_baseline,
        camera_type=camera_type,
        em_gain=em_gain,
        inv_e_adu=inv_e_adu,
        quantum_efficiency=quantum_efficiency,
        readout_noise=readout_noise,
        spurious_charge=spurious_charge,
        gen=gen,
    )
    return y


def batched_render(
    x: Tensor,
    bg: Tensor,
    adu_baseline: float,
    camera_type: float,
    em_gain: float,
    inv_e_adu: float,
    inv_voxel_size: Tensor,
    psf_center: Tensor,
    psf: Tensor,
    quantum_efficiency: float,
):
    _check_x_batch(x)
    y = _batched_img_photons(
        x=x,
        bg=bg,
        inv_voxel_size=inv_voxel_size,
        psf_center=psf_center,
        psf=psf,
    )
    y = smlmsim.camera.apply_camera_gain(
        y=y,
        adu_baseline=adu_baseline,
        camera_type=camera_type,
        em_gain=em_gain,
        inv_e_adu=inv_e_adu,
        quantum_efficiency=quantum_efficiency,
    )
    return y


def _batched_img_photons(
    x: Tensor,
    bg: Tensor,
    inv_voxel_size: Tensor,
    psf_center: Tensor,
    psf: Tensor,
):
    _check_x_batch(x)
    bs, n, d = x.size(0), x.size(1), x.size(2)
    h, w = bg.size(-2), bg.size(-1)
    f = d - 3

    x_flat = x.view(bs * n, d)
    out = smlmsim.psf.batched_render_coordinates(
        x_flat[:, :3],
        center=psf_center,
        img_size=(h, w),
        inv_voxel_size=inv_voxel_size,
        psf=psf,
    )
    n_photons = x_flat[:, 3:]  # [N, f]
    out = out[:, None] * n_photons[:, :, None, None]  # [N, f, h, w]
    out = out.view(bs, n, f, h, w)
    y = out.sum(1)  # [bs, f, h, w]
    y = y + bg[:, None]
    return y


def _check_x_batch(x: Tensor):
    if x.ndim != 3 or x.size(-1) < 4:
        raise ValueError(
            """Fluorophores formalisme: x must be a 3d-batched tensor with shape [bs, N, d>=4].
            Each row is a fluorophore defined by the following constants:
            [x,y,z,num_photons_frame_1, ..., num_photons_frame_n]."""
        )


def _check_x(x: Tensor):
    if x.ndim != 2 or x.size(-1) < 4:
        raise ValueError(
            """Fluorophores formalisme: x must be a 2d tensor with shape [N, d>=4].
            Each row is a fluorophore defined by the following constants:
            [x,y,z,num_photons_frame_1, ..., num_photons_frame_n]."""
        )
