import numpy as np
import torch
from lightning_fabric import Fabric
from scipy.optimize import least_squares
from torch import Tensor, nn
from torch.utils.data import DataLoader
from tqdm import tqdm

import smlm


def split_params(x: Tensor, psf_center):
    psf_center = torch.cat([x[:2], psf_center[-1, None]], dim=0)
    bg_value = x[-1]
    return psf_center, bg_value


def merge_params(psf_center, bg_photons):
    x = torch.cat([psf_center[:2], bg_photons], dim=0)
    return x


def psf_calibration_loop(fabric: Fabric, dl: DataLoader, renderer: nn.Module):
    if fabric.world_size > 1:
        raise ValueError("PSF calibration only supports one GPU.")

    device, dtype = fabric.device, torch.get_default_dtype()
    pbar = tqdm(desc="calibration")

    # pre-fetch dataloader (requirement of jacfwd)
    ds = [batch for batch in dl]

    # objective and jac functions
    def fun_torch(x: Tensor):
        renderer.psf_center, bg_photons = split_params(
            x, psf_center=renderer.psf_center
        )
        z_extent = smlm.utils.psf.get_z_extent(
            psf=renderer.psf,
            psf_center=renderer.psf_center,
            voxel_size=renderer.inv_voxel_size.reciprocal(),
            shrink_factor=1e-4,
        )
        z_extent = z_extent.to(device=device)
        r = []
        for batch in ds:
            x_gt, s, y = batch
            bg = bg_photons.expand_as(y[:, 0])
            x_gt = torch.stack(x_gt, dim=0)
            x_gt[..., 2].clamp_(min=z_extent[0], max=z_extent[1])
            y_pred = renderer(x_gt, bg=bg)
            r.extend([(y1 - y2).square().flatten() for y1, y2 in zip(y_pred, y)])
        r = torch.stack(r, dim=0).flatten()
        pbar.update(1)
        return r

    jac_torch = torch.func.jacfwd(fun_torch)

    def fun(x_np: np.array):
        x = torch.from_numpy(x_np).to(device=device, dtype=dtype)
        r = fun_torch(x)
        r = r.cpu().numpy()
        return r

    def jac(x_np: np.array):
        x = torch.from_numpy(x_np).to(device=device, dtype=dtype)
        j = jac_torch(x)
        j = j.cpu().numpy()
        return j

    # initialisation
    psf_center = [renderer.psf.shape[2], renderer.psf.shape[1], renderer.psf.shape[0]]
    psf_center = 0.5 * torch.tensor(psf_center, device=device, dtype=torch.float32)
    bg_photons = torch.zeros((1,), device=device)
    x0 = merge_params(psf_center=psf_center, bg_photons=bg_photons)
    x0 = x0.cpu().numpy()

    # least squares
    res = least_squares(fun=fun, jac=jac, x0=x0)
    if not res.success:
        raise ValueError("Optimisation did not converge.")
    pbar.close()
    x_np = res.x
    x = torch.from_numpy(x_np).to(device=device, dtype=dtype)
    psf_center, bg_photons = split_params(x, psf_center=renderer.psf_center)

    return {"psf_center": psf_center, "bg_photons": bg_photons}
