import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import evaluate
import numpy as np
from lm_polygraph.estimators.sar import SAR
from lm_polygraph.stat_calculators.cross_encoder_similarity import (
    CrossEncoderSimilarityMatrixCalculator,
)
from lm_polygraph.utils.model import WhiteboxModel
from scipy.optimize import minimize
from semantic_uncertainty.uncertainty.uncertainty_measures.semantic_entropy import (
    EntailmentDeberta,
    get_semantic_ids,
)
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the ROUGE metric once
rouge_metric = evaluate.load("rouge")


def argmin_objective(
    x: np.ndarray, samples: np.ndarray, distance_func: Callable
) -> float:
    """
    Objective function: average distance between x and each sample in samples.
    Parameters:
    - x: probability vector (decision variable)
    - samples: numpy array of shape (N, n) where each row is a sample probability vector.

    Returns:
    - average Hellinger distance between x and each sample in samples.
    """
    N = samples.shape[0]
    total = 0.0
    for v in samples:
        total += distance_func(x, v)
    return total / N


def _find_argmin(p_samples: np.ndarray, distance_func: Callable) -> np.ndarray:
    """
    Optimize to find x in the simplex that minimizes the average distance to the given samples.

    Parameters:
    - samples: numpy array of shape (N, n) where each row is a probability vector.

    Returns:
    - x_opt: the optimal probability vector.
    """

    n = p_samples.shape[1]
    if n == 1:
        return np.array([1.0])
    # Initial guess: uniform distribution
    x0 = np.full(n, 1 / n)

    # Constraints: sum(x) == 1
    cons = {"type": "eq", "fun": lambda x: np.sum(x) - 1}

    # Bounds: x_i in [0, 1]
    bnds = [(0, 1) for _ in range(n)]

    res = minimize(
        argmin_objective,
        x0,
        args=(p_samples, distance_func),
        method="SLSQP",
        bounds=bnds,
        constraints=cons,
    )

    # if not res.success:
    #    raise RuntimeError("Optimization did not converge: " + res.message)

    return res.x


class UEMetric(ABC):
    @abstractmethod
    def compute_metric(self, run_data: Dict[str, Any]) -> Dict[str, Union[float, Any]]:
        """Compute the metric of interest given the experimental run data."""
        pass


class JSDivergenceMetric(UEMetric):
    def __init__(
        self, similarity_threshold: float = 0.3, implication_model: str = "rouge"
    ):
        self.metric_name = "js"
        self.similarity_threshold = similarity_threshold
        if implication_model == "rouge":
            self.implication_model = RougeModel(
                threshold=similarity_threshold
            )  # Can also be changed to entailment model
        elif implication_model == "deberta":
            self.implication_model = EntailmentDeberta()
        else:
            raise ValueError("Invalid implication model")

    def compute_metric(self, run_data: Dict[str, Any]) -> Dict[str, Union[float, Any]]:
        """
        Compute the Jensen-Shannon divergence between the question and the answer.
        The run_data dictionary is expected to have the following keys:
          - "log_likelihood_per_semantic_id": list of log-likelihoods for the decoys.
          - "decoys": a dict mapping decoy labels to their outputs.
        """

        # 1st get the predictive distribution for the model run
        semantic_ids = run_data["semantic_ids"]
        unique_ids = np.unique(semantic_ids)  # get the clusters
        idx_of_ids = np.array(
            [np.argwhere(semantic_ids == i)[0][0] for i in unique_ids]
        )  # get the index of the first occurence of each cluster
        x0_answers = [
            run_data["tokens_decoded_generated_truncated"][i] for i in idx_of_ids
        ]  # get the answers corresponding to the first occurence of each cluste
        log_likelihoods = run_data["log_likelihood_per_semantic_id"]
        probs = np.exp(log_likelihoods)
        probs = probs / probs.sum()  # ensure sum to 1
        p_hat = {ans: prob for ans, prob in zip(x0_answers, probs)}
        p_true = {
            ans[0]: prob
            for ans, prob in zip(run_data["answer"], run_data["gt_probabilities"])
        }

        # Get union support, i.e. align for JS divergence calculation
        out = self._get_union_support([p_hat, p_true], run_data["question"])
        vocab = out["vocab"]
        p_hat = out["p_samples"][0]
        p_true = out["p_samples"][1]
        mapping = out["mapping"]

        # 3. Calculate the JS divergence on p_hat and p_true.
        metric = self.js_divergence(p_hat, p_true)

        return {
            "metric": metric,
            "metric_data": {
                "vocab": vocab,
                "p_hat": p_hat,
                "p_true": p_true,
                "mapping": mapping,
            },
        }

    def js_divergence(self, p: np.ndarray, q: np.ndarray, base: float = 2) -> float:
        """
        Compute the Jensen–Shannon divergence between two probability distributions using only NumPy.

        Parameters
        ----------
        p : np.ndarray
            First non-negative array (will be normalized to sum to 1).
        q : np.ndarray
            Second non-negative array (same shape as p; will be normalized to sum to 1).
        base : float, optional
            Logarithm base for the divergence (default=2, so result is in bits).

        Returns
        -------
        float
            Jensen–Shannon divergence.
        """
        # Convert to float and check shapes
        p = np.asarray(p, dtype=np.float64)
        q = np.asarray(q, dtype=np.float64)
        if p.shape != q.shape:
            raise ValueError("p and q must have the same shape")

        # Normalize to probability distributions
        p_sum = p.sum()
        q_sum = q.sum()
        if p_sum == 0 or q_sum == 0:
            # log warning and return 1
            logging.warning("One of the distributions is empty, returning 1")
            return 1
            # raise ValueError("Input arrays must have non-zero sum")
        p /= p_sum
        q /= q_sum

        # Mixture distribution
        m = 0.5 * (p + q)

        # Define a helper for KL divergence: KL(u || v) = sum u * log(u/v)
        def kl_div(u, v):
            # Only include terms where u > 0 to avoid 0*log(0)
            mask = u > 0
            return np.sum(u[mask] * np.log(u[mask] / v[mask]))

        # Compute JS divergence
        js = 0.5 * (kl_div(p, m) + kl_div(q, m))

        # Change log base if needed
        if base != np.e:
            js /= np.log(base)

        return js

    def _aggregate_sample(self, sample_dict, mapping):
        """
        For one sample dictionary, aggregate probabilities according to the canonical mapping.
        If two keys map to the same canonical answer, sum their probabilities.
        """
        aggregated = {}
        for ans, prob in sample_dict.items():
            canon = mapping[ans]
            aggregated[canon] = aggregated.get(canon, 0) + prob
        return aggregated

    def _get_union_support(
        self, p_c_x: List[Dict[str, float]], question: Optional[str] = None
    ) -> Dict[str, Union[List[str], np.ndarray]]:
        """
        Given the predictive distributions, find the union support of the distributions.
        I.e. If we have
        p(c_1|x) = {London: 0.5, Paris: 0.5}
        p(c_2|x) = {Paris: 0.3, Berlin: 0.7}

        -> vocab = {Berlin, London, Paris}
        -> vectors = [[0, 0.5, 0.5], [0.7, 0, 0.3]]
        """

        union_support = set()
        for p in p_c_x:
            union_support.update(p.keys())

        # Now check which of them are semantically equivalent
        union_support = list(union_support)

        # check if implication model is EntailmentDeberta
        if isinstance(self.implication_model, EntailmentDeberta):
            # add question to union_support
            entailment_input = [question + " " + answer for answer in union_support]
        else:
            entailment_input = union_support

        semantic_ids = get_semantic_ids(
            strings_list=entailment_input,
            model=self.implication_model,
            strict_entailment=False,
            example=None,
        )

        # now we need to find the cluster representative for each cluster by taking the first occurence
        vocab_idx = np.array(
            [np.argwhere(semantic_ids == i)[0][0] for i in np.unique(semantic_ids)]
        )
        vocab = [union_support[i] for i in vocab_idx]

        # create a mapping from unique id to cluster representative
        reverse_mapping = {i: vocab[i] for i in np.unique(semantic_ids)}

        mapping = {
            union_support[i]: reverse_mapping[semantic_ids[i]]
            for i in range(len(union_support))
        }

        # Sort the vocabulary for consistent ordering
        vocab = sorted(vocab)

        # Process each dictionary to aggregate probabilities using canonical keys.
        vectors = []
        for d in p_c_x:
            aggregated = self._aggregate_sample(d, mapping)
            vector = [aggregated.get(word, 0) for word in vocab]
            vectors.append(vector)

        # The first vector in vectors is the clean run
        return {"vocab": vocab, "p_samples": np.array(vectors), "mapping": mapping}


class WassersteinDistanceMetric(UEMetric):
    def __init__(
        self,
        similarity_threshold: float = 0.3,
        p: int = 1,
        surrogate: str = "argmin",
        implication_model: str = "rouge",
    ):
        self.metric_name = "wd"
        self.similarity_threshold = similarity_threshold
        self.p = p
        self.surrogate = surrogate
        if implication_model == "rouge":
            self.implication_model = RougeModel(
                threshold=similarity_threshold
            )  # Can also be changed to entailment model
        elif implication_model == "deberta":
            self.implication_model = EntailmentDeberta()
        else:
            raise ValueError("Invalid implication model")

    def compute_metric(
        self, run_data: Dict[str, Any]
    ) -> Dict[str, Union[float, Dict[str, Any]]]:
        """
        Compute the Wasserstein distance between the question and the answer.
        The run_data dictionary is expected to have the following
        keys:
          - "log_likelihood_per_semantic_id": list of log-likelihoods for the decoys.
          - "decoys": a dict mapping decoy labels to their outputs.
        """

        p_c_x = self._get_predictive_distributions(run_data)
        union_support = self._get_union_support(p_c_x, run_data["question"])
        if self.surrogate == "argmin":
            p_hat = _find_argmin(union_support["p_samples"], self._hellinger_distance)
            w_d = self._wasserstein_mc(
                p_hat, union_support["p_samples"], self._hellinger_distance
            )
        elif self.surrogate == "p0":
            p_hat = union_support["p_samples"][0]
            w_d = self._wasserstein_mc(
                p_hat, union_support["p_samples"][1:], self._hellinger_distance
            )

        if np.isnan(w_d):
            w_d = 2  # dummy value

        return {
            "metric": w_d,
            "metric_data": {
                "union_support": union_support,
                "p_hat": p_hat,
                "p_c_x": p_c_x,
            },
        }

    def _wasserstein_mc(
        self, p_hat: np.ndarray, p_samples: np.ndarray, distance_func: Callable
    ) -> float:
        """
        Compute the W_p distance between the empirical distribution of p_samples and a dirac delta distribution at p_hat
        """

        return np.mean([distance_func(p_hat, p) ** self.p for p in p_samples]) ** (
            1 / self.p
        )

    def _hellinger_distance(self, p, q):
        return np.sqrt(np.sum((np.sqrt(p) - np.sqrt(q)) ** 2)) / np.sqrt(2)

    def _get_predictive_distributions(
        self, run_data: Dict[str, Any]
    ) -> List[Dict[str, float]]:
        """
        Given the run data, extracts the prediction distributions p(c|x) for the clean run and every deocy
        """

        decoys = list(run_data["decoys"].keys())

        # take as answers the first occurence of every cluster. Important when decoys are top-k tokens.
        semantic_ids = run_data["semantic_ids"]
        unique_ids = np.unique(semantic_ids)  # get the clusters
        idx_of_ids = np.array(
            [np.argwhere(semantic_ids == i)[0][0] for i in unique_ids]
        )  # get the index of the first occurence of each cluster
        x0_answers = [
            run_data["tokens_decoded_generated_truncated"][i] for i in idx_of_ids
        ]  # get the answers corresponding to the first occurence of each cluste
        log_likelihoods = run_data["log_likelihood_per_semantic_id"]
        probs = np.exp(log_likelihoods)
        probs = probs / probs.sum()  # ensure sum to 1
        p_c_x = [{ans: prob for ans, prob in zip(x0_answers, probs)}]

        for decoy in decoys:
            output = run_data["decoys"][decoy]
            log_likelihoods = output["log_likelihood_per_semantic_id"]
            probs = np.exp(log_likelihoods)
            probs = probs / probs.sum()  # ensure sum to 1

            # now we need to find the first occurence of each cluster
            semantic_ids = output["semantic_ids"]
            unique_ids = np.unique(semantic_ids)
            first_indices = [
                np.argwhere(semantic_ids == uid)[0][0] for uid in unique_ids
            ]
            # get the answers corresponding to the first occurence of each cluster
            answers = [
                output["tokens_decoded_generated_truncated"][i] for i in first_indices
            ]

            p_c_x.append({ans: prob for ans, prob in zip(answers, probs)})

        return p_c_x

    def _aggregate_sample(self, sample_dict, mapping):
        """
        For one sample dictionary, aggregate probabilities according to the canonical mapping.
        If two keys map to the same canonical answer, sum their probabilities.
        """
        aggregated = {}
        for ans, prob in sample_dict.items():
            canon = mapping[ans]
            aggregated[canon] = aggregated.get(canon, 0) + prob
        return aggregated

    def _get_union_support(
        self, p_c_x: List[Dict[str, float]], question: Optional[str] = None
    ) -> Dict[str, Union[List[str], np.ndarray]]:
        """
        Given the predictive distributions, find the union support of the distributions.
        I.e. If we have
        p(c_1|x) = {London: 0.5, Paris: 0.5}
        p(c_2|x) = {Paris: 0.3, Berlin: 0.7}

        -> vocab = {Berlin, London, Paris}
        -> vectors = [[0, 0.5, 0.5], [0.7, 0, 0.3]]
        """

        union_support = set()
        for p in p_c_x:
            union_support.update(p.keys())

        # Now check which of them are semantically equivalent
        union_support = list(union_support)

        # check if implication model is EntailmentDeberta
        if isinstance(self.implication_model, EntailmentDeberta):
            # add question to union_support
            entailment_input = [question + " " + answer for answer in union_support]
        else:
            entailment_input = union_support

        semantic_ids = get_semantic_ids(
            strings_list=entailment_input,
            model=self.implication_model,
            strict_entailment=False,
            example=None,
        )

        # now we need to find the cluster representative for each cluster by taking the first occurence
        vocab_idx = np.array(
            [np.argwhere(semantic_ids == i)[0][0] for i in np.unique(semantic_ids)]
        )
        vocab = [union_support[i] for i in vocab_idx]

        # create a mapping from unique id to cluster representative
        reverse_mapping = {i: vocab[i] for i in np.unique(semantic_ids)}

        mapping = {
            union_support[i]: reverse_mapping[semantic_ids[i]]
            for i in range(len(union_support))
        }

        # Sort the vocabulary for consistent ordering
        vocab = sorted(vocab)

        # Process each dictionary to aggregate probabilities using canonical keys.
        vectors = []
        for d in p_c_x:
            aggregated = self._aggregate_sample(d, mapping)
            vector = [aggregated.get(word, 0) for word in vocab]
            vectors.append(vector)

        # The first vector in vectors is the clean run
        return {"vocab": vocab, "p_samples": np.array(vectors)}


class ArgMaxMetric(UEMetric):
    def __init__(
        self,
        similarity_threshold: float = 0.3,
        implication_model: str = "deberta",
        **kwargs,
    ):
        self.metric_name = "argmax"
        self.similarity_threshold = similarity_threshold
        if implication_model == "rouge":
            self.implication_model = RougeModel(threshold=similarity_threshold)
        elif implication_model == "deberta":
            self.implication_model = EntailmentDeberta()
        else:
            raise ValueError("Invalid implication model")

    def compute_metric(self, run_data: Dict[str, Any]) -> Dict[str, Union[float, Any]]:
        """Assumes the format that run_data [log_likelihood_truncated] is the p(s_argmax|x0)
        and run_data[decoys][log_likelihood_truncated] is results of beam search for each decoy

        Args:
            run_data (Dict[str, Any]): _description_

        Returns:
            Dict[str, Union[float, Any]]: _description_
        """

        multi_modes = False
        # check if cluster entry is in dict
        if "clusters" in run_data.keys():
            # Enter multiple mode
            clusters = run_data["clusters"]
            multi_modes = True
        else:
            clean_logs_probs = run_data["model_answer_dict"][
                "log_likelihood_truncated"
            ][0].sum() / len(
                run_data["model_answer_dict"]["tokens_decoded_generated_truncated"][0]
            )
            s_argmax = run_data["model_answer_dict"][
                "tokens_decoded_generated_truncated"
            ][0]
            clusters = {
                s_argmax: clean_logs_probs
            }  # dummy to reuse function match_to_argmax_candidates

        all_matches = []
        for decoy in run_data["decoys"].keys():
            out = run_data["decoys"][decoy]
            matches = self.match_to_argmax_candidates(out, clusters)
            all_matches.append(matches)

        # Build the vectors
        vectors = [list(clusters.values())]
        for matches in all_matches:
            vectors.append([match["logprob"] for match in matches.values()])

        vectors = np.exp(np.stack(vectors))  # exp to convert to probabilities

        if multi_modes:
            metric = np.mean(
                [
                    np.sum((vectors[0] - vectors[i]) ** 2)
                    for i in range(1, vectors.shape[0])
                ]
            )

        else:
            # should have shape (n_decoys+1,n_clusters) #here n_clusters = 1
            metric = np.mean(np.sqrt(1 - np.sqrt(vectors)))

        if np.isnan(metric):
            metric = 1.1

        return {
            "metric": metric,
            "metric_data": {
                "vectors": vectors,
                "matches": all_matches,
            },
        }

    def match_to_argmax_candidates(self, out: dict, clusters: dict) -> dict:
        # first sort the answers by log likelihood
        log_likelihoods = out["log_likelihood_truncated"]
        answers = out["tokens_decoded_generated_truncated"]
        question = out["question"]
        sentence_ll = [ll.sum() / len(ll) for ll in log_likelihoods]

        # order sentence_ll from highest to lowest
        order = sorted(
            range(len(sentence_ll)), key=lambda k: sentence_ll[k], reverse=True
        )
        ordered_answers = [answers[i] for i in order]
        ordered_sentence_ll = [sentence_ll[i] for i in order]

        # Now we need to go through every cluster and find the best match
        matches = {
            cluster: {"answer": None, "logprob": -np.inf} for cluster in clusters.keys()
        }  # case where no match is found
        for cluster in clusters.keys():
            for i in range(len(ordered_answers)):
                answer = ordered_answers[i]
                log_prob = ordered_sentence_ll[i]

                t1 = cluster
                t2 = answer
                # check if implication model is EntailmentDeberta
                if isinstance(self.implication_model, EntailmentDeberta):
                    # add questions for entailment model
                    t1 = question + " " + t1
                    t2 = question + " " + t2

                semantic_ids = get_semantic_ids(
                    strings_list=[t1, t2],
                    model=self.implication_model,
                    strict_entailment=False,
                    example=None,
                )

                if semantic_ids[0] == semantic_ids[1]:
                    matches[cluster] = {"answer": answer, "logprob": log_prob}
                    break

        return matches


class RougeModel:
    """Wrapper to use as model in get_semantic_ids"""

    def __init__(self, threshold: float = 0.3):
        self.threshold = threshold
        self.rouge = evaluate.load("rouge")

    def check_implication(self, text1, text2, example=None):
        """Mimics API of get_semantic_ids 0: contradiction, 1: neutral, 2: entailment"""
        rouge_score = self.rouge.compute(
            predictions=[text1],
            references=[text2],
            rouge_types=["rougeL"],
        )["rougeL"]
        if rouge_score >= self.threshold:
            return 2
        else:
            return 0


class MutualInformationMetric(UEMetric):
    def __init__(self, similarity_threshold: float = 0.3, **kwargs):
        ##https://openreview.net/pdf?id=k6iyUfwdI9
        self.name = "mi"
        self.similarity_threshold = similarity_threshold

    def compute_metric(self, run_data: Dict[str, Any]) -> Dict[str, Union[float, Any]]:
        marginals = self._extract_marginals(run_data)
        conditionals = self._extract_conditionals(run_data)
        joint = self._build_joint_distribution(conditionals, marginals)
        product = self._build_product_distribution(conditionals, marginals)
        mi = self._get_mi_estimate(joint, product, 1e-9, 1e-9)
        return {
            "metric": mi,
            "metric_data": {
                "joint": joint,
                "product": product,
                "marginals": marginals,
                "conditionals": conditionals,
            },
        }

    def _extract_marginals(self, data: dict) -> Dict[str, float]:
        """
        Extract the marginal distribution p(x) for decoys.
        Assumes that the order of decoys in data["decoys"] matches the order in data["log_likelihood_per_semantic_id"].
        """
        decoy_order = list(data["decoys"].keys())
        log_likelihoods = data["log_likelihood_per_semantic_id"]

        # Check if decoys and log likelihoods have same length (Could be trunacted to 10 fore compute budget)
        if len(decoy_order) != len(log_likelihoods):
            logging.warning(
                f"Decoy order and log likelihoods have different lengths. {len(decoy_order)} vs {len(log_likelihoods)}. Using only the first {min(len(decoy_order), len(log_likelihoods))} decoys."
            )
            decoy_order = decoy_order[: len(log_likelihoods)]
            log_likelihoods = log_likelihoods[: len(decoy_order)]

        # Convert log-likelihoods to probabilities
        p_x = np.array([np.exp(ll) for ll in log_likelihoods])
        # Ensure that the probabilities sum to 1
        p_x = p_x / p_x.sum()
        return {decoy: prob for decoy, prob in zip(decoy_order, p_x)}

    def _extract_conditionals(self, data: dict) -> Dict[str, Dict[str, float]]:
        """Returns p(X_t|X_i) for each decoy X_i"""

        # Assumes that the output is data["decoys"][decoy1][decoy2]: prob (decoy2 | decoy1)
        conditionals = {}
        for decoy1, output in data["decoys"].items():
            conditionals[decoy1] = {decoy2: prob for decoy2, prob in output.items()}

        # Ensure that for each conditional distribution, the probabilities sum to 1
        for decoy1, cond in conditionals.items():
            total = sum(cond.values())
            if total == 0:
                logging.warning(
                    f"Conditional distribution for {decoy1} sums to 0, setting to uniform distribution."
                )
                for decoy2 in cond:
                    cond[decoy2] = 1.0 / len(cond)
            else:
                for decoy2 in cond:
                    cond[decoy2] /= total

        return conditionals

    def _build_joint_distribution(
        self, conditionals: Dict[str, Dict[str, float]], marginals: Dict[str, float]
    ) -> np.ndarray:
        """
        Build matrix such that joint[i,t]  = p(X_i, X_t) = p(X_i) * p(X_t|X_i)
        """
        n = len(marginals)
        joint = np.zeros((n, n))
        for i, decoy1 in enumerate(marginals):
            for j, decoy2 in enumerate(marginals):
                joint[i, j] = marginals[decoy1] * conditionals[decoy1][decoy2]
        return joint

    def _build_product_distribution(
        self, conditionals: Dict[str, Dict[str, float]], marginals: Dict[str, float]
    ) -> np.ndarray:
        """
        Build matrix such that product[i,t] = p(X_i) * p(X_t)
        """
        n = len(marginals)
        product = np.zeros((n, n))
        for i, decoy1 in enumerate(marginals):
            for j, decoy2 in enumerate(marginals):
                for k, decoy3 in enumerate(marginals):
                    product[i, j] += marginals[decoy3] * conditionals[decoy3][decoy2]
                product[i, j] *= marginals[decoy1]
        return product

    def _get_mi_estimate(
        self,
        joint: np.ndarray,
        product: np.ndarray,
        gamma1: float = 1e-9,
        gamma2: float = 1e-9,
    ) -> float:
        """
        Compute the mutual information from the joint and product distributions:
            MI = sum_{x,y} joint(x,y) * log( joint(x,y) / (p(y) * p(x)) )
        Only nonzero entries in the joint distribution contribute.

        joint[i,t] = p(X_i,
        product[i,t] = p(X_i) * p(X_t)
        """

        joint = joint.flatten()
        product = product.flatten()

        # assert same support
        assert len(joint) == len(product)

        return np.sum(joint * np.log((joint + gamma1) / (product + gamma2)))


class SemanticEntropyEstimator(UEMetric):
    """
    Concrete implementation of UQEstimator for semantic entropy estimation.
    """

    def __init__(self):
        """
        Initialize the SemanticEntropyEstimator.
        """
        self.name = "se"

    def compute_metric(self, sample: Dict[str, Any]) -> Dict[str, Any]:
        """
        Estimate the semantic entropy for a given sample.

        Estimate := -\sum_i p_i \log(p_i)

        :param sample: A dictionary containing the sample data.
        :return: A float representing the estimated semantic entropy.
        """

        p_hat = np.array(sample["p_hat"])

        # assert it sums to 1 with  a tolerance of 1e-3
        assert np.isclose(np.sum(p_hat), 1.0, atol=1e-3), "p_hat does not sum to 1"

        # normalize again
        p_hat /= np.sum(p_hat)

        mask = p_hat > 0  # to avoid log(0)
        # calculate the semantic entropy
        return {
            "metric": -np.sum(p_hat[mask] * np.log(p_hat[mask])),
            "metric_data": {"p_hat": p_hat},
        }


class PmaxEstimator(UEMetric):
    """
    Concrete implementation of UQEstimator for argmax probability estimation.
    """

    def __init__(self):
        """
        Initialize the PmaxEstimator.
        """
        self.name = "pmax"

    def compute_metric(self, sample: Dict[str, Any]) -> float:
        """
        Estimate the argmax probability for a given sample.

        estimate := argmax_p_i p

        :param sample: A dictionary containing the sample data.
        :return: A float representing the estimated argmax probability.
        """

        p_hat = np.array(sample["p_hat"])

        # assert it sums to 1 with a tolerance of 1e-3
        assert np.isclose(np.sum(p_hat), 1.0, atol=1e-3), "p_hat does not sum to 1"

        # normalize again
        p_hat /= np.sum(p_hat)

        return {"metric": 1 - np.max(p_hat), " metric_data": {"p_hat": p_hat}}


class PMarginEstimator(UEMetric):
    """
    Concrete implementation of UQEstimator for PMargin estimation.
    """

    def __init__(self):
        """
        Initialize the PMarginEstimator.
        """
        self.name = "pmargin"

    def compute_metric(self, sample: Dict[str, Any]) -> float:
        """
        Estimate the PMargin for a given sample.

        PMargin := max(p_i) - second_max(p_i)

        :param sample: A dictionary containing the sample data.
        :return: A float representing the estimated PMargin.
        """

        p_hat = np.array(sample["p_hat"])

        # assert it sums to 1 with a tolerance of 1e-3
        assert np.isclose(np.sum(p_hat), 1.0, atol=1e-3), "p_true does not sum to 1"

        # normalize again
        p_hat /= np.sum(p_hat)

        return {
            "metric": (
                np.max(p_hat) - np.sort(p_hat)[-2] if len(p_hat) > 1 else p_hat[0]
            ),  # second largest probability
            "metric_data": {"p_hat": p_hat},
        }


class MaximumSentenceProbabilityEstimator(UEMetric):
    """
    Concrete implementation of UQEstimator for Maximum Sentence Probability estimation.
    """

    def __init__(self):
        """
        Initialize the MaximumSentenceProbabilityEstimator.
        """
        self.name = "msp"

    def compute_metric(self, sample: Dict[str, Any]) -> float:
        """
        Estimate the maximum sentence probability for a given sample.

        :param sample: A dictionary containing the sample data.
        :return: A float representing the estimated maximum sentence probability.
        """

        sentence_probs = np.exp(sample["sentence_ln_log_lik"])
        max_sentence_prob = np.max(sentence_probs)

        return {
            "metric": 1 - max_sentence_prob,
            "metric_data": {"sentence_probs": sentence_probs},
        }


class SAREstimator(UEMetric):
    """
    Concrete implementation of UQEstimator for SAR using LM_Polygraph
    https://arxiv.org/abs/2307.01379.
    https://github.com/IINemo/lm-polygraph/blob/main/src/lm_polygraph/estimators/sar.py

    """

    def __init__(self, model):
        """
        Initialize the SAR estimator with a WhiteboxModel. Use CausalLM Model wrapper

        :param model : wrapper of CAUSALLM
        """
        self.name = "sar"

        # Necessary to wrap in WhiteboxModel for lm_polygraph compatibility
        self.model = WhiteboxModel(model=model.model, tokenizer=model.tokenizer)
        self.calculator = CrossEncoderSimilarityMatrixCalculator()
        self.metric = SAR()

    def compute_metric(self, sample: Dict[str, Any]) -> float:
        """
        Estimate the Semantic Average Rank (SAR) for a given sample.

        SAR := \sum_i p_i * rank_i

        :param sample: A dictionary containing the sample data.
        :return: A float representing the estimated SAR.
        """

        # build dependencies for lm polygraph call
        dependencies = {
            "sample_tokens": list(map(lambda x: [x], sample["tokens_truncated"])),
            "sample_texts": [sample["tokens_decoded_generated_truncated"]],
            "input_texts": [sample["question"]]
            * len(sample["tokens_decoded_generated_truncated"]),
            "greedy_tokens": sample["tokens_truncated"],
        }

        # Text seems not to be used in the SAR calculation, but is required by the calculator
        similarities = self.calculator(
            dependencies=dependencies, texts=[], model=self.model
        )

        stats = {
            "sample_log_likelihoods": [sample["log_likelihood_truncated"]],
            "sample_token_similarity": [similarities["token_similarity"]],
            "sample_sentence_similarity": similarities["sample_sentence_similarity"],
        }

        return {
            "metric": self.metric(stats)[0],
            "metric_data": {
                "similarities": similarities,
                "stats": stats,
            },
        }
