import torch
from tqdm import tqdm

from .base_uncertainty import LLMUncertaintyEstimator
# from .similarity import compute_similarity_matrix_nli
# from .clustering import compute_seq_id_clusters
from .common import compute_ll_from_transition_scores
from .ptrue import digest_txt_xy
from typing import Dict, List
from sentence_transformers import CrossEncoder, SentenceTransformer
import itertools
from transformers import AutoTokenizer


def itertools_batched(iterable, bs, strict=False):
    batch = []
    for item in iterable:
        batch.append(item)
        if len(batch) == bs:
            yield batch
            batch = []
    if len(batch) > 0:
        yield batch # last batch if any


def sentran_compute_cross_encoder_like(
    model: SentenceTransformer, 
    inputs
):
    try:
        ins_a, ins_b = list(zip(*inputs))
        enc_a = model.encode(ins_a)
        enc_b = model.encode(ins_b)
    except Exception as e:
        # print(ins_a,ins_b)
        print(inputs)
        print(len(inputs[0]), len(inputs))
        print(e)
        
        raise e

    return model.similarity_pairwise(enc_a, enc_b)


class ShiftingAttentionToRelevance(LLMUncertaintyEstimator):
    def __init__(
        self,
        nli_model='cross-encoder/ms-marco-MiniLM-L-6-v2', #"allenai/longformer-base-4096",
        use_device='cpu',
        return_cache = False,
        compute_token_sar = True,
        compute_sentence_sar = True,
        cross_encode_batch = 128,
        limit_token_number = 500,
        **nli_model_kwargs,
    ):
        super(ShiftingAttentionToRelevance, self).__init__()

        # basic configuration
        self.return_cache = return_cache
        self.use_device = use_device
        self.ce_batch_size = cross_encode_batch
        self.sar_temperature = 1.

        self.compute_token_sar = compute_token_sar
        self.compute_sentence_sar = compute_sentence_sar
        
        # the senetence encoder
        # self.sem_model = CrossEncoder(nli_model, **nli_model_kwargs, device=use_device)
        self.sem_model = SentenceTransformer(nli_model, **nli_model_kwargs, device=use_device)

        # for tokenizing / detokenizing the original intput
        self.model_id = None
        self.tokenizer = None

    @torch.no_grad()
    def compute_uncertainty(
        self, 
        txt_y,
        txt_x_no_prompt,
        sequences_len, 
        sequences_no_input,
        seq_no_input_transition, 
        model,
    ) -> Dict:
        if self.model_id != model:
            self.model_id = model
            # get the right tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        
        # create the LOTO grid
        loto_grid_y = []
        sentsar_grid_y = []
        for i, seqlen in enumerate(sequences_len):
            # for a given sequence create strings
            sel_tok_seq = sequences_no_input[i, :seqlen].tolist()
            os = self.tokenizer.decode([t for i, t in enumerate(sel_tok_seq)])
            os = digest_txt_xy(os)
            sentsar_grid_y.append(os)
            
            for j in range(seqlen):
                ts = self.tokenizer.decode([t for i, t in enumerate(sel_tok_seq) if i!=j])
                # ts is now the text of the generation WITH special tokens, remove them
                ts = digest_txt_xy(ts)
                # special tokens will now have relevance of 0
                loto_grid_y.append({'text_o': os, 'text_m': ts, 'seq': i, 'tok_position': j})
        
        # now run the cross encoder against the reference sequences
        tok_relevances = torch.zeros_like(sequences_no_input, dtype=torch.float)
        for tok_seqs in itertools_batched(loto_grid_y, self.ce_batch_size):
            # load up into the cross encoder
            # scores = self.sem_model.predict([(txt_x_no_prompt+' '+seq['text_o'], txt_x_no_prompt+' '+seq['text_m']) for seq in tok_seqs])
            scores = sentran_compute_cross_encoder_like(self.sem_model, [(txt_x_no_prompt+' '+seq['text_o'], txt_x_no_prompt+' '+seq['text_m']) for seq in tok_seqs])
            for i in range(len(tok_seqs)):
                tok_seqs[i]['score'] = scores[i].item()
                # write into the relevances array
                seq_id = tok_seqs[i]['seq']
                tok_id = tok_seqs[i]['tok_position']
                tok_relevances[seq_id, tok_id] = scores[i]
        tok_relevances = 1 - tok_relevances
        
        # congrats, we got the token relevances!
        # now compute the sentence relevance
        sentsar_pairs = []
        for i, si in enumerate(sentsar_grid_y):
            for j, sj in enumerate(sentsar_grid_y):
                sentsar_pairs.append({'si': si, 'sj': sj, 'i': i, 'j': j})
        
        sent_similarities = torch.zeros((sequences_len.shape[0],sequences_len.shape[0]))
        for sents in itertools_batched(sentsar_pairs, self.ce_batch_size):
            # scores = self.sem_model.predict([(s['si'], s['sj']) for s in sents])
            scores = sentran_compute_cross_encoder_like(self.sem_model, [(s['si'], s['sj']) for s in sents])
            for i in range(len(sents)):
                sents[i]['score'] = scores[i]
                sent_similarities[sents[i]['i'], sents[i]['j']] = scores[i]
        
        # now we have sentence relevances
        # compute all types of SAR
        # token SAR first:
        norm_tok_rel = tok_relevances / tok_relevances.sum(-1).unsqueeze(-1)
        Et = - seq_no_input_transition*norm_tok_rel
        token_sar = compute_ll_from_transition_scores(Et, batch_lens=sequences_len, length_normalize=False)
        mean_token_sar = token_sar.mean()

        # sentence SAR computation
        raw_sentence_nlls = compute_ll_from_transition_scores(seq_no_input_transition, sequences_len, length_normalize=True)
        sent_similarities = sent_similarities*(1-torch.eye(sent_similarities.shape[0]))
        R_S = sent_similarities@raw_sentence_nlls.exp()
        E_S = - torch.log(raw_sentence_nlls.exp() + R_S/self.sar_temperature)
        sent_sar = E_S.mean()

        # compute the combined SAAR
        exp_ts = (-token_sar).exp()
        R_S = sent_similarities@exp_ts.exp()
        E_S = - torch.log(exp_ts + R_S/self.sar_temperature)
        all_sar = E_S.mean()
        
        retdict = {
            'TOKEN_SAR': mean_token_sar.item(),
            'SENT_SAR': sent_sar.item(),
            'SAR': all_sar.item(),
            '_sar_temp': self.sar_temperature,
        }

        if self.return_cache:
            retdict['_sentsar_pairs'] = sentsar_pairs
            retdict['_toksar_pairs'] = loto_grid_y
        
        return retdict

    @property
    def get_required_fields(self) -> List[str]:
        # get the fields necessary for the compute uncertainty fn
        return [
            'txt_y',
            'txt_x_no_prompt',
            'sequences_len', 
            'sequences_no_input',
            'seq_no_input_transition', 
            'model',
        ]

    def prepare_records(self, records) -> List:
        # select ms or bs or combine somehow (unlikely)
        return records['ms']
