import numpy as np

from typing import Dict

from .estimator import Estimator
from .process_probs import process_probs


class MonteCarloSequenceEntropy(Estimator):
    def __init__(self, samples_source: str = "sample"):
        """
        Estimates the sequence-level uncertainty of a language model following the method of
        "Predictive entropy" as provided in the paper https://arxiv.org/abs/2302.09664.
        Works only with whitebox models (initialized using lm_polygraph.utils.model.WhiteboxModel).

        This method calculates the generation entropy estimations using Monte-Carlo estimation.
        The number of samples is controlled by lm_polygraph.stat_calculators.sample.SamplingGenerationCalculator
        'samples_n' parameter.
        """
        super().__init__([f"{samples_source}_log_probs"], "sequence")
        self.samples_source = samples_source

    def __str__(self):
        base = "MonteCarloSequenceEntropy"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates generation entropy with Monte-Carlo for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * log probabilities for each token in each sample, in 'sample_log_probs'
        Returns:
            np.ndarray: float uncertainty for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        logprobs = stats[f"{self.samples_source}_log_probs"]
        return np.array([-np.mean(lp) for lp in logprobs])


class MonteCarloSequenceEntropyP(Estimator):
    def __init__(
            self,
            samples_source: str = "beamsearch",
            **process_probs_args,
    ):
        """
        Estimates the sequence-level uncertainty of a language model following the method of
        "Predictive entropy" as provided in the paper https://arxiv.org/abs/2302.09664.
        Works only with whitebox models (initialized using lm_polygraph.utils.model.WhiteboxModel).

        This method calculates the generation entropy estimations using Monte-Carlo estimation.
        The number of samples is controlled by lm_polygraph.stat_calculators.sample.SamplingGenerationCalculator
        'samples_n' parameter.
        """
        super().__init__([f"{samples_source}_log_probs"], "sequence")
        self.samples_source = samples_source
        self.process_probs_args = process_probs_args

    def __str__(self):
        base = "MonteCarloSequenceEntropyP"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates generation entropy with Monte-Carlo for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * log probabilities for each token in each sample, in 'sample_log_probs'
        Returns:
            np.ndarray: float uncertainty for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        logprobs = stats[f"{self.samples_source}_log_probs"]
        ents = []
        for lp in logprobs:
            exp_probs = np.exp(lp)
            exp_probs = process_probs(exp_probs, **self.process_probs_args)
            ent = -(np.array(lp) * exp_probs).sum()
            ents.append(ent)
        return np.array(ents)
