import math
import pandas as pd
from transformers import EvalPrediction

import wandb
from src.metric.metric import Metric
from src.utils.logging_utils import get_logger

logger = get_logger()


class ExPPLMetric(Metric):
    def __init__(self) -> None:
        super().__init__()
        self.losses = []
        self.log_count = 0

    def _compute(self, eval_preds: EvalPrediction, compute_result: bool = False):
        if not compute_result:
            id_list = eval_preds.inputs["_id"].tolist()
            ds_id_list = eval_preds.inputs["_ds_id"].tolist()
            losses = eval_preds.losses.tolist()
            for _id, ds_id, loss in zip(id_list, ds_id_list, losses):
                self.losses.append({"_id": _id, "_ds_id": ds_id, "ppl": math.exp(loss), "log_idx": self.log_count})
        else:
            self.log_count += 1
            if wandb.run is not None:
                return {"ex_ppl": wandb.Table(dataframe=pd.DataFrame(self.losses))}

        return {}
