import torch

from .base_uncertainty import LLMUncertaintyEstimator
from .common import compute_ll_from_transition_scores
from typing import Dict, List


class PerplexityUncertainty(LLMUncertaintyEstimator):
    def __init__(
        self,
        mode='log_pplx' # log_pplx, GNLL
    ):
        super(PerplexityUncertainty, self).__init__()
        self.mode = mode

    def compute_uncertainty(
        self,
        sequences_len, 
        seq_no_input_transition,           
    ) -> Dict:
        if self.mode=='log_pplx':
            val = -compute_ll_from_transition_scores(seq_no_input_transition, sequences_len, length_normalize=True)
            return {self.mode: val.item()}
        elif self.mode=='GNLL':
            val = -compute_ll_from_transition_scores(seq_no_input_transition, sequences_len, length_normalize=False)
            return {self.mode: val.item()}
        else:
            raise NotImplementedError()

    @property
    def get_required_fields(self) -> List[str]:
        # get the fields necessary for the compute uncertainty fn
        return ['sequences_len', 'seq_no_input_transition']

    def prepare_records(self, records) -> List:
        # select ms or bs or combine somehow (unlikely)
        return records['bs']


class PickLogitUncertainty(LLMUncertaintyEstimator):
    def __init__(
        self,
        mode='min_logprob' # 
    ):
        super(PickLogitUncertainty, self).__init__()
        self.mode = mode

    def compute_uncertainty(
        self,
        seq_no_input_transition,
    ) -> Dict:
        try:
            all_values = seq_no_input_transition[~torch.isinf(seq_no_input_transition)] # make sure we dont pick any -infs, would be bad
            return {f"min_logprob": -all_values.min().item()} # nll of the least likely token in the generation
        except Exception as e:
            return {f"min_logprob": torch.nan} # nll of the least likely token in the generation

    @property
    def get_required_fields(self) -> List[str]:
        # get the fields necessary for the compute uncertainty fn
        return ['seq_no_input_transition']

    def prepare_records(self, records) -> List:
        # select ms or bs or combine somehow (unlikely)
        return records['bs']


class GenerationLengthUncertainty(LLMUncertaintyEstimator):
    def __init__(
        self,
        mode='correct' # 'sample', 'correct'
    ):
        super(GenerationLengthUncertainty, self).__init__()
        self.mode = mode

    def compute_uncertainty(
        self,
        sequences_len,       
    ) -> Dict:
        return {f'seqlen_{self.mode}': sequences_len.to(dtype=torch.float).mean(0).item()}

    @property
    def get_required_fields(self) -> List[str]:
        # get the fields necessary for the compute uncertainty fn
        return ['sequences_len']

    def prepare_records(self, records) -> List:
        # select ms or bs or combine somehow (unlikely)
        if self.mode=='correct':
            return records['bs']
        else:
            return records['ms']
