import numpy as np
import torch

from typing import List, Dict, Optional

from lm_polygraph.estimators import Estimator
from lm_polygraph.stat_calculators.step.utils import flatten, reconstruct


class StepsSemanticEntropy(Estimator):
    """
    Estimates the sequence-level uncertainty of a language model following the method of
    "Semantic entropy" as provided in the paper https://arxiv.org/abs/2302.09664.
    Works only with whitebox models (initialized using lm_polygraph.utils.model.WhiteboxModel).

    This method calculates the generation entropy estimations merged by semantic classes using Monte-Carlo.
    The number of samples is controlled by lm_polygraph.stat_calculators.sample.SamplingGenerationCalculator
    'samples_n' parameter.
    """

    def __init__(
        self, verbose: bool = False, class_probability_estimation: str = "sum"
    ):
        self.class_probability_estimation = class_probability_estimation
        if self.class_probability_estimation == "sum":
            deps = ["sample_steps_log_probs", "sample_steps_texts", "steps_semantic_classes_entail"]
        elif self.class_probability_estimation == "frequency":
            deps = ["sample_steps_texts", "steps_semantic_classes_entail"]
        else:
            raise ValueError(
                f"Unknown class_probability_estimation: {self.class_probability_estimation}. Use 'sum' or 'frequency'."
            )

        super().__init__(deps, "claim")
        self.verbose = verbose

    def __str__(self):
        if self.class_probability_estimation == "sum":
            return "StepsSemanticEntropy"
        elif self.class_probability_estimation == "frequency":
            return "StepsSemanticEntropyEmpirical"

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates the semantic entropy for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * generated samples in 'sample_texts',
                * corresponding log probabilities in 'sample_log_probs',
                * matrix with semantic similarities in 'semantic_matrix_entail'
        Returns:
            np.ndarray: float semantic entropy for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        if self.class_probability_estimation == "sum":
            loglikelihoods_list: list | None = stats["sample_steps_log_probs"]
            hyps_list: list = stats["sample_steps_texts"]
        elif self.class_probability_estimation == "frequency":
            loglikelihoods_list: list | None = None
            hyps_list: list = stats["sample_steps_texts"]

        self._class_to_sample = flatten(stats["steps_semantic_classes_entail"]["class_to_sample"])
        self._sample_to_class = flatten(stats["steps_semantic_classes_entail"]["sample_to_class"])

        all_se = self.batched_call(flatten(hyps_list), flatten(loglikelihoods_list))
        return reconstruct(all_se, hyps_list)

    def batched_call(
        self,
        hyps_list: List[List[str]],
        loglikelihoods_list: Optional[List[List[float]]],
        log_weights: Optional[List[List[float]]] = None,
    ) -> np.array:
        if log_weights is None:
            log_weights = [None for _ in hyps_list]

        semantic_logits = {}
        # Iteration over batch
        for i in range(len(hyps_list)):
            if self.class_probability_estimation == "sum":
                class_likelihoods = [
                    np.array(loglikelihoods_list[i])[np.array(class_idx)]
                    for class_idx in self._class_to_sample[i]
                ]
                class_lp = [
                    np.logaddexp.reduce(likelihoods)
                    for likelihoods in class_likelihoods
                ]
            elif self.class_probability_estimation == "frequency":
                num_samples = len(hyps_list[i])
                class_lp = np.log(
                    [
                        len(class_idx) / num_samples
                        for class_idx in self._class_to_sample[i]
                    ]
                )

            if log_weights[i] is None:
                log_weights[i] = [0 for _ in hyps_list[i]]
            semantic_logits[i] = -np.mean(
                [
                    class_lp[self._sample_to_class[i][j]] * np.exp(log_weights[i][j])
                    for j in range(len(hyps_list[i]))
                ]
            )
        return np.array([semantic_logits[i] for i in range(len(hyps_list))])
