import os
import re
import random
import json
from tqdm import tqdm
import argparse
import pathlib
import requests
import ollama

from load_aokvqa import load_aokvqa, get_coco_path

def prompt_element(args, d, context=None, include_choices=False, answer=False):
        return (f"Context: {context}\n" if context is not None else '') + \
            f"Q: {d['question']}\n" + \
           (f"Choices: {', '.join(d['choices'])}.\n" if include_choices else '') + \
            f"A:The answer is" + (f" {args.options[d['correct_choice_idx']]}" if answer else '')



def get_qwen_result(image_path, prompt, args):
    messages = [{
        "role": "user",
        "content": prompt
    }]

    if os.path.exists(image_path):
        messages[0]["images"] = [image_path]

    response = ollama.chat(
        model="llama3.2-vision:11b",
        stream=False,
        messages=messages,
        options={"temperature": 0}
    )

    print(response)
    output = response['message']['content']

    pattern = re.compile(r'The answer is ([A-Z]).')
    res = pattern.findall(output)
    if len(res) == 1:
        answer = res[0]  # 'A', 'B', ...
    else:
        answer = "FAILED"

    return answer, output


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file')
    parser.add_argument('--prefix', type=str, default='Please provide an answer based on the given choice\n', dest='prompt_prefix')
    parser.add_argument('--include-choices', action='store_true', dest='include_choices')
    parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file')
    parser.add_argument('--temperature', type=float, default=0.5)
    parser.add_argument('--max_tokens',
                        type=int,
                        default=512,
                        help='The maximum number of tokens allowed for the generated answer.')
    parser.add_argument('--top_p', type=float, default=1.0)
    parser.add_argument('--frequency_penalty', type=float, default=0.0)
    parser.add_argument('--presence_penalty', type=float, default=0.0)
    args = parser.parse_args()

    aokvqa_dir = "../data/aokvqa/"
    coco_dir = "/home/test/yxl/MCoT/data/COCO"

    train_dataset = load_aokvqa(aokvqa_dir, 'train')
    dataset_example = train_dataset[4]

    context = {}
    if args.context_file is not None:
        train_context = json.load(args.train_context_file)
        context = json.load(args.context_file)

    prompt = args.prompt_prefix
    prompt += prompt_element(args,
                             dataset_example,
                             context=context.get(0, None),
                             include_choices=True,
                             answer=False
                             )
    image_path = get_coco_path('train', dataset_example['image_id'], coco_dir)

    print(os.path.exists(image_path))

    prediction, output = get_qwen_result(image_path, prompt, args)

    print(prediction)
    print(output)




if __name__ == '__main__':

    aokvqa_dir = "../data/aokvqa/"
    coco_dir = "/home/test/yxl/MCoT/data/coco"

    train_dataset = load_aokvqa(aokvqa_dir, 'train')

    dataset_example = train_dataset[0]
    image_path = get_coco_path('train', dataset_example['image_id'], coco_dir)


    correct_choice = dataset_example['choices'][ dataset_example['correct_choice_idx'] ]

    main()

