import numpy as np

from typing import Dict

from lm_polygraph.estimators.estimator import Estimator


class EigenScore(Estimator):
    """
    Estimates the sequence-level uncertainty of a language model following the method of
    "EigenScore" as provided in the paper https://openreview.net/forum?id=Zj12nzlQbz.
    Works only with whitebox models (initialized using lm_polygraph.utils.model.WhiteboxModel).
    Uses embeddings for the last generated token from the middle layer of the model.
    """

    def __init__(
        self,
        alpha: float = 1e-3,
        samples_source: str = "sample",
    ):
        super().__init__([f"{samples_source}_embeddings"], "sequence")
        self.alpha = alpha
        self.J_d = None
        self.samples_source = samples_source

    def __str__(self):
        base = "EigenScore"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        """
        Estimates the EigenScore score for each sample in the input statistics.

        Parameters:
            stats (Dict[str, np.ndarray]): input statistics, which for multiple samples includes:
                *  embeddings for several sampled texts in 'sample_embeddings'
        Returns:
            np.ndarray: float EigenScore score for each sample in input statistics.
                Higher values indicate more uncertain samples.
        """
        sample_embeddings = stats[f"{self.samples_source}_embeddings"]
        ue = []
        for embeddings in sample_embeddings:
            sentence_embeddings = np.array(embeddings)
            if self.J_d is None:
                dim = sentence_embeddings.shape[-1]
                self.J_d = np.eye(dim) - 1 / dim * np.ones((dim, dim))
            covariance = sentence_embeddings @ self.J_d @ sentence_embeddings.T
            reg_covariance = covariance + self.alpha * np.eye(covariance.shape[0])
            eigenvalues, _ = np.linalg.eig(reg_covariance)
            ue.append(
                np.mean(np.log([val if val > 0 else 1e-10 for val in eigenvalues]))
            )
        return np.array(ue)
