import argparse
import os
from tqdm import tqdm
import pandas as pd
from evaluator.VizWiz import eval_model
import pandas as pd
from Models.LLaVA import LLaVA_7B, LLaVA_13B
from Models.Qwen2VL import Qwen2VL
from Models.InternVL3 import InternVL3
from Models.Qwen_VL_Chat import QwenVLChat
from Models.Qwen import Qwen
from Models.Qwen2p5_VL_Instruct import Qwen2p5_VL_Instruct
from Models.BLIP2 import BLIP2
from Models.InstructBLIP import InstructBLIP
from Models.InternVL3 import InternVL3
from utils.ECE import *
from utils.utils import *
import json
from PIL import Image
import glob
from utils.conversation import conv_templates, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_TOKEN
from evaluator.AnswerPrompt import Answer, evaluate_uncertainty
from utils.prompt import *
from utils.GPTprompt import *
from Ask import get_model_class

def generate_questions(model, args):
    problem_type = args.problem_type
    if problem_type not in ["answerable", "unanswerable"]:
        raise ValueError("problem_type must be 'answerable' or 'unanswerable'")
    if model.name == "GPT4o":
        prompt_generators = {
            "answerable": generate_answerable_GPT_prompts,
            "unanswerable": generate_unanswerable_GPT_prompts
        }
    else:
        prompt_generators = {
            "answerable": generate_answerable_prompts,
            "unanswerable": generate_unanswerable_prompts
        }

    image_paths = sorted(glob.glob(os.path.join(args.data_path, "*.*")), 
                      key=lambda x: os.path.basename(x))
    selected_paths = image_paths[:args.num_test_samples]

    prompt_versions = prompt_generators[problem_type]()

    for prompt_version, template in prompt_versions.items():
        print(f"\nProcessing {problem_type} prompt version: {prompt_version}")
        filename_suffix = f"questions_{prompt_version}.json"
        
        results = []
        for image_path in tqdm(selected_paths, desc=f"Processing images"):
            try:
                conv = conv_templates[args.ask_conv_mode].copy()     
                qs = template["qs"]
                context = template["context"]

                full_prompt = f"{DEFAULT_IMAGE_TOKEN}\n{qs}\n{context}\n"

                conv.append_message(conv.roles[0], full_prompt)
                conv.append_message(conv.roles[1], None)
                prompt = conv.get_prompt()
                prompt = model.generate_prompt(prompt)

                image_id = os.path.basename(image_path)               
                questions, _ , _= model.get_answer(image_path)

                entry = {
                    'image': image_id,
                    'prompt_version': prompt_version,
                    'full_prompt': prompt,
                    'qs_str': questions
                }
                results.append(entry)

            except Exception as e:
                print(f"Error in {prompt_version} - {image_path}: {str(e)}")
                continue

        save_dir = os.path.join(
            args.save_base,
            model.name,
            problem_type
        )
        print("save_base", args.save_base)
        os.makedirs(save_dir, exist_ok=True)
        
        save_path = os.path.join(save_dir, filename_suffix)
        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
            
        print(f"Saved {prompt_version} results to: {save_path}")




def main(args):

    ModelClass = get_model_class(args.model_path)
    model = ModelClass(args)

    Answer(model, args)

    if args.prompt_version is not None:
        evaluate_uncertainty(model, args.prompt_version, args)
    else:
        prompt_generators = {
            "answerable": generate_answerable_prompts,
            "unanswerable": generate_unanswerable_prompts
        }
        prompt_versions = prompt_generators[args.problem_type]()

        sorted_versions = sorted(
            prompt_versions.keys(),
            key=lambda x: int(x[1:])  )

        for version in sorted_versions:
            evaluate_uncertainty(model, version, args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="UnSAF Evaluation of prompt engineering")
    parser.add_argument('--model_path', type=str,help="Path to the pre-trained or distilled model file. Specify the location of the model.")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--data_path', type=str, help="Path to the dataset.")
    parser.add_argument('--temperature', type=float, default=1.0)
    parser.add_argument('--num_test_samples', type=int, default=1)
    parser.add_argument('--prompt_version', type=str, default=None)
    parser.add_argument("--ask_conv_mode", default="conv_llama_v2_vqa",help="The ask prompt of model")# (mistral_instruct, chatml_direct, llava_v1, mpt, llava_v0, llava_llama_2, conv_llama_v2_vqa)
    parser.add_argument("--ans_conv_mode", default="llava_llama_2",help="The answer prompt of model") # (mistral_instruct, chatml_direct, llava_v1, mpt, llava_v0, llava_llama_2, conv_llama_v2_vqa)
    parser.add_argument("--problem_type", default="answerable", help="The type of questions you want the MllM to generate.") # (answerable, unanswerable)
    parser.add_argument('--save_base', type=str, default="./result/pe", help="Directory path where the results will be saved.")
    args = parser.parse_args()
    main(args)