import numpy as np

from typing import Dict

from .estimator import Estimator
from .process_probs import process_probs


class MonteCarloNormalizedSequenceEntropy(Estimator):
    def __init__(self, samples_source: str = "sample"):
        """
        Estimates the sequence-level uncertainty of a language model following the method of
        "Length-normalized predictive entropy" as provided in the paper https://arxiv.org/abs/2302.09664.
        Works only with whitebox models (initialized using lm_polygraph.utils.model.WhiteboxModel).

        This method calculates the generation entropy estimations using Monte-Carlo estimation with length normalization.
        The number of samples is controlled by lm_polygraph.stat_calculators.sample.SamplingGenerationCalculator
        'samples_n' parameter.
        """
        super().__init__([f"{samples_source}_log_probs", f"{samples_source}_tokens"], "sequence")
        self.samples_source = samples_source

    def __str__(self):
        base = "MonteCarloNormalizedSequenceEntropy"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates length normalized entropy with Monte-Carlo for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * a list of tokens for each sample, in 'sample_tokens'
                * log probabilities for each token, in 'sample_log_probs'
        Returns:
            np.ndarray: float uncertainty for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        logprobs = stats[f"{self.samples_source}_log_probs"]
        tokens = stats[f"{self.samples_source}_tokens"]
        return np.array(
            [
                -np.mean([lp_i / len(t_i) for lp_i, t_i in zip(lp, t) if len(t_i)])
                for lp, t in zip(logprobs, tokens)
            ]
        )


class MonteCarloNormalizedSequenceEntropyP(Estimator):
    def __init__(
            self,
            samples_source: str = "beamsearch",
            **process_probs_args,
    ):
        """
        Estimates the sequence-level uncertainty of a language model following the method of
        "Length-normalized predictive entropy" as provided in the paper https://arxiv.org/abs/2302.09664.
        Works only with whitebox models (initialized using lm_polygraph.utils.model.WhiteboxModel).

        This method calculates the generation entropy estimations using Monte-Carlo estimation with length normalization.
        The number of samples is controlled by lm_polygraph.stat_calculators.sample.SamplingGenerationCalculator
        'samples_n' parameter.
        """
        super().__init__([f"{samples_source}_log_probs", f"{samples_source}_tokens"], "sequence")
        self.samples_source = samples_source
        self.process_probs_args = process_probs_args

    def __str__(self):
        base = "MonteCarloNormalizedSequenceEntropyP"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates length normalized entropy with Monte-Carlo for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * a list of tokens for each sample, in 'sample_tokens'
                * log probabilities for each token, in 'sample_log_probs'
        Returns:
            np.ndarray: float uncertainty for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        logprobs = stats[f"{self.samples_source}_log_probs"]
        tokens = stats[f"{self.samples_source}_tokens"]
        res = []
        for lp, t in zip(logprobs, tokens):
            norm_logp = np.array([lp_i / len(t_i) for lp_i, t_i in zip(lp, t) if len(t_i)])
            logp = np.array([np.exp(lp) for lp_i, t_i in zip(lp, t) if len(t_i)])
            logp = process_probs(logp, **self.process_probs_args)
            mean_norm_logp = (logp * norm_logp).sum()
            res.append(-mean_norm_logp)
        return np.array(res)
