from abc import ABC, abstractmethod
from typing import Any, Dict, List
import numpy as np
import torch
from lm_polygraph.estimators.sar import SAR
from lm_polygraph.utils.deberta import Deberta
from lm_polygraph.estimators import KernelLanguageEntropy
from lm_polygraph.stat_calculators import SemanticMatrixCalculator
from lm_polygraph.stat_calculators.cross_encoder_similarity import (
    CrossEncoderSimilarityMatrixCalculator,
)

# Calculator for greedy-sample similarity
from lm_polygraph.stat_calculators.greedy_cross_encoder_similarity import (
    GreedyCrossEncoderSimilarityMatrixCalculator,
)
from lm_polygraph.utils.model import WhiteboxModel
from structured_llmuq.utils.postprocessing import EntailmentDeberta, get_semantic_ids
from lm_polygraph.stat_calculators.prompt import BasePromptCalculator
from lm_polygraph.estimators.p_true import PTrue


class UQEstimator(ABC):
    @abstractmethod
    def __call__(self, *args, **kwds):
        return super().__call__(*args, **kwds)


class SemanticEntropyEstimator(UQEstimator):
    """
    Concrete implementation of UQEstimator for semantic entropy estimation.
    """

    def __init__(self):
        """
        Initialize the SemanticEntropyEstimator.
        """
        self.name = "se"
        self.entailment_model = EntailmentDeberta()

    def __call__(
        self, answers: List[str], question: str = None, probs: List[float] = None
    ) -> Dict[str, Any]:
        """
        Given a question list of answers create semantic entropy based on
        https://www.nature.com/articles/s41586-024-07421-0
        """

        # In case question is not none add it to each answer for context
        if question is not None:
            answers = [question + " " + answer for answer in answers]

        semantic_ids = get_semantic_ids(
            strings_list=answers,
            model=self.entailment_model,
            strict_entailment=False,
            example=None,
        )

        # create pmf
        if probs:
            # Use also provided probs
            pmf_dict = {}
            for sem_id, prob in zip(semantic_ids, probs):
                if sem_id in pmf_dict:
                    pmf_dict[sem_id] += prob
                else:
                    pmf_dict[sem_id] = prob
            pmf = np.array(list(pmf_dict.values()))
            pmf = pmf / np.sum(pmf)  # normalize
        else:
            # Discrete Semantic Entropy
            unique_ids, counts = np.unique(semantic_ids, return_counts=True)
            pmf = counts / np.sum(counts)

        entropy = -np.sum(pmf * np.log(pmf + 1e-12))  # add small value to avoid log(0)
        return {
            "estimate": entropy,
            "additional_data": {"pmf": pmf, "semantic_ids": semantic_ids},
        }


class SAREstimator(UQEstimator):
    """
    Concrete implementation of UQEstimator for SAR using LM_Polygraph
    https://arxiv.org/abs/2307.01379.
    https://github.com/IINemo/lm-polygraph/blob/main/src/lm_polygraph/estimators/sar.py

    """

    def __init__(self, model):
        """
        Initialize the SAR estimator with a WhiteboxModel. Use CausalLM Model wrapper

        :param model : wrapper of CAUSALLM
        """
        self.name = "sar"

        # Necessary to wrap in WhiteboxModel for lm_polygraph compatibility
        self.model = WhiteboxModel(model=model.model, tokenizer=model.tokenizer)
        self.calculator = CrossEncoderSimilarityMatrixCalculator(batch_size=256)
        self.metric = SAR()

    def __call__(
        self,
        answers: List[str],
        tokens: List[List[int]] = None,
        log_likelihoods: List[List[float]] = None,
    ) -> Dict[str, Any]:
        """
        Given a question list of answers create SAR estimate.

        :param answers: List of answer strings.
        :param probs: Optional list of probabilities associated with each answer.
        :return: A dictionary containing the SAR estimate and additional data.
        """

        # build dependencies for lm polygraph call
        dependencies = {
            "sample_tokens": list(map(lambda x: [x], tokens)),
            "sample_texts": [answers],
            "input_texts": [""] * len(answers),
            "greedy_tokens": tokens,
        }

        # Text seems not to be used in the SAR calculation, but is required by the calculator
        similarities = self.calculator(
            dependencies=dependencies, texts=[], model=self.model
        )

        # The greedy output is not used for SAR of the calculator method
        stats = {
            "sample_log_likelihoods": [log_likelihoods],
            "sample_token_similarity": [similarities["token_similarity"]],
            "sample_sentence_similarity": similarities["sample_sentence_similarity"],
        }

        return {
            "estimate": self.metric(stats)[0],
            "additional_data": {
                "similarities": similarities,
                "stats": stats,
            },
        }


class KLEEstimator(UQEstimator):
    """
    Concrete implementation of UQEstimator for Kernel Language Entropy estimation.
    https://arxiv.org/abs/2405.20003
    """

    def __init__(self):
        """
        Initialize the KernelLanguageEntropy estimator.
        """
        self.name = "kle"
        self.nli_model = Deberta()
        self.nli_model.setup()
        self.calculator = SemanticMatrixCalculator(nli_model=self.nli_model)
        self.kl_estimator = KernelLanguageEntropy()

    def __call__(self, answers: List[str]) -> Dict[str, Any]:
        # TODO: check if question is needed for entailnment model
        dependencies = {"sample_texts": [answers]}
        stats = self.calculator(dependencies=dependencies, model=None, texts=None)
        v_heat = self.kl_estimator(stats)[0]
        return {
            "estimate": 1 - v_heat,
            "additional_data": {
                "stats": stats,
            },
        }


class DummyModel:
    """Dummy model class that only provides device information for stat calculators."""

    def device(self):
        """Return 'cuda' if available, otherwise 'cpu'."""
        return "cuda" if torch.cuda.is_available() else "cpu"


class CocoaEstimator(UQEstimator):
    """
    Adapted from CocoaMSP - LM Polygraph
    https://arxiv.org/abs/2406.04370

    COCOA (Consistency-Guided Decoding) combines uncertainty with semantic similarity
    to greedy outputs. It uses MSP (Maximum Sequence Probability) weighted by
    dissimilarity to the greedy output.
    """

    def __init__(self):
        """
        Initialize the COCOA estimator
        """
        self.name = "cocoa"
        self.calculator = GreedyCrossEncoderSimilarityMatrixCalculator()
        self.dummy_model = DummyModel()

    def __call__(
        self,
        answers: List[str],
        greedy_answer: str,
        greedy_log_likelihoods: List[float],
    ) -> Dict[str, Any]:
        """
        Given a list of sampled answers and a greedy answer, compute COCOA estimate.

        :param answers: List of sampled answer strings.
        :param greedy_answer: The greedy decoded answer string.
        :param greedy_log_likelihoods: log-likelihoods for greedy answer
        :return: A dictionary containing the COCOA estimate and additional data.
        """

        # build dependencies for lm polygraph call
        dependencies = {
            "sample_texts": [answers],
            "greedy_texts": [greedy_answer],
        }

        # Calculate greedy-sample similarity using cross-encoder
        similarities = self.calculator(
            dependencies=dependencies, texts=[], model=self.dummy_model
        )

        greedy_sentence_similarity = similarities["greedy_sentence_similarity"]
        greedy_ll = np.mean(greedy_log_likelihoods)  # Length Normalized

        # Following lm-polygraph code
        prob = -greedy_ll

        # Compute row-wise average similarity, excluding self-similarity
        # Diagonal contains self-similarities
        avg_dissimilarity = np.mean(1 - greedy_sentence_similarity)

        cocoa_msp = prob * avg_dissimilarity

        return {
            "estimate": cocoa_msp,
            "additional_data": {"greedy_sentence_similarity": greedy_sentence_similarity},
        }


# Class for PTrue estimator with better prompt template since (A), (B) options can lead model to respond with (A) or (B)
class PromptCalculatorOwn(BasePromptCalculator):

    @staticmethod
    def meta_info():
        """
        Returns the statistics and dependencies for the PromptCalculator.
        """

        return ["p_true"], ["greedy_texts"]

    def __init__(self, prompt: str, expected: str, name: str = "p_true"):
        super().__init__(
            # "Question: {q}\n Possible answer:{a}\n "
            # "Is the possible answer:\n-True\n-False\n The possible answer is:",
            # "True",
            prompt, expected,
            name,
        )


class PTrueEstimator(UQEstimator):
    """
    Concrete implementation of UQEstimator for PTrue using LM_Polygraph
    """

    def __init__(self, model,
                 prompt: str = "Question: {q}\n Possible answer:{a}\n "
            "Is the possible answer:\n-True\n-False\n The possible answer is:",
                expected: str = "True",
                name: str =  "p_true"):
        """
        Initialize the PTrue estimator with a WhiteboxModel. Use CausalLM Model wrapper

        :param model : wrapper of CAUSALLM
        """
        self.name = "p_true"

        # Necessary to wrap in WhiteboxModel for lm_polygraph compatibility
        model.model.generation_config.num_return_sequences = 1 #needed for compatibility
        self.model = WhiteboxModel(model=model.model, tokenizer=model.tokenizer)
        self.prompt_calculator = PromptCalculatorOwn(prompt, expected, name)
        self.metric = PTrue()
        self.name = name

    def __call__(self, question: str, answer: str) -> Dict[str, Any]:
        """
        Given a question and list of answers create PTrue estimate.

        args:
            question: str, question string
            answer: str, answer string
        returns:
            Dict with estimate and additional data
        """

        dependencies = {"greedy_texts": [answer], "input_texts": [question]}
        stats = self.prompt_calculator(dependencies=dependencies, texts=[], model=self.model)
        return {
            "estimate": self.metric(stats)[0],
            "additional_data": {
                "stats": stats,
            },
        }
