from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import json
from tqdm import tqdm

model_name = "Qwen2.5-72B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)



def get_prompt(data):
    question = data['question']
    options = data['options']
    length = len(options)
    choices = ''
    if length == 2:
        choices = 'A, B'
    elif length == 3:
        choices = 'A, B, C'
    elif length == 4:
        choices = 'A, B, C, D'
    elif length == 5:
        choices = 'A, B, C, D, E'
    system_prompt = f"Give a single choice question with several options, you must only select one answer."
    prompt =  "Question: " + question+'\n' + 'Options: \n'
    if length == 2:
        caption1 = options['A']
        caption2 = options['B']
        prompt += f"(A) {caption1}\n(B) {caption2}\nOutput only one answer."
    elif length == 3:
        caption1 = options['A']
        caption2 = options['B']
        caption3 = options['C']
        prompt += f"(A) {caption1}\n(B) {caption2}\n(C) {caption3}\nOutput only one answer."
    elif length == 4:
        caption1 = options['A']
        caption2 = options['B']
        caption3 = options['C']
        caption4 = options['D']
        prompt += f"(A) {caption1}\n(B) {caption2}\n(C) {caption3}\n(D) {caption4}\nOutput only one answer."
    elif length == 5:
        caption1 = options['A']
        caption2 = options['B']
        caption3 = options['C']
        caption4 = options['D']
        caption5 = options['E']
        prompt += f"(A) {caption1}\n(B) {caption2}\n(C) {caption3}\n(D) {caption4}\n(E) {caption5}\nOutput only one answer."
    return system_prompt,prompt

batch_size = 1
inp_path = 'MMComposition_finalized/images'
with open('MMComposition_finalized/questions.json','r') as f:
    data = json.load(f)

img_keys = sorted(os.listdir(inp_path),key=lambda x:int(x.split('.')[0]))
length = len(img_keys)

end = 3942 // batch_size
img_li = img_keys
res_dict = {}

with open('one_result_qwen25-72b.json','r') as f:
    res_dict = json.load(f)

start = len(res_dict) // batch_size
for i in tqdm(range(start,end)):
    _,prompt = get_prompt(data[i])
    messages = [
        {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant. Answer the question in 10 words"},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=128
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    res_dict[img_li[batch_size*i]] = response   
with open(f'llama31-70b_image_blind.json','w') as json_file:
    json.dump(res_dict,json_file,indent=4)
