import torch
from torch import Tensor
import smlmsim


def visualize_xy(psf: Tensor, center: Tensor = None, img_size: tuple[int, int] = None):
    device, dtype = psf.device, psf.dtype

    center = smlmsim.psf.get_auto_center(psf) if center is None else center
    if img_size is None:
        h = int(1.5 * psf.size(1))
        w = int(1.5 * psf.size(2))
        img_size = (h, w)
    h, w = img_size

    x = [[w * 0.5, h * 0.5, 0.0]]
    x = torch.tensor(x, device=device, dtype=dtype)
    inv_voxel_size = torch.ones((3,), device=device, dtype=dtype)
    y = smlmsim.psf.batched_render_coordinates(
        x,
        img_size=(h, w),
        psf=psf,
        center=center,
        inv_voxel_size=inv_voxel_size,
    )
    y = y.squeeze(0)
    return y


def visualize_zx(psf: Tensor, center: Tensor = None, img_size: tuple[int, int] = None):
    """img size has z in h"""
    device, dtype = psf.device, psf.dtype
    if img_size is None:
        w = int(1.5 * psf.size(2))
        h = psf.size(0)
        img_size = (h, w)
    h, w = img_size

    center = smlmsim.psf.get_auto_center(psf) if center is None else center
    z_range = smlmsim.psf.depth_extent(
        psf_z_size=psf.size(0), z_center=center[-1], z_voxel_size=1
    )
    z = torch.linspace(z_range[0], z_range[1], steps=h + 2)[1:-1]  # remove bounds
    xy = torch.tensor([0.5 * w, 0.5], device=device, dtype=dtype)
    xy = xy[None].expand((z.size(0), 2))
    x = torch.cat([xy, z[:, None]], dim=-1)

    inv_voxel_size = torch.ones((3,), device=device, dtype=dtype)
    y = smlmsim.psf.batched_render_coordinates(
        x,
        img_size=(1, w),
        psf=psf,
        center=center,
        inv_voxel_size=inv_voxel_size,
    )
    y = y.squeeze(1)
    return y
