import json
import torch
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

system_prompt = """You are a helpful human assistant who always strictly follow the instructions given by users."""
instruction = "Given the question and an answer, please verify whether the answer is true or not. unidentifiable can only be used when asked about the information in a certain image. Remember you only need to output yes/no/unidentifiable. No further explanation or reason."
input_format = """Question: {question}\nAnswer: {answer}"""
text_format = """"{system_prompt}\nUSER: {user_input}\nASSISTANT:"""


def transform_into_multiple_choice():
    dataset = load_dataset("derek-thomas/ScienceQA")["train"]
    
    model = AutoModelForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5", torch_dtype=torch.bfloat16).to(torch.device("cuda:2"))
    tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
    
    pb = tqdm(range(len(dataset)))
    cnt_conflict = 0
    cnt_valid = 0
    
    for data_id, data in enumerate(dataset):
        query = data["question"]
        if data["task"] == "closed choice":
            query += f"\nChoices:\n{"\n".join(data["choices"])}"
        answer = data["choices"][data["answer"]]
        # messages = [
        #     {"role": "system", "content": system_prompt},
        #     {"role": "user", "content": instruction + "\n" + input_format.format(question=query, answer=answer)},
        # ]
        # input_ids = tokenizer.apply_chat_template(
        #     messages,
        #     add_generation_prompt=True,
        #     return_tensors="pt"
        # ).to(model.device)
        # slen = input_ids.size(1)
        # terminators = [
        #     tokenizer.eos_token_id,
        #     # tokenizer.convert_tokens_to_ids("<|eot_id|>")
        # ]
        message = text_format.format(system_prompt=system_prompt, user_input=instruction + "\n" + input_format.format(question=query, answer=answer))
        inputs = tokenizer(message, return_tensors="pt")
        for k, v in inputs.items():
            if v is not None:
                inputs[k] = v.to(model.device)
        slen = inputs.input_ids.size(1)
        outputs = model.generate(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            # eos_token_id=terminators,
            max_new_tokens=10,
        )
        output = tokenizer.batch_decode(outputs[:, slen:], skip_special_tokens=True)[0].strip().lower()
        
        # print(input_format.format(question=query, answer=answer))
        # print(output)
        # input()
        
        if "no" in output:
            cnt_conflict += 1
            cnt_valid += 1
        elif "yes" in output:
            cnt_valid += 1
        
        with open("outputs/analysis/scienceqa/align_vicuna.txt", "a+") as fout:
            fout.write(f"{json.dumps({data_id: output})}\n")
        pb.update(1)
        
    print(f"Valid: {cnt_valid}")
    print(f"Conflict: {cnt_conflict}")
    print(f"Conflict Rate: {cnt_conflict / cnt_valid}")
            
        
if __name__ == "__main__":
    transform_into_multiple_choice()