import torch
from torch import Tensor

from smlmsim.utils.insert import batched_insert_2d
from smlmsim.utils.power_series import cubic_3d_power_series


def _batched_preprocess_coordinates(
    x: Tensor, inv_voxel_size: Tensor, center: Tensor, psf_h: int, psf_w: int
):
    """Splits a stack (N,3) of 3d-coordinates into the depth index z_idx,
    the top left corner of the psf bounding box and the floating residual u.
    """
    u = x[:, :3] * inv_voxel_size
    u[:, :2] -= center[:2]
    u[:, 2] += center[2]
    u_idx = u.floor()
    u -= u_idx
    u_idx = u_idx.int()
    x_idx = u_idx[:, 0].add(psf_w)
    y_idx = u_idx[:, 1].add(psf_h)
    z_idx = u_idx[:, 2]
    u[:, :2] = 1.0 - u[:, :2]  # Z follows a different convention
    return x_idx, y_idx, z_idx, u


def _check_arguments(x: Tensor, psf: Tensor, center: Tensor, inv_voxel_size: Tensor):
    if x.ndim != 2 or x.shape[-1] != 3:
        raise ValueError("x must be a 2d tensor of shape [N, 3].")
    if torch.any(x.isnan()):
        raise ValueError("nan detected in x.")
    if psf.ndim != 4 or psf.shape[-1] != 64:
        raise ValueError("psf must be a 4D tensor of shape [D, H, W, 64].")
    if inv_voxel_size.ndim != 1 or inv_voxel_size.shape[0] != 3:
        raise ValueError("inv_voxel_size must be a 1D tensor of size 3.")
    if center.ndim != 1 or center.shape[0] != 3:
        raise ValueError("center must be a 1D tensor of size 3.")


def _check_index_range(idx: Tensor, lb: int, ub: int, name: str):
    if torch.any(idx < lb):
        x = idx[idx < lb][0]
        raise ValueError(f"Out of range {name}: Value {x} < {lb}.")

    if torch.any(idx >= ub):
        x = idx[idx >= ub][0]
        raise ValueError(f"Out of range {name}: Value {x} >= {ub}.")


def _check_indices_range(
    x_idx: Tensor,
    y_idx: Tensor,
    z_idx: Tensor,
    psf_d: int,
    psf_h: int,
    psf_w: int,
    h: int,
    w: int,
):
    _check_index_range(x_idx, lb=0, ub=w + psf_w, name="x")
    _check_index_range(y_idx, lb=0, ub=h + psf_h, name="y")
    _check_index_range(z_idx, lb=0, ub=psf_d, name="z")


def batched_render_coordinates(
    x: Tensor,
    center: Tensor,
    img_size: tuple[int, int],
    inv_voxel_size: Tensor,
    psf: Tensor,
) -> Tensor:
    """
    Given a stack of coordinates Nx3 with (x,y,z), a psf and its parameters,
    renders each of them through the psf at their location in a stack of images [N,h,w].
    """
    _check_arguments(x=x, psf=psf, center=center, inv_voxel_size=inv_voxel_size)
    device, dtype = x.device, x.dtype
    N = x.size(0)
    h, w = img_size
    psf_d, psf_h, psf_w, _ = psf.size()
    if N < 1:
        return

    # Split coordinates in relevant indices and floating remainders
    x_idx, y_idx, z_idx, u = _batched_preprocess_coordinates(
        x, inv_voxel_size=inv_voxel_size, center=center, psf_h=psf_h, psf_w=psf_w
    )
    _check_indices_range(
        x_idx, y_idx, z_idx, psf_d=psf_d, psf_h=psf_h, psf_w=psf_w, h=h, w=w
    )

    # Render each emitter through the psf
    u = cubic_3d_power_series(u)  # [N, 64]
    psf_at_z = psf[z_idx].view(N, psf_h * psf_w, 64)  # [N, h*w, 64]
    u = u.view(N, 64, 1)  # [N, 64, 1]
    rendered_psfs = torch.bmm(psf_at_z, u)  # [N, h*w, 1]
    rendered_psfs = rendered_psfs.view(N, psf_h, psf_w)

    # Map rendered psfs at their respective location and
    extended_output = torch.zeros(
        (N, h + 2 * psf_h, w + 2 * psf_w), dtype=dtype, device=device
    )
    batched_insert_2d(rendered_psfs, i_idx=y_idx, j_idx=x_idx, output=extended_output)
    return extended_output[:, psf_h:-psf_h, psf_w:-psf_w]


# def _check_arguments_jac(
#     x: Tensor, psf: Tensor, center: Tensor, inv_voxel_size: Tensor, output: Tensor
# ):
#     if x.ndim != 2 or x.shape[-1] != 3:
#         raise ValueError("x must be a 2d tensor of shape [N, 3].")
#     if psf.ndim != 4 or psf.shape[-1] != 64:
#         raise ValueError("psf must be a 4D tensor of shape [D, H, W, 64].")
#     if inv_voxel_size.ndim != 1 or inv_voxel_size.shape[0] != 3:
#         raise ValueError("inv_voxel_size must be a 1D tensor of size 3.")
#     if center.ndim != 1 or center.shape[0] != 3:
#         raise ValueError("center must be a 1D tensor of size 3.")
#     if output.ndim != 4 or output.shape[0] != x.shape[0] or output.shape[-1] != 4:
#         raise ValueError("output must be a 4d tensor [N,H,W,4] with N matching x's")


# def batched_render_coordinates_jac_and_value(
#     x: Tensor, psf: Tensor, center: Tensor, inv_voxel_size: Tensor, output: Tensor
# ) -> Tensor:
#     """
#     Given a stack of coordinates Nx3 with (x,y,z), a psf and its parameters,
#     returns the value and the jacobian of the image rendered with those corodinates and psf,
#     with shape [N,h,w,4]: first 3 dims are gradients w.r.t (x,y,z), and last
#     stack of images is the value returned by batched_render_coordinates.
#     """
#     _check_arguments_jac(
#         x=x, psf=psf, center=center, inv_voxel_size=inv_voxel_size, output=output
#     )
#     N = x.size(0)
#     _, h, w, _ = output.shape
#     psf_d, psf_h, psf_w, _ = psf.size()
#     if N < 1:
#         return

#     # Split coordinates in relevant indices and floating remainders
#     x_idx, y_idx, z_idx, u = _batched_preprocess_coordinates(
#         x, inv_voxel_size=inv_voxel_size, center=center, psf_h=psf_h, psf_w=psf_w
#     )
#     _check_indices_range(
#         x_idx, y_idx, z_idx, psf_d=psf_d, psf_h=psf_h, psf_w=psf_w, h=h, w=w
#     )

#     # Render each emitter through the psf
#     u = cubic_3d_power_series_jac_and_value(u)  # [N, 64, 4]
#     psf_at_z = psf[z_idx].view(N, psf_h * psf_w, 64)  # [N, h*w, 64]
#     u = u.view(N, 64, 4)  # [N, 64, 4]
#     rendered_psfs = torch.bmm(psf_at_z, u)  # [N, h*w, 4]
#     rendered_psfs = rendered_psfs.view(N, psf_h, psf_w, 4)  # [N, h, w, 4]

#     # Chain rules with operations during preprocess
#     rendered_psfs[..., :3] *= inv_voxel_size  # Due to scaling
#     rendered_psfs[..., :2] *= -1  # Due to the -1

#     # Map rendered psfs at their respective location and
#     extended_output = torch.zeros(
#         (N, h + 2 * psf_h, w + 2 * psf_w, 4),
#         dtype=output.dtype,
#         device=output.device,
#     )
#     batched_insert_2d(rendered_psfs, i_idx=y_idx, j_idx=x_idx, output=extended_output)
#     output += extended_output[:, psf_h:-psf_h, psf_w:-psf_w]
