
import sys
sys.path.append("the relative path")
import argparse
import os
from tqdm import tqdm
import pandas as pd
from evaluator.VizWiz import eval_model
import pandas as pd
from Models.LLaVA import LLaVA_7B, LLaVA_13B
from Models.Qwen2VL import Qwen2VL
from Models.Qwen_VL_Chat import QwenVLChat
from Models.Qwen import Qwen
from Models.Qwen2p5_VL_Instruct import Qwen2p5_VL_Instruct
from Models.InstructBLIP import InstructBLIP
from utils.ECE import *
from utils.utils import *
import json
from PIL import Image
import glob
from utils.conversation import conv_templates, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_TOKEN
from utils.GPTprompt import generate_k_answerable_GPT_prompts, generate_k_unanswerable_GPT_prompts
from evaluator.ARSAnswer import ars_answer, ars_evaluate_uncertainty
from utils.prompt import *
from main.Ask import get_model_class

def generate_k_questions(model, args):
    k = args.num_generate_questions
    problem_type = args.problem_type
    if problem_type not in ["answerable", "unanswerable"]:
        raise ValueError("problem_type must be 'answerable' or 'unanswerable'")

    # 区分 GPT 和普通模型
    if model.name in [ "Kimi","InternVL3_14b"]:
        prompt_generators = {
            "answerable": generate_k_answerable_GPT_prompts(k),
            "unanswerable": generate_k_unanswerable_GPT_prompts(k)
        }
    else:
        prompt_generators = {
            "answerable": generate_k_answerable_prompts(k),
            "unanswerable": generate_k_unanswerable_prompts(k)
        }

    image_paths = sorted(glob.glob(os.path.join(args.data_path, "*.*")),
                         key=lambda x: os.path.basename(x))
    selected_paths = image_paths[:args.num_test_samples]

    prompt_versions = prompt_generators[problem_type]

    for prompt_version, template in prompt_versions.items():
        print(f"\nProcessing {problem_type} prompt version: {prompt_version}")
        filename_suffix = f"questions_{prompt_version}.json"

        save_dir = os.path.join(f"{args.save_base}/{k}_question", model.name, problem_type)
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, filename_suffix)

        if os.path.exists(save_path):
            with open(save_path, "r", encoding="utf-8") as f:
                existing_results = json.load(f)
            done_images = {entry['image'] for entry in existing_results}
        else:
            existing_results, done_images = [], set()

        results = existing_results.copy()

        for image_path in tqdm(selected_paths, desc=f"Processing images"):
            image_id = os.path.basename(image_path)
            if image_id in done_images:
                continue

            try:
                conv = conv_templates[args.ask_conv_mode].copy()
                qs = template["qs"]
                context = template["context"]
                full_prompt = f"{DEFAULT_IMAGE_TOKEN}\n{qs}\n{context}\n"

                conv.append_message(conv.roles[0], full_prompt)
                conv.append_message(conv.roles[1], None)
                prompt = conv.get_prompt()
                prompt = model.generate_prompt(prompt)

                questions, _, _ = model.get_answer(image_path)

                entry = {
                    'image': image_id,
                    'prompt_version': prompt_version,
                    'full_prompt': prompt,
                    'qs_str': questions
                }
                results.append(entry)

                with open(save_path, "w", encoding="utf-8") as f:
                    json.dump(results, f, ensure_ascii=False, indent=2)

            except Exception as e:
                print(f"[Error] {prompt_version} - {image_path}: {e}")
                continue

        print(f"Saved {prompt_version} results to: {save_path}")


def main(args):

    ModelClass = get_model_class(args.model_path)
    model = ModelClass(args)


    generate_k_questions(model, args)
    ars_answer(model, args)

    if args.prompt_version is not None: 
        ars_evaluate_uncertainty(model, args.prompt_version, args)
    else:
        prompt_generators = {
            "answerable": generate_answerable_prompts,
            "unanswerable": generate_unanswerable_prompts
        }
        prompt_versions = prompt_generators[args.problem_type]()

        sorted_versions = sorted(
            prompt_versions.keys(),
            key=lambda x: int(x[1:]) 
        )

        for version in sorted_versions:
            ars_evaluate_uncertainty(model, version, args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="UnSAF framework with k = 1")
    parser.add_argument('--model_path', type=str,help="Path to the pre-trained or distilled model file. Specify the location of the model.")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--data_path', type=str, help="Path to the dataset.")
    parser.add_argument('--temperature', type=float, default=1.0)
    parser.add_argument('--num_test_samples', type=int, default=1)
    parser.add_argument('--prompt_version', type=str, default=None)
    parser.add_argument("--ask_conv_mode", default="conv_llama_v2_vqa",help="The ask prompt of model")# (mistral_instruct, chatml_direct, llava_v1, mpt, llava_v0, llava_llama_2, conv_llama_v2_vqa)
    parser.add_argument("--ans_conv_mode", default="llava_llama_2",help="The answer prompt of model") # (mistral_instruct, chatml_direct, llava_v1, mpt, llava_v0, llava_llama_2, conv_llama_v2_vqa)
    parser.add_argument("--problem_type", default="answerable", help="The type of questions you want the MllM to generate.")  #(answerable, unanswerable)
    parser.add_argument('--num_generate_questions', type=int, default=1)
    parser.add_argument("--save_base", default="./result/arsquestion", help="Directory path where the results will be saved.")
    
    args = parser.parse_args()

    main(args)

