import torch
from tqdm import tqdm

from .base_uncertainty import LLMUncertaintyEstimator
from .similarity import compute_seq_id_similarity
from .clustering import compute_seq_id_clusters
from .common import compute_clustered_entropy, compute_ll_from_transition_scores
from typing import Dict, List


class PredictiveEntropy(LLMUncertaintyEstimator):
    def __init__(
        self,
        length_normalize = False,
    ):
        super(PredictiveEntropy, self).__init__()
        self.length_normalize = length_normalize

    def compute_uncertainty(
        self, 
        txt_y,
        sequences_len, 
        seq_no_input_transition, 
        norm_log_scores=None, 
        id_sem_sim=None,          
    ) -> Dict:
        retval = {}
        retcache = {}
        # compute sequence probs
        if norm_log_scores is None:
            seq_nlls = compute_ll_from_transition_scores(
                seq_no_input_transition, 
                sequences_len, 
                length_normalize=self.length_normalize
            )
            norm_log_scores = torch.log_softmax(seq_nlls, -1) # log softmax them to normalize the log probs
            retcache['norm_log_scores'] = norm_log_scores

        # compute similarity matrix
        if id_sem_sim is None:
            id_sem_sim = compute_seq_id_similarity(
                txt_y
            )
            retcache['id_sem_sim'] = id_sem_sim
        # compute clustering
        clusters = compute_seq_id_clusters(id_sem_sim)
        
        # compute se
        pent = compute_clustered_entropy(clusters, norm_log_scores).item()
        
        metric_name = 'predent' if self.length_normalize else 'len_norm_predent'
        return {metric_name: pent}

    @property
    def get_required_fields(self) -> List[str]:
        # get the fields necessary for the compute uncertainty fn
        return ['txt_y', 'sequences_len', 'seq_no_input_transition']

    def prepare_records(self, records) -> List:
        # select ms or bs or combine somehow (unlikely)
        return records['ms']
