import os
import random
import json
from tqdm import tqdm
import argparse
import pathlib
from load_aokvqa import load_aokvqa, get_coco_path
import ollama

from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

random.seed(0)

def get_qwen_result(image_path, prompt, processor, model):

    content_list = []

    if os.path.exists(image_path):
        content_list.append({
            "type": "image",
            "image": image_path
        })

    content_list.append({
        "type": "text",
        "text": prompt
    })

    messages = [
        {
            "role": "user",
            "content": content_list
        }
    ]

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    output = output_text[0]

    return output

def prompt_element(d, context=None, include_choices=False, answer=False):
    return (f"Context: {context}\n" if context is not None else '') + \
            f"Q: {d['question']}\n" + \
           (f"Options: {', '.join(d['choices'])}.\n" if include_choices else '') + \
            f"A:" + (f"The correct answer is {d['choices'][d['correct_choice_idx']]},because {d['rationales']}" if answer else '')

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
    parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
    parser.add_argument('--shot-ids', nargs='+', required=True, help='List of question IDs to use as few-shot examples')
    parser.add_argument('--n', type=int, default=10, dest='num_examples')
    parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file')
    parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix')
    parser.add_argument('--include-choices', action='store_true', dest='include_choices')
    parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file')
    parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
    parser.add_argument('--temperature', type=float, default=0.5)
    parser.add_argument('--max_tokens',
                        type=int,
                        default=512,
                        help='The maximum number of tokens allowed for the generated answer.')
    parser.add_argument('--top_p', type=float, default=1.0)
    parser.add_argument('--frequency_penalty', type=float, default=0.0)
    parser.add_argument('--presence_penalty', type=float, default=0.0)
    args = parser.parse_args()

    coco_dir = "/home/test/yxl/MCoT/data/COCO"
    train_set = load_aokvqa(args.aokvqa_dir, 'train')
    eval_set = load_aokvqa(args.aokvqa_dir, args.split)

    # 创建快速查找的字典
    train_dict = {e['question_id']: e for e in train_set}
    # 按顺序获取指定ID对应的示例
    shot_examples = [train_dict[qid] for qid in args.shot_ids
                        if qid in train_dict]

    print(shot_examples)

    train_context = {}
    context = {}
    if args.context_file is not None:
        train_context = json.load(args.train_context_file)
        context = json.load(args.context_file)

    predictions = {}

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
    )

    processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")

    for d in tqdm(eval_set):
        q = d['question_id']

        prompt = args.prompt_prefix

        for e in shot_examples:
            prompt += prompt_element(e,
                                     context=train_context.get(q, None),
                                     include_choices=True,
                                     answer=True
                                     )
            prompt += '\n\n'

        prompt += "Please answer the following question in the form \"The correct answer is ,because\" based on the above example \n\n"
        prompt += prompt_element(d,
                                 context=context.get(q, None),
                                 include_choices=True,
                                 answer=False
                                 )
        image_path = get_coco_path('val', d['image_id'], coco_dir)
        output = get_qwen_result(image_path, prompt, processor, model)
        predictions[q] = output

    json.dump(predictions, args.output_file)



if __name__ == '__main__':
    main()