import torch
from lightning_fabric import Fabric
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from ..metrics.maching_metrics import MatchingMetrics


def evaluate_thresholds(fabric: Fabric, dl: DataLoader, model: nn.Module, num: int):
    if model.supports_fast_calibration():
        return evaluate_thresholds_fast(fabric=fabric, dl=dl, model=model, num=num)
    return evaluate_thresholds_regular(fabric=fabric, dl=dl, model=model, num=num)


@torch.no_grad()
def evaluate_thresholds_regular(
    fabric: Fabric, dl: DataLoader, model: nn.Module, num: int
):
    model.eval()
    t0 = model.get_thresholds()

    L = torch.linspace(0, 1, num, device=fabric.device)
    d = len(t0) if t0.ndim != 0 else 0
    L = [L] * d if d != 0 else [L]
    grids = torch.meshgrid(*L, indexing="ij")
    T = torch.stack(grids, dim=-1)
    T = T.reshape(-1, d) if d != 0 else T.flatten()

    metrics = []
    for t in tqdm(T, disable=not fabric.is_global_zero, desc="applying thresholds"):
        metric = MatchingMetrics(sync_on_compute=True)
        metric = metric.to(device=fabric.device)
        model.set_thresholds(t)

        for batch in tqdm(dl, leave=False, disable=not fabric.is_global_zero):
            x_gt, s, y = batch
            x = model(y)
            metric.update(x, x_gt, s=s)

        r = metric.compute()
        metrics.append(r)

    model.set_thresholds(t0)  # back to normal
    return T, metrics


@torch.no_grad()
def evaluate_thresholds_fast(
    fabric: Fabric, dl: DataLoader, model: nn.Module, num: int
):
    model.eval()
    t0 = model.get_thresholds()
    d = len(t0) if t0.ndim != 0 else 0
    model.set_thresholds(None)  # will export raw outputs

    all_x, all_x_gt, all_s = [], [], []
    for batch in tqdm(
        dl, leave=False, desc="precompute outputs", disable=not fabric.is_global_zero
    ):
        x_gt, s, y = batch
        x = model(y)
        all_x.extend(x.unbind())
        all_x_gt.extend(x_gt.unbind())
        all_s.extend(s.unbind())

    all_x = torch.nested.nested_tensor(all_x, layout=torch.jagged)

    # every thresholds
    L = torch.linspace(0, 1, num, dtype=all_x.dtype, device=fabric.device)
    L = [L] * d if d != 0 else [L]
    grids = torch.meshgrid(*L, indexing="ij")
    all_t = torch.stack(grids, dim=-1)
    all_t = all_t.reshape(-1, d) if d != 0 else all_t.flatten()

    metric = MatchingMetrics(sync_on_compute=True)
    metric = metric.to(device=fabric.device)
    all_metrics = []
    for t in tqdm(all_t, disable=not fabric.is_global_zero, desc="applying thresholds"):
        x = model.apply_thresholds(all_x, thresholds=t)
        metric.update(x, all_x_gt, s=all_s)
        r = metric.compute()
        metric.reset()  # Mandatory
        all_metrics.append(r)

    model.set_thresholds(t0)  # back to normal
    return all_t, all_metrics
