import numpy as np
import scipy.io as sio
import smlmsim
import torch
from torch import Tensor


def get_z_extent(
    psf: Tensor, psf_center: Tensor, voxel_size: Tensor, shrink_factor: float = 0.0
):
    z_extent = smlmsim.psf.depth_extent(
        psf_z_size=psf.shape[0],
        z_center=psf_center[-1],
        z_voxel_size=voxel_size[-1],
    )
    z_extent = (1.0 - shrink_factor) * z_extent
    return z_extent


# def load(path: str):
#     psf = np.load(path).transpose([2, 0, 1, 3])
#     psf = torch.from_numpy(psf)
#     psf = psf.to(dtype=torch.get_default_dtype(), device=torch.get_default_device())
#     return psf


def load(
    path: str,
    device=torch.get_default_device(),
    dtype=torch.get_default_dtype(),
):
    """Load a .mat csplines PSF calibrated with SMAP"""
    mat = sio.loadmat(path, struct_as_record=False, squeeze_me=True)["SXY"]
    psf = mat.cspline.coeff
    psf = psf.transpose([2, 0, 1, 3])  # smlmsim takes z in first dim
    psf = torch.from_numpy(psf)
    psf = psf.to(dtype=dtype, device=device)
    psf = psf.contiguous()
    return psf


def load_center(
    path: str,
    device=torch.get_default_device(),
    dtype=torch.get_default_dtype(),
):
    """Extract the center from the mat file"""
    data = sio.loadmat(path, struct_as_record=False, squeeze_me=True)["SXY"]
    center = [data.cspline.x0 - 1, data.cspline.x0 - 1, data.cspline.z0 - 1]
    center = torch.tensor(center, device=device, dtype=dtype)
    return center


def load_dz(path: str):
    """Extract the dz from the config file"""
    data = sio.loadmat(path, struct_as_record=False, squeeze_me=True)["SXY"]
    dz = data.cspline.dz
    return dz
