import numpy as np

from typing import List, Dict, Optional

from .estimator import Estimator
from .process_probs import process_probs


class SemanticEntropy(Estimator):
    """
    Estimates the sequence-level uncertainty of a language model following the method of
    "Semantic entropy" as provided in the paper https://arxiv.org/abs/2302.09664.
    Works only with whitebox models (initialized using lm_polygraph.utils.model.WhiteboxModel).

    This method calculates the generation entropy estimations merged by semantic classes using Monte-Carlo.
    The number of samples is controlled by lm_polygraph.stat_calculators.sample.SamplingGenerationCalculator
    'samples_n' parameter.
    """

    def __init__(
            self,
            verbose: bool = False,
            class_probability_estimation: str = "sum",
            samples_source: str = "sample",
    ):
        self.class_probability_estimation = class_probability_estimation
        if self.class_probability_estimation == "sum":
            deps = [f"{samples_source}_log_likelihoods", f"{samples_source}_texts",
                    f"{samples_source}_semantic_classes_entail"]
        elif self.class_probability_estimation == "frequency":
            deps = [f"{samples_source}_texts", f"{samples_source}_semantic_classes_entail"]
        else:
            raise ValueError(
                f"Unknown class_probability_estimation: {self.class_probability_estimation}. Use 'sum' or 'frequency'."
            )

        super().__init__(deps, "sequence")
        self.verbose = verbose
        self.samples_source = samples_source

    def __str__(self):
        if self.class_probability_estimation == "sum":
            base = "SemanticEntropy"
        elif self.class_probability_estimation == "frequency":
            base = "SemanticEntropyEmpirical"
        else:
            raise Exception('пошел нахуй')
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates the semantic entropy for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * generated samples in 'sample_texts',
                * corresponding log probabilities in 'sample_log_probs',
                * matrix with semantic similarities in 'sample_semantic_matrix_entail'
        Returns:
            np.ndarray: float semantic entropy for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        if self.class_probability_estimation == "sum":
            loglikelihoods_list = [[sum(a) for a in x] for x in stats[f"{self.samples_source}_log_likelihoods"]]
            hyps_list = stats[f"{self.samples_source}_texts"]
        elif self.class_probability_estimation == "frequency":
            loglikelihoods_list = None
            hyps_list = stats[f"{self.samples_source}_texts"]
        else:
            raise Exception("Unknown class_probability_estimation: {}".format(self.class_probability_estimation))

        classes_info = stats[f"{self.samples_source}_semantic_classes_entail"]
        if isinstance(classes_info, list):
            self._class_to_sample = []
            self._sample_to_class = []
            for c in classes_info:
                for k in sorted(c['class_to_sample'].keys()):
                    self._class_to_sample.append(c['class_to_sample'][k])
                    self._sample_to_class.append(c['sample_to_class'][k])
        else:
            self._class_to_sample = classes_info["class_to_sample"]
            self._sample_to_class = classes_info["sample_to_class"]

        return self.batched_call(hyps_list, loglikelihoods_list)

    def batched_call(
            self,
            hyps_list: List[List[str]],
            loglikelihoods_list: Optional[List[List[float]]],
            log_weights: Optional[List[List[float]]] = None,
    ) -> np.array:
        if log_weights is None:
            log_weights = [None for _ in hyps_list]

        semantic_logits = {}
        # Iteration over batch
        for i in range(len(hyps_list)):
            if self.class_probability_estimation == "sum":
                class_likelihoods = [
                    np.array(loglikelihoods_list[i])[np.array(class_idx)]
                    for class_idx in self._class_to_sample[i]
                ]
                class_lp = [
                    np.logaddexp.reduce(likelihoods)
                    for likelihoods in class_likelihoods
                ]
            elif self.class_probability_estimation == "frequency":
                num_samples = len(hyps_list[i])
                class_lp = np.log(
                    [
                        len(class_idx) / num_samples
                        for class_idx in self._class_to_sample[i]
                    ]
                )

            if log_weights[i] is None:
                log_weights[i] = [0 for _ in hyps_list[i]]
            semantic_logits[i] = -np.mean(
                [
                    class_lp[self._sample_to_class[i][j]] * np.exp(log_weights[i][j])
                    for j in range(len(hyps_list[i]))
                ]
            )
        return np.array([semantic_logits[i] for i in range(len(hyps_list))])


class SemanticEntropyP(Estimator):
    """
    Estimates the sequence-level uncertainty of a language model following the method of
    "Semantic entropy" as provided in the paper https://arxiv.org/abs/2302.09664.
    Works only with whitebox models (initialized using lm_polygraph.utils.model.WhiteboxModel).

    This method calculates the generation entropy estimations merged by semantic classes using Monte-Carlo.
    The number of samples is controlled by lm_polygraph.stat_calculators.sample.SamplingGenerationCalculator
    'samples_n' parameter.
    """

    def __init__(
            self,
            verbose: bool = False,
            class_probability_estimation: str = "sum",
            samples_source: str = "sample",
            **process_probs_args,
    ):
        self.class_probability_estimation = class_probability_estimation
        if self.class_probability_estimation == "sum":
            deps = [
                f"{samples_source}_log_likelihoods",
                f"{samples_source}_texts",
                f"{samples_source}_semantic_classes_entail",
            ]
        elif self.class_probability_estimation == "frequency":
            deps = [
                f"{samples_source}_texts",
                f"{samples_source}_semantic_classes_entail",
            ]
        else:
            raise ValueError(
                f"Unknown class_probability_estimation: {self.class_probability_estimation}. Use 'sum' or 'frequency'."
            )

        super().__init__(deps, "sequence")
        self.verbose = verbose
        self.samples_source = samples_source
        self.process_probs_args = process_probs_args

    def __str__(self):
        if self.class_probability_estimation == "sum":
            base = "SemanticEntropyP"
        elif self.class_probability_estimation == "frequency":
            base = "SemanticEntropyEmpiricalP"
        else:
            raise Exception('пошел нахуй')
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates the semantic entropy for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * generated samples in 'sample_texts',
                * corresponding log probabilities in 'sample_log_probs',
                * matrix with semantic similarities in 'sample_semantic_matrix_entail'
        Returns:
            np.ndarray: float semantic entropy for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        if self.class_probability_estimation == "sum":
            loglikelihoods_list = [[sum(a) for a in x] for x in stats[f"{self.samples_source}_log_likelihoods"]]
            hyps_list = stats[f"{self.samples_source}_texts"]
        elif self.class_probability_estimation == "frequency":
            loglikelihoods_list = None
            hyps_list = stats[f"{self.samples_source}_texts"]

        classes_info = stats[f"{self.samples_source}_semantic_classes_entail"]
        if isinstance(classes_info, list):
            self._class_to_sample = []
            self._sample_to_class = []
            for c in classes_info:
                for k in sorted(c['class_to_sample'].keys()):
                    self._class_to_sample.append(c['class_to_sample'][k])
                    self._sample_to_class.append(c['sample_to_class'][k])
        else:
            self._class_to_sample = classes_info["class_to_sample"]
            self._sample_to_class = classes_info["sample_to_class"]

        return self.batched_call(hyps_list, loglikelihoods_list)

    def batched_call(
            self,
            hyps_list: List[List[str]],
            loglikelihoods_list: Optional[List[List[float]]],
    ) -> np.array:
        semantic_logits = {}
        # Iteration over batch
        for i in range(len(hyps_list)):
            if self.class_probability_estimation == "sum":
                class_likelihoods = [
                    np.array(loglikelihoods_list[i])[np.array(class_idx)]
                    for class_idx in self._class_to_sample[i]
                ]
                class_lp = [
                    np.logaddexp.reduce(likelihoods)
                    for likelihoods in class_likelihoods
                ]
            elif self.class_probability_estimation == "frequency":
                num_samples = len(hyps_list[i])
                class_lp = np.log(
                    [
                        len(class_idx) / num_samples
                        for class_idx in self._class_to_sample[i]
                    ]
                )

            probs = np.array([np.exp(s) for s in loglikelihoods_list[i]])
            probs = process_probs(probs, **self.process_probs_args)

            class_lps = np.array([
                class_lp[self._sample_to_class[i][j]]
                for j in range(len(hyps_list[i]))
            ])
            semantic_logits[i] = -(class_lps * probs).sum()
        return np.array([semantic_logits[i] for i in range(len(hyps_list))])
