import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import (
    AbstractSet,
    cast,
    Collection,
    Dict,
    Iterator,
    List,
    Literal,
    Sequence,
    TypedDict,
    Union,
)

Role = Literal["system", "user", "assistant"]


class Message(TypedDict):
    role: Role
    content: str


Dialog = Sequence[Message]

system_prompt = """You are a helpful human assistant who always strictly follow the instructions given by users."""
instruction_accuracy = """Given a gold answer and a predicted answer of the same question, please identify whether the predicted answer is correct. Just output Ture/False, do not give further explanations."""
# instruction_conflict = """Given two answers of the same question, please identify whether there exists conflict between them. Just output Ture/False, do not give further explanations."""
input_format_accuracy = """Question: {question}\nGold Answer: {gold}\nPredicted Answer: {pred}"""
input_format_conflict = """Question: {question}\nAnswer 1: {gold}\nAnswer 2: {pred}"""

instruction_conflict = """Your task is to determine whether two answers to a given question are equivalent, even if there are slight differences in syntax.

Instructions:

1. Read the Question and Answers: Carefully read the provided question and the two answers given.
2. Understand the Context: Understand the context of the question and the expected answers.
3. Identify Differences: Compare the syntax and content of the two answers.
4. Ignore Irrelevant Differences: Ignore minor differences in formatting, punctuation, and grammar that do not change the meaning of the answer.
5. Focus on Semantic Equivalence: Determine whether the two answers convey the same information and meaning, even if the wording differs.
6. Make a Decision: Decide whether the two answers are equivalent or not.
7. Output: True/False.
Note: You just need to output True/False. Please ensure that your evaluation is based on the semantic equivalence of the answers rather than solely on superficial differences in syntax or wording."""


@torch.inference_mode()
def evaluate_llama3(mode, questions, golds, preds):
    model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct",torch_dtype=torch.bfloat16).to(torch.device("cuda:6"))
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
    generate_config = {
        "do_sample": False,
        "max_new_tokens": 10,
    }
    if mode == "accuracy":
        instruction = instruction_accuracy
        input_format = input_format_accuracy
    elif mode == "conflict":
        instruction = instruction_conflict
        input_format = input_format_conflict
    
    cnt_correct = 0
    cnt = 0
    pb = tqdm(range(len(questions)))
    for question, gold, pred in zip(questions, golds, preds):
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": instruction + "\n" + input_format.format(question=question, gold=gold, pred=pred)},
        ]
        # print(messages)
        # input()
        input_ids = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)
        # input_ids = torch.tensor(input_ids).to(torch.device("cuda:6"))
        slen = input_ids.size(1)
        terminators = [
            tokenizer.eos_token_id,
            tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
        outputs = model.generate(
            input_ids,
            eos_token_id=terminators,
            **generate_config,
        )
        answer = tokenizer.batch_decode(outputs[:, slen:], skip_special_tokens=True)[0].strip()
        # print(gold)
        # print(pred)
        # print(answer)
        # input()
        if "true" in answer.lower():
            # print(gold)
            # print(pred)
            # input()
            cnt_correct += 1
        cnt += 1
        # print(question)
        # print("------------------------------------------")
        # print(gold)
        # print("------------------------------------------")
        # print(pred)
        # print("------------------------------------------")
        # print(answer)
        # input()
        pb.update(1)
    return cnt_correct / cnt