from sklearn.decomposition import PCA
from .base_uncertainty import LLMUncertaintyEstimator
from typing import Dict, List
import torch
import numpy as np


class EigenScore(LLMUncertaintyEstimator):
    def __init__(
        self,
        hmin = -5.,
        hmax = 5.,
        use_eig_ratios = True,
        gamma = 1e-10,
    ):
        self.gamma = gamma
        self.use_eig_ratios = use_eig_ratios
        self.hmin = hmin
        self.hmax = hmax

    def compute_uncertainty(
        self, 
        hidden_at_xy
    ) -> Dict:
        # use the embs to compute the eigenscore
        es = EigenScore._compute_eigenscore(
            hidden_at_xy, 
            hmin=self.hmin, 
            hmax=self.hmax, 
            use_eig_ratios=self.use_eig_ratios)
        return {'EigenScore': es,}        

    @property
    def get_required_fields(self) -> List[str]:
        # get the fields necessary for the compute uncertainty fn
        return ['hidden_at_xy']

    def prepare_records(self, records) -> List:
        # select ms or bs or combine somehow (unlikely)
        return records['ms']

    @staticmethod
    def _compute_eigenscore(
        embs, # [n_sample, embedding_size]
        hmin = -5.,
        hmax = 5.,
        use_eig_ratios = True,
        gamma = 1e-10,
    ):
        clamped_embs = torch.clamp(embs, max=hmax, min=hmin)

        # compute the eigen decomposition for embedding vectors
        pca = PCA()
        pca = pca.fit(clamped_embs.numpy())
        eigs = pca.explained_variance_ratio_ if use_eig_ratios else pca.explained_variance_
        
        # in case of ratios can be nan when redundant sample is provided
        # a common thing in QA datasets tbh
        # in this case, the confidence should be maximum and we should assign the highest value of 1 to it
        # TODO: would be have poorly if used without eig ratios
        eigs = np.nan_to_num(eigs, 1.)

        # now compute the eigenscore
        eigs = np.clip(eigs, a_min=gamma, a_max=None)
        eigenscore = np.log(eigs).mean() 
        
        return eigenscore
