import transformers
import torch
import os
import json
from tqdm import tqdm


model_id = "Meta-Llama-3.1-70B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)

def get_prompt(data):
    question = data['question']
    options = data['options']
    length = len(options)
    choices = ''
    if length == 2:
        choices = 'A, B'
    elif length == 3:
        choices = 'A, B, C'
    elif length == 4:
        choices = 'A, B, C, D'
    elif length == 5:
        choices = 'A, B, C, D, E'
    system_prompt = f"Give a single choice question with several options, you must only select one answer."
    prompt =  "Question: " + question+'\n' + 'Options: \n'
    if length == 2:
        caption1 = options['A']
        caption2 = options['B']
        prompt += f"(A) {caption1}\n(B) {caption2}\nOutput only one answer."
    elif length == 3:
        caption1 = options['A']
        caption2 = options['B']
        caption3 = options['C']
        prompt += f"(A) {caption1}\n(B) {caption2}\n(C) {caption3}\nOutput only one answer."
    elif length == 4:
        caption1 = options['A']
        caption2 = options['B']
        caption3 = options['C']
        caption4 = options['D']
        prompt += f"(A) {caption1}\n(B) {caption2}\n(C) {caption3}\n(D) {caption4}\nOutput only one answer."
    elif length == 5:
        caption1 = options['A']
        caption2 = options['B']
        caption3 = options['C']
        caption4 = options['D']
        caption5 = options['E']
        prompt += f"(A) {caption1}\n(B) {caption2}\n(C) {caption3}\n(D) {caption4}\n(E) {caption5}\nOutput only one answer."
    return system_prompt,prompt

batch_size = 1
inp_path = 'MMComposition_finalized/images'
with open('MMComposition_finalized/questions.json','r') as f:
    data = json.load(f)

img_keys = sorted(os.listdir(inp_path),key=lambda x:int(x.split('.')[0]))
length = len(img_keys)

end = 3942 // batch_size
img_li = img_keys
res_dict = {}
with open('one_result_llama31-70b.json','r') as f:
    res_dict = json.load(f)
    
start = len(res_dict) // batch_size
for i in tqdm(range(start,end)):
    system_prompt,prompt = get_prompt(data[i])
    messages = [
        {"role": "system", "content": f"{system_prompt}"},
        {"role": "user", "content": f"{prompt}"},
    ]

    outputs = pipeline(
        messages,
        max_new_tokens=32,
    )
    res_dict[img_li[batch_size*i]] = outputs[0]["generated_text"][-1]['content']
with open(f'one_result_llama31-70b.json','w') as json_file:
    json.dump(res_dict,json_file,indent=4)