import torch
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
from generation.gen_utils import probe_token_id
from tqdm import tqdm
from typing import Dict, List
from copy import deepcopy

from generation.prompts import get_prompt_builder_by_model_id

ptrue_query_message = [{
    "role": "user",
    "message": "Within the context of the question, would you assess that the proposed answer is correct? Respond only with yes or no."
},
{
    "role": "assistant",
    "message": "Response:"
}
]


def digest_txt_xy(s):
    while True:
        s_old = s
        # Regular expression to match `<|` and any characters except `>` till `>`
        pattern = r'>\|.*?\|<'
        match = re.search(pattern, s[::-1])
        # If a match is found, remove it from the end
        if match and match.start()==0:
            s = s[:-match.end()]
        else:
            break
    return s


from .base_uncertainty import LLMUncertaintyEstimator
from transformers import AutoTokenizer, AutoModelForCausalLM

class PTrueEstimator(LLMUncertaintyEstimator):
    def __init__(
        self,
        model_kwargs: dict = {},
        tokenizer_kwargs: dict = {}
    ):
        self.model_id = None
        self.model = None
        self.tokenizer = None
        self.prompt_builder = None
        
        self.model_kwargs = model_kwargs
        self.tokenizer_kwargs = tokenizer_kwargs

    @torch.no_grad()
    def compute_uncertainty(self, model, dataset, txt_xy_full, **kwargs) -> Dict:
        if self.model_id != model:
            self.model_id = model
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, **self.tokenizer_kwargs)
            temp_kwargs = deepcopy(self.model_kwargs)
            if 'amba' in self.model_id:
                pass
            else:
                if '70' in self.model_id:
                    temp_kwargs['load_in_4bit'] = True 
                    temp_kwargs['torch_dtype'] = "auto"
                temp_kwargs['attn_implementation'] = "flash_attention_2"
            print(temp_kwargs)
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_id, 
                **temp_kwargs,
            ) #  device_map=use_device, torch_dtype=torch.bfloat16,
            # now get the prompt builder
            self.prompt_builder = get_prompt_builder_by_model_id(self.model_id, dataset=dataset)
        
        # run the p true computation
        prior_interaction = digest_txt_xy(txt_xy_full[0])
        
        text_in = self.prompt_builder.build_prompt(ptrue_query_message, starting_string=prior_interaction)
        llm_in = self.tokenizer([text_in,], add_special_tokens=False, return_tensors='pt').to(next(self.model.parameters()).device)

        # optionally expand to more tokens, not sure if that would add value
        toklist_yes = ['yes',]
        toklist_no = ['no',]
        idlist_yes = list(set([probe_token_id(t, self.tokenizer) for t in toklist_yes]))
        idlist_no = list(set([probe_token_id(t, self.tokenizer) for t in toklist_no]))

        # run the model
        outs = self.model(**llm_in)
        logits = outs.logits[0, -1, :]

        # collect the logrpobs
        # raw_yes_logits = torch.logsumexp(logits[idlist_yes], dim=0).item()
        # raw_no_logits = torch.logsumexp(logits[idlist_no], dim=0).item()
        raw_yes_logits = logits[idlist_yes].item()
        raw_no_logits = logits[idlist_no].item()
        # print(yes_logits, no_logits)
        fin_logits = torch.tensor([raw_yes_logits, raw_no_logits])
        fin_logits = torch.log_softmax(fin_logits, dim=0)
        yes_logits = fin_logits[0].item()
        no_logits = fin_logits[1].item()

        retdict = {
            'ptrue_pos_log_prob': yes_logits,
            'ptrue_neg_log_prob': no_logits,
            'ptrue_raw_pos_log_prob': raw_yes_logits,
            'ptrue_raw_neg_log_prob': raw_no_logits,
        }
        return retdict

    @property
    def get_required_fields(self) -> List[str]:
        # get the fields necessary for the compute uncertainty fn
        return ['dataset', 'model', 'txt_xy_full']

    def prepare_records(self, records) -> List:
        # select ms or bs or combine somehow (unlikely)
        return records['bs']
