import scipy
import torch
from lightning_fabric import Fabric
from torch import Tensor, nn
from torch.utils.data import DataLoader
from tqdm import tqdm

import smlm
from smlm.metrics.maching_metrics import MatchingMetrics


def calibration_loop(
    fabric: Fabric, dl: DataLoader, model: nn.Module, metric: str, maxfev: int = 500
):
    if len(dl) == 0:
        return model.get_thresholds()
    if model.supports_fast_calibration():
        return calibration_loop_fast(
            fabric=fabric, dl=dl, model=model, watched_metric=metric, maxfev=maxfev
        )
    return calibration_loop_regular(fabric=fabric, dl=dl, model=model, maxfev=maxfev)


@torch.no_grad()
def calibration_loop_regular(
    fabric: Fabric, dl: DataLoader, model: nn.Module, maxfev: int = 500
):
    """Regular research"""
    model.eval()

    with tqdm(
        total=maxfev,
        leave=False,
        disable=not fabric.is_global_zero,
        desc="calibration",
    ) as pbar:

        def fun(t):
            t = torch.from_numpy(t) if t.ndim > 0 else torch.tensor([t])
            t = t.to(device=fabric.device, dtype=torch.float32)
            t = fabric.broadcast(t, src=0)
            model.set_thresholds(t)
            metric = MatchingMetrics(sync_on_compute=True)
            metric = metric.to(device=fabric.device)
            for batch in tqdm(dl, leave=False, disable=not fabric.is_global_zero):
                x_gt, s, y = batch
                x = model(y)
                metric.update(x, x_gt, s=s)
            r = metric.compute()
            pbar.update(1)
            return -r[metric].item()

        t0 = model.get_thresholds()
        t0 = t0.cpu().numpy()
        if t0.ndim != 0:
            res = scipy.optimize.minimize(
                fun,
                x0=t0,
                bounds=[(0.0, 1.0)] * t0.shape[0],
                method="Powell",
                options={"maxfev": maxfev, "xtol": 1e-2},
            )
        else:
            res = scipy.optimize.minimize_scalar(
                fun,
                bounds=(0.0, 1.0),
                options={"maxiter": maxfev, "xatol": 1e-3},
            )

        tf = torch.from_numpy(res.x) if res.x.ndim > 0 else torch.tensor([res.x])
        tf = tf.to(device=fabric.device, dtype=torch.float32)
        pbar.close()

    tf = fabric.broadcast(tf, src=0)  # to make sur it propagates well
    return tf


@torch.no_grad()
def calibration_loop_fast(
    fabric: Fabric,
    dl: DataLoader,
    model: nn.Module,
    watched_metric: str,
    maxfev: int = 500,
):
    """Precompute outputs to speed up the search"""
    model.eval()
    t0 = model.get_thresholds()
    model.set_thresholds(None)  # will export raw outputs

    all_x, all_x_gt, all_s = [], [], []
    for batch in tqdm(
        dl, leave=False, desc="precompute outputs", disable=not fabric.is_global_zero
    ):
        y = batch["y"]
        x = model(y)

        x_gt, x_gt_lengths = batch["x"]
        s, _ = batch["s"]
        x_gt = smlm.utils.nested.expand_to_list(x_gt, lengths=x_gt_lengths)
        s = smlm.utils.nested.expand_to_list(s, lengths=x_gt_lengths)

        all_x.extend(x.unbind())
        all_x_gt.extend(x_gt)
        all_s.extend(s)

    all_x = torch.nested.nested_tensor(all_x, layout=torch.jagged)

    tf = None
    # search for a treshold
    with tqdm(
        leave=False, disable=not fabric.is_global_zero, desc="searching thresholds"
    ) as pbar:

        def fun(t):
            metric = MatchingMetrics(sync_on_compute=True)
            metric = metric.to(device=fabric.device)
            t = torch.from_numpy(t) if t.ndim > 0 else torch.tensor([t])
            t = t.to(device=fabric.device, dtype=torch.float32)
            x = model.apply_thresholds(all_x, thresholds=t)
            metric.update(x, all_x_gt, s=all_s)
            r = metric.compute()
            pbar.update(1)
            return -r[watched_metric].item()

        t0 = t0.cpu().numpy()
        if t0.ndim != 0:
            res = scipy.optimize.minimize(
                fun,
                x0=t0,
                bounds=[(0.0, 1.0)] * t0.shape[0],
                method="Powell",
                options={"maxfev": maxfev, "xtol": 1e-2},
            )
        else:
            res = scipy.optimize.minimize_scalar(
                fun,
                bounds=(0.0, 1.0),
                options={"maxiter": maxfev, "xatol": 1e-2},
            )

        tf = res.x
        pbar.close()

    tf = fabric.broadcast(tf, src=0)  # to make sur it propagates well
    tf = torch.from_numpy(tf) if res.x.ndim > 0 else torch.tensor([res.x])
    tf = tf.to(device=fabric.device, dtype=torch.float32)
    return tf
