import torch
from lightning_fabric import Fabric
from torch import Tensor, nn
from torch.utils.data import DataLoader
from tqdm import tqdm

import smlm
from smlm.metrics.maching_metrics import MatchingMetrics


@torch.no_grad()
def validation_loop(
    fabric: Fabric,
    dl: DataLoader,
    model: nn.Module,
    wooblecorr: bool = False,
    offset: Tensor = 0.0,
    desc: str = "validation",
):
    # wooble correction
    if wooblecorr:
        metrics = validation_loop(
            fabric=fabric, dl=dl, model=model, wooblecorr=False, offset=0, desc="wooble"
        )
        offset = metrics["offset"]

    # evaluation
    metric_func = MatchingMetrics().to(fabric.device)
    model.eval()
    for idx, batch in enumerate(
        tqdm(dl, leave=False, desc=desc, disable=not fabric.is_global_zero)
    ):
        y = batch["y"]
        x = model(y)
        x[..., :3] += offset

        x_gt, x_gt_lengths = batch["x"]
        s, _ = batch["s"]
        x_gt = smlm.utils.nested.expand_to_list(x_gt, lengths=x_gt_lengths)
        s = smlm.utils.nested.expand_to_list(s, lengths=x_gt_lengths)

        metric_func.update(x, x_gt, s=s)

    metrics = metric_func.compute()
    if wooblecorr:
        metrics["wooblecorr"] = offset
    return metrics
