from lm_polygraph.estimators import *
from lm_polygraph.estimators.eccentricity import *


class EccentricityConf(Estimator):
    def __init__(
        self,
        similarity_score: Literal["NLI_score", "Jaccard_score"] = "NLI_score",
        affinity: Literal["entail", "contra"] = "entail",  # relevant for NLI score case
        verbose: bool = False,
        thres: float = 0.9,
        formula: str = 'dist_to_mean',
        samples_source: str = "sample",
    ):
        assert formula in ['dist_to_mean', 'mean_dist']
        self.formula = formula
        if not samples_source.startswith('greedy+'):
            samples_source = 'greedy+' + samples_source
        if similarity_score == "NLI_score":
            if affinity == "entail":
                super().__init__([f"{samples_source}_semantic_matrix_entail", f"{samples_source}_texts"], "sequence")
            else:
                super().__init__([f"{samples_source}_semantic_matrix_contra", f"{samples_source}_texts"], "sequence")
        else:
            super().__init__([f"{samples_source}_texts"], "sequence")

        self.similarity_score = similarity_score
        self.affinity = affinity
        self.verbose = verbose
        self.thres = thres
        self.samples_source = samples_source

    def __str__(self):
        base = "EccentricityConf" if self.formula == 'dist_to_mean' else "EigVecDissimilarity"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        if self.similarity_score == "NLI_score":
            return f"{base}_{self.similarity_score}_{self.affinity}"
        return f"{base}_{self.similarity_score}"

    def U_Eccentricity(self, i, stats):
        answers = stats[f"{self.samples_source}_texts"][i]

        if self.similarity_score == "NLI_score":
            if self.affinity == "entail":
                W = np.array(stats[f"{self.samples_source}_semantic_matrix_entail"])[i, :, :]
            else:
                W = 1 - np.array(stats[f"{self.samples_source}_semantic_matrix_contra"])[i, :, :]
            W = (W + np.transpose(W)) / 2
        else:
            W = compute_sim_score(
                answers=answers,
                affinity=self.affinity,
                similarity_score=self.similarity_score,
            )

        D = np.diag(W.sum(axis=1))
        D_inverse_sqrt = np.linalg.inv(np.sqrt(D))
        L = np.eye(D.shape[0]) - D_inverse_sqrt @ W @ D_inverse_sqrt

        # k is hyperparameter  - Number of smallest eigenvectors to retrieve
        # Compute eigenvalues and eigenvectors
        eigenvalues, eigenvectors = eigh(L)

        if self.thres is not None:
            keep_mask = eigenvalues < self.thres
            eigenvalues, smallest_eigenvectors = (
                eigenvalues[keep_mask],
                eigenvectors[:, keep_mask],
            )
        else:
            smallest_eigenvectors = eigenvectors

        smallest_eigenvectors = smallest_eigenvectors.T

        if self.formula == 'dist_to_mean':
            mean_vector = smallest_eigenvectors[:, 1:].mean(-1)
            assert mean_vector.shape == smallest_eigenvectors[:, 0].shape
            C_Ecc = np.linalg.norm(mean_vector - smallest_eigenvectors[:, 0]) ** 2
        elif self.formula == 'mean_dist':
            C_Ecc = np.mean([
                np.linalg.norm(smallest_eigenvectors[:, i] - smallest_eigenvectors[:, 0]) ** 2
                for i in range(1, smallest_eigenvectors.shape[1])
            ])
        else:
            raise Exception('пошел нахуй')
        return C_Ecc

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        res = []
        for i, answers in enumerate(stats[f"{self.samples_source}_texts"]):
            if self.verbose:
                log.debug(f"generated answers: {answers}")
            conf = self.U_Eccentricity(i, stats)
            res.append(conf)
        return np.array(res)


class EccentricityPConf(Estimator):
    def __init__(
        self,
        similarity_score: Literal["NLI_score", "Jaccard_score"] = "NLI_score",
        affinity: Literal["entail", "contra"] = "entail",  # relevant for NLI score case
        verbose: bool = False,
        thres: float = 0.9,
        samples_source: str = "beamsearch",
        formula: str = 'dist_to_mean',
        **process_probs_args,
    ):
        assert formula in ['dist_to_mean', 'mean_dist']
        self.formula = formula
        if not samples_source.startswith('greedy+'):
            samples_source = 'greedy+' + samples_source
        if similarity_score == "NLI_score":
            if affinity == "entail":
                super().__init__([
                    f"{samples_source}_log_likelihoods",
                    f"{samples_source}_semantic_matrix_entail",
                    f"{samples_source}_texts",
                ], "sequence")
            else:
                super().__init__([
                    f"{samples_source}_log_likelihoods",
                    f"{samples_source}_semantic_matrix_contra",
                    f"{samples_source}_texts",
                ], "sequence")
        else:
            super().__init__([
                f"{samples_source}_log_likelihoods",
                f"{samples_source}_texts",
            ], "sequence")

        self.similarity_score = similarity_score
        self.affinity = affinity
        self.verbose = verbose
        self.thres = thres
        self.samples_source = samples_source
        self.process_probs_args = process_probs_args

    def __str__(self):
        base = "EccentricityPConf" if self.formula == 'dist_to_mean' else "EigVecDissimilarityP"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        if self.similarity_score == "NLI_score":
            return f"{base}_{self.similarity_score}_{self.affinity}"
        return f"{base}_{self.similarity_score}"

    def U_Eccentricity(self, i, stats):
        answers = stats[f"{self.samples_source}_texts"][i]

        if self.similarity_score == "NLI_score":
            if self.affinity == "entail":
                W = np.array(stats[f"{self.samples_source}_semantic_matrix_entail"])[i, :, :]
            else:
                W = 1 - np.array(stats[f"{self.samples_source}_semantic_matrix_contra"])[i, :, :]
            W = (W + np.transpose(W)) / 2
        else:
            W = compute_sim_score(
                answers=answers,
                affinity=self.affinity,
                similarity_score=self.similarity_score,
            )

        D = np.diag(W.sum(axis=1))
        D_inverse_sqrt = np.linalg.inv(np.sqrt(D))
        L = np.eye(D.shape[0]) - D_inverse_sqrt @ W @ D_inverse_sqrt

        # k is hyperparameter  - Number of smallest eigenvectors to retrieve
        # Compute eigenvalues and eigenvectors
        eigenvalues, eigenvectors = eigh(L)

        if self.thres is not None:
            keep_mask = eigenvalues < self.thres
            eigenvalues, smallest_eigenvectors = (
                eigenvalues[keep_mask],
                eigenvectors[:, keep_mask],
            )
        else:
            smallest_eigenvectors = eigenvectors

        smallest_eigenvectors = smallest_eigenvectors.T

        sample_token_lls = stats[f"{self.samples_source}_log_likelihoods"][i]
        probs = np.array([np.exp(sum(s)) for s in sample_token_lls])
        probs = probs[1:]
        probs = process_probs(probs, **self.process_probs_args)

        if self.formula == 'dist_to_mean':
            mean_vector = (smallest_eigenvectors[:, 1:] * probs.reshape(1, len(probs))).sum(-1)
            assert mean_vector.shape == smallest_eigenvectors[:, 0].shape
            C_Ecc = np.linalg.norm(mean_vector - smallest_eigenvectors[:, 0]) ** 2
        elif self.formula == 'mean_dist':
            C_Ecc = (np.array([
                np.linalg.norm(smallest_eigenvectors[:, i] - smallest_eigenvectors[:, 0]) ** 2
                for i in range(1, smallest_eigenvectors.shape[1])
            ]) * probs).sum()
        else:
            raise Exception('пошел нахуй')
        
        return C_Ecc

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        res = []
        for i, answers in enumerate(stats[f"{self.samples_source}_texts"]):
            if self.verbose:
                log.debug(f"generated answers: {answers}")
            conf = self.U_Eccentricity(i, stats)
            res.append(conf)
        return np.array(res)


import numpy as np

from typing import Dict

from lm_polygraph.estimators.estimator import Estimator
from lm_polygraph.estimators.process_probs import process_probs

class CocoaMSP(Estimator):
    def __init__(
        self,
        samples_source: str = "sample"
    ):
        super().__init__(
            [f"greedy_{samples_source}_sentence_similarity", "greedy_log_likelihoods"], "sequence"
        )
        self.samples_source = samples_source

    def __str__(self):
        base = "CocoaMSP"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        batch_greedy_sentence_similarity = [x[0][1:] for x in stats[f"greedy+{self.samples_source}_semantic_matrix_entail"]]
        batch_lls = np.array(
            [
                np.sum(log_likelihood)
                for log_likelihood in stats["greedy_log_likelihoods"]
            ]
        )

        enriched_metrics = []  # To store enriched metrics for each sample
        for greedy_ll, greedy_sentence_similarity in zip(
            batch_lls, batch_greedy_sentence_similarity
        ):
            # Compute probabilities (negative log-probs)
            prob = -greedy_ll

            # Compute row-wise average similarity, excluding self-similarity
            # Diagonal contains self-similarities
            avg_dissimilarity = np.mean(1 - np.array(greedy_sentence_similarity))

            enriched_metric = prob * avg_dissimilarity
            enriched_metrics.append(enriched_metric)

        return np.array(enriched_metrics)


class CocoaMSPP(Estimator):
    def __init__(
        self,
        samples_source: str = "sample",
        **process_probs_args,
    ):
        super().__init__(
            [
                f"greedy_{samples_source}_sentence_similarity",
                f"{samples_source}_log_likelihoods",
                "greedy_log_likelihoods",
            ], "sequence"
        )
        self.samples_source = samples_source
        self.process_probs_args = process_probs_args

    def __str__(self):
        base = "CocoaMSPP"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        batch_greedy_sentence_similarity = [x[0][1:] for x in stats[f"greedy+{self.samples_source}_semantic_matrix_entail"]]
        batch_sample_token_lls = stats[f"{self.samples_source}_log_likelihoods"]
        batch_lls = np.array(
            [
                np.sum(log_likelihood)
                for log_likelihood in stats["greedy_log_likelihoods"]
            ]
        )

        enriched_metrics = []  # To store enriched metrics for each sample
        for sample_token_lls, greedy_ll, greedy_sentence_similarity in zip(
            batch_sample_token_lls, batch_lls, batch_greedy_sentence_similarity
        ):
            # Compute probabilities (negative log-probs)
            prob = -greedy_ll

            # Compute row-wise average similarity, excluding self-similarity
            # Diagonal contains self-similarities
            d = 1 - np.array(greedy_sentence_similarity)
            probs = np.array([np.exp(sum(x)) for x in sample_token_lls])
            probs = process_probs(probs, **self.process_probs_args)
            avg_dissimilarity = (d * probs).sum()

            enriched_metric = prob * avg_dissimilarity
            enriched_metrics.append(enriched_metric)

        return np.array(enriched_metrics)


class CocoaPPL(Estimator):
    def __init__(
        self,
        samples_source: str = "sample"
    ):
        super().__init__(
            [f"greedy_{samples_source}_sentence_similarity", "greedy_log_likelihoods"], "sequence"
        )
        self.samples_source = samples_source

    def __str__(self):
        base = "CocoaPPL"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        batch_greedy_log_likelihoods = stats["greedy_log_likelihoods"]
        batch_greedy_sentence_similarity = [x[0][1:] for x in stats[f"greedy+{self.samples_source}_semantic_matrix_entail"]]

        enriched_ppl = []  # To store enriched PPL for each sample

        for greedy_log_likelihoods, greedy_sentence_similarity in zip(
            batch_greedy_log_likelihoods, batch_greedy_sentence_similarity
        ):
            # get PPL for each sample
            ppl = -np.mean(greedy_log_likelihoods)

            # Compute row-wise average similarity, excluding self-similarity
            avg_dissimilarity = np.mean(1 - np.array(greedy_sentence_similarity))

            enriched_value = ppl * avg_dissimilarity
            enriched_ppl.append(enriched_value)

        return np.array(enriched_ppl)


class CocoaPPLP(Estimator):
    def __init__(
        self,
        samples_source: str = "sample",
        **process_probs_args,
    ):
        super().__init__(
            [
                f"greedy_{samples_source}_sentence_similarity",
                f"{samples_source}_log_likelihoods",
                "greedy_log_likelihoods",
            ], "sequence"
        )
        self.samples_source = samples_source
        self.process_probs_args = process_probs_args

    def __str__(self):
        base = "CocoaPPLP"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        batch_greedy_log_likelihoods = stats["greedy_log_likelihoods"]
        batch_greedy_sentence_similarity = [x[0][1:] for x in stats[f"greedy+{self.samples_source}_semantic_matrix_entail"]]
        batch_sample_token_lls = stats[f"{self.samples_source}_log_likelihoods"]

        enriched_ppl = []  # To store enriched PPL for each sample

        for sample_token_lls, greedy_log_likelihoods, greedy_sentence_similarity in zip(
            batch_sample_token_lls, batch_greedy_log_likelihoods, batch_greedy_sentence_similarity
        ):
            # get PPL for each sample
            ppl = -np.mean(greedy_log_likelihoods)

            # Compute row-wise average similarity, excluding self-similarity
            d = 1 - np.array(greedy_sentence_similarity)
            probs = np.array([np.exp(sum(x)) for x in sample_token_lls])
            probs = process_probs(probs, **self.process_probs_args)
            avg_dissimilarity = (d * probs).sum()

            enriched_value = ppl * avg_dissimilarity
            enriched_ppl.append(enriched_value)

        return np.array(enriched_ppl)


class CocoaMTE(Estimator):
    def __init__(
        self,
        samples_source: str = "sample"
    ):
        super().__init__([f"greedy_{samples_source}_sentence_similarity", "entropy"], "sequence")
        self.samples_source = samples_source

    def __str__(self):
        base = "CocoaMTE"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        batch_greedy_entropy = stats["entropy"]
        batch_greedy_sentence_similarity = [x[0][1:] for x in stats[f"greedy+{self.samples_source}_semantic_matrix_entail"]]

        enriched_entropy = []

        for greedy_entropy, greedy_sentence_similarity in zip(
            batch_greedy_entropy, batch_greedy_sentence_similarity
        ):
            #  Compute row-wise average similarity, excluding self-similarity
            avg_dissimilarity = np.mean(1 - np.array(greedy_sentence_similarity))

            entropy = np.mean(greedy_entropy)
            enriched_value = entropy * avg_dissimilarity
            enriched_entropy.append(enriched_value)

        return np.array(enriched_entropy)


class CocoaMTEP(Estimator):
    def __init__(
        self,
        samples_source: str = "beamsearch",
        **process_probs_args,
    ):
        super().__init__([
            f"greedy_{samples_source}_sentence_similarity",
            f"{samples_source}_log_likelihoods",
            "entropy",
        ], "sequence")
        self.samples_source = samples_source
        self.process_probs_args = process_probs_args

    def __str__(self):
        base = "CocoaMTEP"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        batch_greedy_entropy = stats["entropy"]
        batch_greedy_sentence_similarity = [x[0][1:] for x in stats[f"greedy+{self.samples_source}_semantic_matrix_entail"]]
        batch_sample_token_lls = stats[f"{self.samples_source}_log_likelihoods"]

        enriched_entropy = []

        for sample_token_lls, greedy_entropy, greedy_sentence_similarity in zip(
            batch_sample_token_lls, batch_greedy_entropy, batch_greedy_sentence_similarity
        ):
            #  Compute row-wise average similarity, excluding self-similarity
            d = 1 - np.array(greedy_sentence_similarity)
            probs = np.array([np.exp(sum(x)) for x in sample_token_lls])
            probs = process_probs(probs, **self.process_probs_args)
            avg_dissimilarity = (d * probs).sum()

            entropy = np.mean(greedy_entropy)
            enriched_value = entropy * avg_dissimilarity
            enriched_entropy.append(enriched_value)

        return np.array(enriched_entropy)


class Dissimilarity(Estimator):
    def __init__(
            self,
            samples_source: str = "sample"
    ):
        super().__init__([f"greedy_{samples_source}_sentence_similarity"], "sequence")
        self.samples_source = samples_source

    def __str__(self):
        base = "Dissimilarity"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        batch_greedy_sentence_similarity = [x[0][1:] for x in stats[f"greedy+{self.samples_source}_semantic_matrix_entail"]]

        dissims = []

        for greedy_sentence_similarity in batch_greedy_sentence_similarity:
            dissims.append(np.mean(1 - np.array(greedy_sentence_similarity)))

        return np.array(dissims)


class DissimilarityP(Estimator):
    def __init__(
            self,
            samples_source: str = "beamsearch",
            **process_probs_args,
    ):
        super().__init__([
            f"greedy_{samples_source}_sentence_similarity",
            f"{samples_source}_log_likelihoods",
        ], "sequence")
        self.samples_source = samples_source
        self.process_probs_args = process_probs_args

    def __str__(self):
        base = "DissimilarityP"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        batch_greedy_sentence_similarity = [x[0][1:] for x in stats[f"greedy+{self.samples_source}_semantic_matrix_entail"]]
        batch_sample_token_lls = stats[f"{self.samples_source}_log_likelihoods"]

        dissims = []

        for sample_token_lls, greedy_sentence_similarity in zip(
                batch_sample_token_lls,
                batch_greedy_sentence_similarity,
        ):
            d = 1 - np.array(greedy_sentence_similarity)
            probs = np.array([np.exp(sum(x)) for x in sample_token_lls])
            probs = process_probs(probs, **self.process_probs_args)
            dissims.append((d * probs).sum())

        return np.array(dissims)
