import json
import torch
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, LlavaProcessor, LlavaForConditionalGeneration

system_prompt = """You are a helpful human assistant who always strictly follow the instructions given by users."""
instruction = "Given the question. Please answer the question with an answer. Just output your choice, no further explanation."
input_format = """Question: {question}"""
text_format = """"{system_prompt}\nUSER: {user_input}\nASSISTANT:"""


def transform_into_multiple_choice():
    dataset = load_dataset("derek-thomas/ScienceQA")["train"]
    
    # model = AutoModelForCausalLM.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf", torch_dtype=torch.bfloat16).to(torch.device("cuda:2"))
    # tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
    processor = LlavaProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
    model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf").to(torch.device("cuda:0"))
    
    pb = tqdm(range(len(dataset)))
    cnt_conflict = 0
    cnt_valid = 0
    
    for data_id, data in enumerate(dataset):
        query = data["question"]
        if data["task"] == "closed choice":
            choices = data["choices"]
            choices_text = ""
            for _, choice in enumerate(choices):
                choices_text += chr(ord("A") + _)
                choices_text += f": {choice}\n"
            query += f"\nChoices:\n{choices_text}"
        elif data["task"] == "yes or no":
            query += " Answer the question with yes or no."
        # answer = data["choices"][data["answer"]]
        # messages = [
        #     {"role": "system", "content": system_prompt},
        #     {"role": "user", "content": instruction + "\n" + input_format.format(question=query)},
        # ]
        # input_ids = tokenizer.apply_chat_template(
        #     messages,
        #     add_generation_prompt=True,
        #     return_tensors="pt"
        # ).to(model.device)
        # slen = input_ids.size(1)
        # terminators = [
        #     tokenizer.eos_token_id,
        #     tokenizer.convert_tokens_to_ids("<|eot_id|>")
        # ]
        # slen = input_ids.size(1)
        # outputs = model.generate(
        #     input_ids=input_ids,
        #     # eos_token_id=terminators,
        #     max_new_tokens=10,
        #     do_sample=False,
        # )
        message = text_format.format(system_prompt=system_prompt, user_input=instruction + "\n" + input_format.format(question=query))
        inputs = processor(text=message, return_tensors="pt")
        # inputs = tokenizer(message, return_tensors="pt")
        for k, v in inputs.items():
            if v is not None:
                inputs[k] = v.to(model.device)
        slen = inputs.input_ids.size(1)
        outputs = model.generate(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            # eos_token_id=terminators,
            max_new_tokens=10,
            do_sample=False,
        )
        # output = tokenizer.batch_decode(outputs[:, slen:], skip_special_tokens=True)[0].strip().lower()
        output = processor.batch_decode(outputs[:, slen:], skip_special_tokens=True)[0].strip().lower()
        
        # print(input_format.format(question=query))
        # print(output)
        # input()
        
        # if "no" in output:
        #     cnt_conflict += 1
        #     cnt_valid += 1
        # elif "yes" in output:
        #     cnt_valid += 1
        
        with open("outputs/analysis/scienceqa/predict_llava.txt", "a+") as fout:
            fout.write(f"{json.dumps({data_id: output})}\n")
        pb.update(1)
        
            
        
if __name__ == "__main__":
    transform_into_multiple_choice()