import torch
from tqdm import tqdm

from .base_uncertainty import LLMUncertaintyEstimator
from .similarity import compute_similarity_matrix_nli
from .clustering import compute_seq_id_clusters
from .common import compute_clustered_entropy, compute_ll_from_transition_scores
from typing import Dict, List
from sentence_transformers import CrossEncoder


class SemanticEntropy(LLMUncertaintyEstimator):
    def __init__(
        self,
        nli_model="cross-encoder/nli-deberta-v3-base",
        use_device='cpu',
        return_cache = False,
        length_normalize = False,
        **nli_model_kwargs,
    ):
        super(SemanticEntropy, self).__init__()
        self.return_cache = return_cache
        self.model = CrossEncoder(nli_model, **nli_model_kwargs, num_labels=3, device=use_device)
        self.length_normalize = length_normalize

    def compute_uncertainty(
        self, 
        txt_y,
        sequences_len, 
        seq_no_input_transition, 
        norm_log_scores=None, 
        kuhn_sem_sim=None,
        txt_x_no_prompt=None,
        txt_x=None
    ) -> Dict:
        if txt_x_no_prompt is None:
            if txt_x is None:
                txt_x_no_prompt = ''
            else:
                txt_x_no_prompt = txt_x

        retval = {}
        retcache = {}
        # compute sequence probs
        if norm_log_scores is None:
            seq_nlls = compute_ll_from_transition_scores(
                seq_no_input_transition, 
                sequences_len, 
                length_normalize=self.length_normalize
            )
            norm_log_scores = torch.log_softmax(seq_nlls, -1) # log softmax them to normalize the log probs
            retcache['norm_log_scores'] = norm_log_scores

        # compute similarity matrix
        if kuhn_sem_sim is None:
            kuhn_sem_sim = compute_similarity_matrix_nli(
                txt_y, 
                model=self.model,
                context=txt_x_no_prompt,
            )
            retcache['kuhn_sem_sim'] = kuhn_sem_sim
        # compute clustering
        clusters = compute_seq_id_clusters(kuhn_sem_sim)
        
        # compute se
        sent = compute_clustered_entropy(clusters, norm_log_scores).item()
        
        metric_name = 'sement' if not self.length_normalize else 'len_norm_sement'
        if self.return_cache:
            return {metric_name: sent}, retcache
        else:
            return {metric_name: sent}

    @property
    def get_required_fields(self) -> List[str]:
        # get the fields necessary for the compute uncertainty fn
        return ['txt_y', 'sequences_len', 'seq_no_input_transition', 'norm_log_scores', 'txt_x_no_prompt']

    def prepare_records(self, records) -> List:
        # select ms or bs or combine somehow (unlikely)
        return records['ms']
