import scipy.io as sio
import torch
from torch import Tensor


def crop_z(psf: Tensor, z0: int, z1: int):
    psf = psf[z0:z1]
    psf = psf.contiguous()
    return psf


def get_auto_center(psf: Tensor):
    """Returns the default middle center"""
    size = torch.tensor(
        [psf.size(2), psf.size(1), psf.size(0)], dtype=torch.float32, device=psf.device
    )
    return 0.5 * size


def depth_extent(
    psf_z_size: int,
    z_center: float,
    z_voxel_size: float = None,
    inv_z_voxel_size: float = None,
):
    """
    Return the extent of the PSF on the Z axis. Either z_voxel_size or inv_z_voxel_size can be specified.
    """

    if z_voxel_size is None and inv_z_voxel_size is None:
        raise AttributeError("Either z_voxel_size or inv_z_voxel_size should be set")
    if z_voxel_size is not None and inv_z_voxel_size is not None:
        print(
            "Warning: Both z_voxel_size and inv_z_voxel_size are set; discarding inv_z_voxel_size"
        )
    if z_voxel_size is None and inv_z_voxel_size is not None:
        z_voxel_size = 1.0 / inv_z_voxel_size

    m = -z_center * z_voxel_size
    M = (psf_z_size - z_center) * z_voxel_size
    return torch.tensor([m, M])


def load(
    path: str,
    device=torch.get_default_device(),
    dtype=torch.get_default_dtype(),
):
    """Load a .mat csplines PSF calibrated with SMAP"""
    data = sio.loadmat(path, struct_as_record=False, squeeze_me=True)["SXY"]
    psf = data.cspline.coeff
    psf = psf.transpose([2, 0, 1, 3])  # smlmsim takes z in first dim
    psf = torch.from_numpy(psf)
    psf = psf.to(dtype=dtype, device=device)
    psf = psf.contiguous()
    return psf
