from transformers import LlamaForCausalLM, LlamaTokenizer
import math
import torch
from tqdm import tqdm

def construct_p_true_prompt(questions, responses):
    prompts = []
    for question, response in zip(questions, responses):
        prompt = ''
        # few_shot = "Question: Who was the third president of the United States?\nPossible answer: James Monroe\nIs the possible answer:\nA) True\nB) False\nThe possible answer is: B\n\nQuestion: Calculate 33 + 4\nPossible answer: 37\nIs the possible answer:\nA) True\nB) False\nThe possible answer is: A\n\nQuestion: Fill in the blank in the sentence 'I went to the grocery and then to the pharmacy. I was disappointed that they didn't have any vegetarian sausage at the _____.'\nPossible answer: grocery\nIs the possible answer:\nA) True\nB) False\nThe possible answer is: A\n\nQuestion: Name a celebrated civil rights leader.\nPossible answer: Martin Luther King\nIs the possible answer:\nA) True\nB) False\nThe possible answer is: A\n\nQuestion: Calculate 33 * 849\nPossible answer: 28347\nIs the possible answer:\nA) True\nB) False\nThe possible answer is: B\n\nQuestion: Fill in the blank in the sentence 'I shot the _____ and it went swish. We walked away the winners of that battle!'\nPossible answer: gun\nIs the possible answer:\nA) True\nB) False\nThe possible answer is: B\n\n"
        few_shot = "Question: Who was the third president of the United States?\nPossible answer: James Monroe\nIs the possible answer:\nA) True\nB) False\nThe possible answer is: False\n\nQuestion: Calculate 33 + 4\nPossible answer: 37\nIs the possible answer:\nA) True\nB) False\nThe possible answer is: True\n\nQuestion: Fill in the blank in the sentence 'I went to the grocery and then to the pharmacy. I was disappointed that they didn't have any vegetarian sausage at the _____.'\nPossible answer: grocery\nIs the possible answer:\nA) True\nB) False\nThe possible answer is: True\n\nQuestion: Name a celebrated civil rights leader.\nPossible answer: Martin Luther King\nIs the possible answer:\nA) True\nB) False\nThe possible answer is: True\n\nQuestion: Calculate 33 * 849\nPossible answer: 28347\nIs the possible answer:\nA) True\nB) False\nThe possible answer is: False\n\nQuestion: Fill in the blank in the sentence 'I shot the _____ and it went swish. We walked away the winners of that battle!'\nPossible answer: gun\nIs the possible answer:\nA) True\nB) False\nThe possible answer is: False\n\n"
        prompt += few_shot
        prompt += 'Question: ' + question + '\n'
        prompt += 'Possible answer: ' + response + '\n'
        prompt += 'Is the possible answer:\n'
        prompt += 'A) True\n'
        prompt += 'B) False\n'
        prompt += 'The possible answer is:'
        prompts.append(prompt)

    return prompts

def get_p_true(input_data, model, tokenizer):
        """Get the probability of the model anwering A (True) for the given input."""

        # input_data += ' A'
        input_data += ' True'
        # print(f'prompt:{input_data}')
        tokenized_prompt_true = tokenizer(input_data, return_tensors='pt').to('cuda')['input_ids']
        # The computation of the negative log likelihoods follows:
        # https://huggingface.co/docs/transformers/perplexity.

        target_ids_true = tokenized_prompt_true.clone()
        # Set all target_ids except the last one to -100.
        target_ids_true[0, :-1] = -100

        with torch.no_grad():
            model_output_true = model(tokenized_prompt_true, labels=target_ids_true)

        loss_true = model_output_true.loss

        return -loss_true.item()

def calculate_p_true(questions, responses, model, tokenizer, knowledge_list=None, with_doc=False):
    """Calculate p_true uncertainty metric."""
    if with_doc:
        prompts = construct_p_true_plus_doc_prompt(questions, responses, knowledge_list)
    else:
        prompts = construct_p_true_prompt(questions, responses)
    
    p_true_scores = []
    for prompt in tqdm(prompts, desc='Calculating p_true'):
        log_prob = get_p_true(prompt, model, tokenizer)
        p_true_scores.append(-log_prob)
    
    return p_true_scores