from PIL import Image
import requests
import torch
from torchvision import io
from typing import Dict
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
import os
import json
from tqdm import tqdm
import numpy as np

from qwen_vl_utils import process_vision_info

# default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto", attn_implementation="flash_attention_2",
)

# default processer
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")


def get_prompt(data):
    question = data['question']
    options = data['options']
    length = len(options)
    choices = ''
    if length == 2:
        choices = 'A, B'
    elif length == 3:
        choices = 'A, B, C'
    elif length == 4:
        choices = 'A, B, C, D'
    elif length == 5:
        choices = 'A, B, C, D, E'
    system_prompt = f"Give a multiple choice question with several options, you need to select all the options that can answer the question. \
        The output should only contains the option index and strictly follow this format: {choices}, and don't contains any other contents!!\n"
    prompt =  "Question: " + question+'\n' + 'Options: \n'
    if length == 2:
        caption1 = options['A']
        caption2 = options['B']
        prompt += f"(A) {caption1}\n(B) {caption2}\nOnly output one answer."
    elif length == 3:
        caption1 = options['A']
        caption2 = options['B']
        caption3 = options['C']
        prompt += f"(A) {caption1}\n(B) {caption2}\n(C) {caption3}\nOnly output one answer."
    elif length == 4:
        caption1 = options['A']
        caption2 = options['B']
        caption3 = options['C']
        caption4 = options['D']
        prompt += f"(A) {caption1}\n(B) {caption2}\n(C) {caption3}\n(D) {caption4}\nOnly output one answer."
    elif length == 5:
        caption1 = options['A']
        caption2 = options['B']
        caption3 = options['C']
        caption4 = options['D']
        caption5 = options['E']
        prompt += f"(A) {caption1}\n(B) {caption2}\n(C) {caption3}\n(D) {caption4}\n(E) {caption5}\nOnly output one answer."
    return prompt
batch_size = 1
inp_path = 'MMComposition_finalized/images'
with open('MMComposition_finalized/questions.json','r') as f:
    data = json.load(f)

img_keys = sorted(os.listdir(inp_path),key=lambda x:int(x.split('.')[0]))
length = len(img_keys)
print(length)

end = 4342 // batch_size
img_li = img_keys
res_dict = {}

start = 0 // batch_size
for i in tqdm(range(start,end)):
    image = Image.open(os.path.join(inp_path,img_li[batch_size*i]))
    width, height = image.size
    image = Image.new("RGB", (width, height), (0, 0, 0))
    prompt = get_prompt(data[i])
    conversation = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
                {"type": "text", "text": f"{prompt}"},
            ],
        }
    ]

    text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

    inputs = processor(
        text=[text_prompt], images=[image], padding=True, return_tensors="pt"
    )
    inputs = inputs.to("cuda")

    # Inference: Generation of the output
    output_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids = [
        output_ids[len(input_ids) :]
        for input_ids, output_ids in zip(inputs.input_ids, output_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )
    res_dict[img_li[batch_size*i]] = output_text[0]
with open('one_result_qwen2-vl-72b-blind.json','w') as json_file:
    json.dump(res_dict,json_file,indent=4)
