from collections import defaultdict

import numpy as np
from transformers import EvalPrediction

from src.metric.metric import Metric
from src.utils.logging_utils import get_logger

logger = get_logger()


class PerDsEvalLoss(Metric):
    def __init__(self) -> None:
        super().__init__()
        self.losses = defaultdict(list)
        self.num_nan_loss = 0

    def _compute(self, eval_preds: EvalPrediction, compute_result: bool = False):
        if not compute_result:
            ds_id_list = eval_preds.inputs["_ds_id"].tolist()
            losses = eval_preds.losses.tolist()
            for ds_id, loss in zip(ds_id_list, losses):
                if loss is None or np.isnan(loss):
                    self.num_nan_loss += 1
                    continue
                self.losses[ds_id].append(loss)
                return {}
        else:
            for ds_id, losses in self.losses.items():
                metrics = {f"{ds_id}_loss": np.mean(losses), "num_nan_loss": self.num_nan_loss}

            self.losses = defaultdict(list)
            self.num_nan_loss = 0

            return metrics
