from lm_polygraph.estimators.estimator import Estimator

import numpy as np
from typing import Dict

import logging

log = logging.getLogger(__name__)


class PRMEstimator(Estimator):
    def __init__(
            self,
            reduction: str = 'min',
            model_id: str = '',
    ):
        self.reduction = reduction
        self.model_id = model_id
        # initialize dependencies now that model_id is known
        deps = [f"prm_scores_{self.model_id}", "claims"]
        super().__init__(deps, "sequence")

    def __str__(self):
        base_name = "PRM"
        if self.model_id:
            base_name = f"PRM_{self.model_id}_{self.reduction}"
        return base_name

    def _reduce(self, x):
        if self.reduction == 'mean':
            return np.mean(x)
        elif self.reduction == 'min':
            return np.min(x)
        elif self.reduction == 'max':
            return np.max(x)
        raise Exception(f"Unknown reduction type: {self.reduction}")

    def __call__(self, stats: Dict[str, np.ndarray]) -> list[float]:
        seq_ue = []
        for sample_prms, sample_claims in zip(
                stats[f"prm_scores_{self.model_id}"],
                stats["claims"],
        ):
            claim_ue = [-x for x in sample_prms]
            seq_ue.append(self._reduce(claim_ue))
        return seq_ue
