import torch
from torch import Generator, Tensor


def camera_noise(
    y: Tensor,
    quantum_efficiency: float,
    spurious_charge: float,
    em_gain: float,
    readout_noise: float,
    camera_type: str,
    gen: Generator,
) -> Tensor:
    """
    Apply shot noise and sensor noise
    """
    # shot noise
    y = quantum_efficiency * y + spurious_charge
    y.clamp_(min=0.0)  # prevent weird bugs
    y = torch.poisson(y, generator=gen)

    # sensor noise (gamma distribution is replaced by a gaussian approx)
    eps = torch.randn(y.shape, device=y.device, dtype=y.dtype, generator=gen)
    if camera_type == "EMCCD":
        y = y * em_gain
        std = torch.sqrt(y * em_gain + readout_noise**2)
        y = y + std * eps
        return y
    if camera_type == "sCMOS":
        std = readout_noise
        y = y + std * eps
        return y
    raise ValueError(f"Supported camera_type are EMCCD or sCMOS; found '{camera_type}'")


def camera_noise_gain(
    y: Tensor, quantum_efficiency: float, em_gain: float, camera_type: str
) -> Tensor:
    if camera_type == "EMCCD":
        return em_gain * quantum_efficiency * y
    if camera_type == "sCMOS":
        return quantum_efficiency * y
    raise ValueError(f"Supported camera_type are EMCCD or sCMOS; found '{camera_type}'")


def camera_noise_jac(
    y: Tensor, quantum_efficiency: float, em_gain: float, camera_type: str
) -> Tensor:
    return camera_noise_gain(
        y=y,
        camera_type=camera_type,
        em_gain=em_gain,
        quantum_efficiency=quantum_efficiency,
    )


def reciprocal_camera_noise_gain(
    y: Tensor, quantum_efficiency: float, em_gain: float, camera_type: str
) -> Tensor:
    if camera_type == "EMCCD":
        return y / em_gain / quantum_efficiency
    if camera_type == "sCMOS":
        return y / quantum_efficiency
    raise ValueError(f"Supported camera_type are EMCCD or sCMOS; found '{camera_type}'")
