import lightning.pytorch as pl
import torch
import torch.distributed as dist
import numpy as np


class BestPredictionSaver(pl.Callback):
    """
    Keep the validation predictions of the epoch that gave the best `monitor`
    metric.  The gathered predictions are available afterwards in
    `self.best_predictions` (only on rank 0).
    """

    def __init__(self, monitor: str = "val_loss", mode: str = "min"):
        self.monitor = monitor
        self.mode = mode
        self.best_value = np.inf if mode == "min" else -np.inf
        self.best_predictions: dict | None = None

    # ------------------------------------------------------------------
    # helper to do DDP → CPU gathering once and turn everything into numpy
    # ------------------------------------------------------------------
    def _gather_to_rank0(self, tensor: torch.Tensor) -> np.ndarray:
        if dist.is_available() and dist.is_initialized():
            gathered = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
            dist.all_gather(gathered, tensor.contiguous())
            tensor = torch.cat(gathered, dim=0)
        return tensor.cpu().numpy()

    # ------------------------------------------------------------------
    def on_validation_epoch_end(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule
    ):
        current = trainer.callback_metrics.get(self.monitor)
        if current is None:
            return  # metric not available yet (sanity check, etc.)

        is_better = (
            current < self.best_value
            if self.mode == "min"
            else current > self.best_value
        )
        if not is_better:
            return

        # --------------------------------------------------------------
        # New best → gather predictions / labels / (optional) probs
        # --------------------------------------------------------------
        preds = torch.cat(pl_module.val_preds)
        labels = torch.cat(pl_module.val_labels)
        probs = torch.cat(pl_module.val_probs) if pl_module.val_probs else None

        if trainer.is_global_zero:  # only rank 0 keeps a copy
            best = {
                "preds": self._gather_to_rank0(preds),
                "labels": self._gather_to_rank0(labels),
            }
            if probs is not None:
                best["probs"] = self._gather_to_rank0(probs)
            self.best_predictions = best
            self.best_value = float(current)
