import os
import random
import json
from tqdm import tqdm
import argparse
import pathlib
from load_aokvqa import load_aokvqa, get_coco_path
import ollama
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from collections import Counter

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

random.seed(3)


def get_qwen_result(image_path, prompt, processor, model):

    content_list = []

    if os.path.exists(image_path):
        content_list.append({
            "type": "image",
            "image": image_path
        })

    content_list.append({
        "type": "text",
        "text": prompt
    })

    messages = [
        {
            "role": "user",
            "content": content_list
        }
    ]

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    output = output_text[0]

    return output

def prompt_element(d, context=None, include_choices=False, answer=False):
    return (f"Context: {context}\n" if context is not None else '') + \
            f"Q: {d['question']}\n" + \
           (f"Options: {', '.join(d['choices'])}.\n" if include_choices else '') + \
            f"A:" + (f"The correct answer is {d['choices'][d['correct_choice_idx']]},because {d['rationales']}" if answer else '')

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
    parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
    parser.add_argument('--n', type=int, default=4, dest='num_examples')
    parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file')
    parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix')
    parser.add_argument('--include-choices', action='store_true', dest='include_choices')
    parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file')
    parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
    parser.add_argument('--temperature', type=float, default=0.5)
    parser.add_argument('--max_tokens',
                        type=int,
                        default=512,
                        help='The maximum number of tokens allowed for the generated answer.')
    parser.add_argument('--top_p', type=float, default=1.0)
    parser.add_argument('--frequency_penalty', type=float, default=0.0)
    parser.add_argument('--presence_penalty', type=float, default=0.0)
    args = parser.parse_args()

    coco_dir = "/home/test/yxl/MCoT/data/COCO"
    train_set = load_aokvqa(args.aokvqa_dir, 'train')
    eval_set = load_aokvqa(args.aokvqa_dir, args.split)

    train_context = {}
    context = {}
    if args.context_file is not None:
        train_context = json.load(args.train_context_file)
        context = json.load(args.context_file)

    predictions = {}

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
    )

    processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")

    prompt_examples = random.sample(train_set, args.num_examples)

    for d in tqdm(eval_set):
        q = d['question_id']

        prompt = args.prompt_prefix

        for e in prompt_examples:
            prompt += prompt_element(e,
                                     context=train_context.get(q, None),
                                     include_choices=args.include_choices,
                                     answer=True
                                     )
            prompt += '\n\n'

        # prompt += "Please answer the following question in the form \"The correct answer is ,because\" based on the above example \n\n"
        prompt += 'Let\'s think step by step\n'
        prompt += prompt_element(d,
                                 context=context.get(q, None),
                                 include_choices=True,
                                 answer=False
                                 )
        image_path = get_coco_path('val', d['image_id'], coco_dir)
        # output = get_qwen_result(image_path, prompt, processor, model)
        # predictions[q] = output

        all_outputs = []
        for _ in range(5):
            output = get_qwen_result(image_path, prompt, processor, model)  # 'A', ..., 'E'
            all_outputs.append(output)
        counter = Counter(all_outputs)
        final_output = counter.most_common(1)[0][0]
        predictions[q] = final_output

    json.dump(predictions, args.output_file)



if __name__ == '__main__':
    main()