import numpy as np
import logging
from typing import Dict, Literal
from .estimator import Estimator
from .common import compute_sim_score
from scipy.linalg import eigh

from .process_probs import process_probs

log = logging.getLogger(__name__)


class Eccentricity(Estimator):
    """
    Estimates the sequence-level uncertainty of a language model following the method of
    "Eccentricity" as provided in the paper https://arxiv.org/abs/2305.19187.
    Works with both whitebox and blackbox models (initialized using
    lm_polygraph.utils.model.BlackboxModel/WhiteboxModel).

    Method calculates a frobenious (euclidian) norm between all eigenvectors that are informative embeddings
    of graph Laplacian (lower norm -> closer embeddings -> higher eigenvectors -> greater uncertainty).
    """

    def __init__(
        self,
        similarity_score: Literal["NLI_score", "Jaccard_score"] = "NLI_score",
        affinity: Literal["entail", "contra"] = "entail",  # relevant for NLI score case
        verbose: bool = False,
        thres: float = 0.9,
        confidence: bool = False,
        samples_source: str = "sample",
    ):
        """
        See parameters descriptions in https://arxiv.org/abs/2305.19187.
        Parameters:
            similarity_score (str): similarity score for matrix construction. Possible values:
                - 'NLI_score': Natural Language Inference similarity
                - 'Jaccard_score': Jaccard score similarity
            affinity (str): affinity method, relevant only when similarity_score='NLI_score'. Possible values:
                - 'entail': similarity(response_1, response_2) = p_entail(response_1, response_2)
                - 'contra': similarity(response_1, response_2) = 1 - p_contra(response_1, response_2)
        """
        if confidence and not samples_source.startswith('greedy+'):
            samples_source = 'greedy+' + samples_source
        if similarity_score == "NLI_score":
            if affinity == "entail":
                super().__init__([f"{samples_source}_semantic_matrix_entail", f"{samples_source}_texts"], "sequence")
            else:
                super().__init__([f"{samples_source}_semantic_matrix_contra", f"{samples_source}_texts"], "sequence")
        else:
            super().__init__([f"{samples_source}_texts"], "sequence")

        self.similarity_score = similarity_score
        self.affinity = affinity
        self.verbose = verbose
        self.thres = thres
        self.confidence = confidence
        self.samples_source = samples_source

    def __str__(self):
        base = "Eccentricity"
        if self.confidence:
            base += 'Conf'
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        if self.similarity_score == "NLI_score":
            return f"{base}_{self.similarity_score}_{self.affinity}"
        return f"{base}_{self.similarity_score}"

    def U_Eccentricity(self, i, stats):
        answers = stats[f"{self.samples_source}_texts"][i]

        if self.similarity_score == "NLI_score":
            if self.affinity == "entail":
                W = np.array(stats[f"{self.samples_source}_semantic_matrix_entail"])[i, :, :]
            else:
                W = 1 - np.array(stats[f"{self.samples_source}_semantic_matrix_contra"])[i, :, :]
            W = (W + np.transpose(W)) / 2
        else:
            W = compute_sim_score(
                answers=answers,
                affinity=self.affinity,
                similarity_score=self.similarity_score,
            )

        D = np.diag(W.sum(axis=1))
        D_inverse_sqrt = np.linalg.inv(np.sqrt(D))
        L = np.eye(D.shape[0]) - D_inverse_sqrt @ W @ D_inverse_sqrt

        # k is hyperparameter  - Number of smallest eigenvectors to retrieve
        # Compute eigenvalues and eigenvectors
        eigenvalues, eigenvectors = eigh(L)

        if self.thres is not None:
            keep_mask = eigenvalues < self.thres
            eigenvalues, smallest_eigenvectors = (
                eigenvalues[keep_mask],
                eigenvectors[:, keep_mask],
            )
        else:
            smallest_eigenvectors = eigenvectors

        smallest_eigenvectors = smallest_eigenvectors.T

        # U_Ecc = np.linalg.norm(np.asarray(
        #     [np.linalg.norm(x - x.mean(0), 2) for x in smallest_eigenvectors]
        # ), 2)
        # C_Ecc_s_j = (-1) * np.asarray(
        #     [np.linalg.norm(x - x.mean(0), 2) for x in smallest_eigenvectors]
        # )
        # We use slightly different formula: it includes eigenvectors variance within samples
        all_confs = []
        for tgt_i in range(smallest_eigenvectors.shape[1]):
            C_Ecc = np.mean([
                np.linalg.norm(smallest_eigenvectors[:, i] - smallest_eigenvectors[:, tgt_i]) ** 2
                for i in range(smallest_eigenvectors.shape[1])
            ])
            all_confs.append(C_Ecc)  # probs-weighted

        return np.mean(all_confs), all_confs[0]

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates the uncertainties for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * generated samples in 'sample_texts',
                * matrix with semantic similarities in 'sample_semantic_matrix_entail'/'sample_semantic_matrix_contra'
        Returns:
            np.ndarray: float uncertainty for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        res = []
        for i, answers in enumerate(stats[f"{self.samples_source}_texts"]):
            if self.verbose:
                log.debug(f"generated answers: {answers}")
            unc, conf = self.U_Eccentricity(i, stats)
            res.append(conf if self.confidence else unc)
        return np.array(res)


class EccentricityP(Estimator):
    """
    Estimates the sequence-level uncertainty of a language model following the method of
    "Eccentricity" as provided in the paper https://arxiv.org/abs/2305.19187.
    Works with both whitebox and blackbox models (initialized using
    lm_polygraph.utils.model.BlackboxModel/WhiteboxModel).

    Method calculates a frobenious (euclidian) norm between all eigenvectors that are informative embeddings
    of graph Laplacian (lower norm -> closer embeddings -> higher eigenvectors -> greater uncertainty).
    """

    def __init__(
        self,
        similarity_score: Literal["NLI_score", "Jaccard_score"] = "NLI_score",
        affinity: Literal["entail", "contra"] = "entail",  # relevant for NLI score case
        verbose: bool = False,
        thres: float = 0.9,
        confidence: bool = False,
        samples_source: str = "beamsearch",
        **process_probs_args,
    ):
        """
        See parameters descriptions in https://arxiv.org/abs/2305.19187.
        Parameters:
            similarity_score (str): similarity score for matrix construction. Possible values:
                - 'NLI_score': Natural Language Inference similarity
                - 'Jaccard_score': Jaccard score similarity
            affinity (str): affinity method, relevant only when similarity_score='NLI_score'. Possible values:
                - 'entail': similarity(response_1, response_2) = p_entail(response_1, response_2)
                - 'contra': similarity(response_1, response_2) = 1 - p_contra(response_1, response_2)
        """
        if confidence and not samples_source.startswith('greedy+'):
            samples_source = 'greedy+' + samples_source
        if similarity_score == "NLI_score":
            if affinity == "entail":
                super().__init__([
                    f"{samples_source}_log_likelihoods",
                    f"{samples_source}_semantic_matrix_entail",
                    f"{samples_source}_texts",
                ], "sequence")
            else:
                super().__init__([
                    f"{samples_source}_log_likelihoods",
                    f"{samples_source}_semantic_matrix_contra",
                    f"{samples_source}_texts",
                ], "sequence")
        else:
            super().__init__([
                f"{samples_source}_log_likelihoods",
                f"{samples_source}_texts",
            ], "sequence")

        self.similarity_score = similarity_score
        self.affinity = affinity
        self.verbose = verbose
        self.thres = thres
        self.confidence = confidence
        self.samples_source = samples_source
        self.process_probs_args = process_probs_args

    def __str__(self):
        base = "EccentricityP"
        if self.confidence:
            base += 'Conf'
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        if self.similarity_score == "NLI_score":
            return f"{base}_{self.similarity_score}_{self.affinity}"
        return f"{base}_{self.similarity_score}"

    def U_Eccentricity(self, i, stats):
        answers = stats[f"{self.samples_source}_texts"][i]

        if self.similarity_score == "NLI_score":
            if self.affinity == "entail":
                W = np.array(stats[f"{self.samples_source}_semantic_matrix_entail"])[i, :, :]
            else:
                W = 1 - np.array(stats[f"{self.samples_source}_semantic_matrix_contra"])[i, :, :]
            W = (W + np.transpose(W)) / 2
        else:
            W = compute_sim_score(
                answers=answers,
                affinity=self.affinity,
                similarity_score=self.similarity_score,
            )

        D = np.diag(W.sum(axis=1))
        D_inverse_sqrt = np.linalg.inv(np.sqrt(D))
        L = np.eye(D.shape[0]) - D_inverse_sqrt @ W @ D_inverse_sqrt

        # k is hyperparameter  - Number of smallest eigenvectors to retrieve
        # Compute eigenvalues and eigenvectors
        eigenvalues, eigenvectors = eigh(L)

        if self.thres is not None:
            keep_mask = eigenvalues < self.thres
            eigenvalues, smallest_eigenvectors = (
                eigenvalues[keep_mask],
                eigenvectors[:, keep_mask],
            )
        else:
            smallest_eigenvectors = eigenvectors

        smallest_eigenvectors = smallest_eigenvectors.T

        sample_token_lls = stats[f"{self.samples_source}_log_likelihoods"][i]
        probs = np.array([np.exp(sum(s)) for s in sample_token_lls])
        probs = process_probs(probs, **self.process_probs_args)

        U_Ecc = np.linalg.norm(np.asarray(
            [np.linalg.norm(x - x.mean(0), 2) for x in smallest_eigenvectors]
        ), 2)

        # C_Ecc_s_j = (-1) * np.asarray(
        #     [np.linalg.norm(x - x.mean(0), 2) for x in smallest_eigenvectors]
        # )
        # We use slightly different formula: it includes eigenvectors variance within samples

        all_confs = []
        for tgt_i in range(smallest_eigenvectors.shape[1]):
            C_Ecc = (np.array([
                np.linalg.norm(smallest_eigenvectors[:, i] - smallest_eigenvectors[:, tgt_i]) ** 2
                for i in range(smallest_eigenvectors.shape[1])
            ]) * probs).sum()
            all_confs.append(C_Ecc) # probs-weighted
        return (all_confs * probs).sum(), all_confs[0]

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates the uncertainties for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                * generated samples in 'sample_texts',
                * matrix with semantic similarities in 'sample_semantic_matrix_entail'/'sample_semantic_matrix_contra'
        Returns:
            np.ndarray: float uncertainty for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        res = []
        for i, answers in enumerate(stats[f"{self.samples_source}_texts"]):
            if self.verbose:
                log.debug(f"generated answers: {answers}")
            unc, conf = self.U_Eccentricity(i, stats)
            res.append(conf if self.confidence else unc)
        return np.array(res)
