import argparse
import os
from tqdm import tqdm
from Models.LLaVA import LLaVA_7B, LLaVA_13B
from Models.Qwen2VL import Qwen2VL
from Models.InternVL3 import InternVL3
from Models.InternVL2 import InternVL2
from Models.Qwen_VL_Chat import QwenVLChat
from Models.Qwen import Qwen
from Models.Qwen2p5_VL_Instruct import Qwen2p5_VL_Instruct
from Models.BLIP2 import BLIP2
from Models.InstructBLIP import InstructBLIP
from Models.Mantis import Mantis
from Models.Idefics import Idefics
from Models.mPLUG import mPLUG
from Models.MiniCPM import MiniCPM
from utils.ECE import *
from utils.utils import *
import json
import glob
from utils.conversation import conv_templates,DEFAULT_IMAGE_TOKEN
from evaluator.Answer import Answer, evaluate_uncertainty
from utils.prompt import *
from utils.GPTprompt import *
import random
random.seed(42)

def get_model_class(model_path):

    model_path = model_path.lower()
    
    if 'llava-1.5-7b' in model_path:
        return LLaVA_7B
    elif 'llava-1.5-13b' in model_path:
        return LLaVA_13B
    elif 'qwen2-vl' in model_path:
        return Qwen2VL
    elif 'qwen-vl-chat' in model_path:
        return QwenVLChat
    elif 'qwen-7b' in model_path:
        return Qwen
    elif 'qwen2.5-vl' in model_path:
        return Qwen2p5_VL_Instruct
    elif 'blip2' in model_path:
        return BLIP2
    elif 'instructblip' in model_path:
        return InstructBLIP  
    elif 'internvl3' in model_path:
        return InternVL3
    elif 'internvl2' in model_path:
        return InternVL2
    elif 'mantis' in model_path:
        return Mantis
    elif 'idefics-9b-instruct' in model_path:
        return Idefics
    elif 'mPLUG' in model_path or 'mplug' in model_path:
        return mPLUG
    elif 'cpm' in model_path:
        return MiniCPM
    else:    
        raise ValueError(f"Unsupported model path: {model_path}")


def generate_questions(model, args):

    problem_type = args.problem_type
    if problem_type not in ["answerable", "unanswerable"]:
        raise ValueError("problem_type must be 'answerable' or 'unanswerable'")

    prompt_generators = {
        "answerable": generate_answerable_prompts,
        "unanswerable": generate_unanswerable_prompts
    }
    
    image_paths = sorted(glob.glob(os.path.join(args.data_path, "*.*")), 
                      key=lambda x: os.path.basename(x))
    selected_paths = random.sample(image_paths, k=args.num_test_samples)

    prompt_versions = prompt_generators[problem_type]()
    
    for prompt_version, template in prompt_versions.items():
        # if prompt_version in ["v2"]:
        if ("lora" in model.name.lower() and prompt_version in ["v4"]) \
            or ("lora" not in model.name.lower() and prompt_version in ["v2","v4"]):
            print(f"\nProcessing {problem_type} prompt version: {prompt_version}")
            filename_suffix = f"questions_{prompt_version}.json"
            
            results = []
            done_images = set()

            save_dir = os.path.join(args.save_base, model.name, problem_type)
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, filename_suffix)

            if os.path.exists(save_path):
                with open(save_path, "r", encoding="utf-8") as f:
                    existing_results = json.load(f)
                results = existing_results.copy()
                done_images = {entry['image'] for entry in existing_results}

            for image_path in tqdm(selected_paths, desc=f"Processing images"):
                image_id = os.path.basename(image_path)
                if image_id in done_images:
                    continue  
                try:
                    conv = conv_templates[args.ask_conv_mode].copy()
                    qs = template["qs"]
                    context = template["context"]
                    full_prompt = f"{DEFAULT_IMAGE_TOKEN}\n{qs}\n{context}\n"
                    conv.append_message(conv.roles[0], full_prompt)
                    conv.append_message(conv.roles[1], None)
                    prompt = conv.get_prompt()
                    prompt = model.generate_prompt(prompt)

                    questions, _, _ = model.get_answer(image_path)

                    entry = {
                        'image': image_id,
                        'prompt_version': prompt_version,
                        'full_prompt': prompt,
                        'qs_str': questions
                    }
                    results.append(entry)
                    with open(save_path, "w", encoding="utf-8") as f:
                        json.dump(results, f, ensure_ascii=False, indent=2)

                except Exception as e:
                    print(f"Error in {prompt_version} - {image_path}: {str(e)}")
                    continue

            print(f"Saved {prompt_version} results to: {save_path}")



def main(args):

    ModelClass = get_model_class(args.model_path)
    model = ModelClass(args)


    generate_questions(model, args)
    Answer(model, args)

    if args.prompt_version is not None:
        evaluate_uncertainty(model, args.prompt_version, args)
    else:
        prompt_generators = {
            "answerable": generate_answerable_prompts,
            "unanswerable": generate_unanswerable_prompts
        }
        prompt_versions = prompt_generators[args.problem_type]()

        sorted_versions = sorted(
            prompt_versions.keys(),
            key=lambda x: int(x[1:])  # 提取v后面的数字进行排序
        )

        for version in sorted_versions:
            evaluate_uncertainty(model, version, args)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="UnSAF Evaluation")
    parser.add_argument('--model_path', type=str,help="Path to the pre-trained or distilled model file. Specify the location of the model.")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--data_path', type=str, help="Path to the dataset.")
    parser.add_argument('--temperature', type=float, default=1.0)
    parser.add_argument('--num_test_samples', type=int, default=1)
    parser.add_argument('--prompt_version', type=str, default=None)
    parser.add_argument("--ask_conv_mode", default="conv_llama_v2_vqa",help="The ask prompt of model")# (mistral_instruct, chatml_direct, llava_v1, mpt, llava_v0, llava_llama_2, conv_llama_v2_vqa)
    parser.add_argument("--ans_conv_mode", default="llava_llama_2",help="The answer prompt of model") # (mistral_instruct, chatml_direct, llava_v1, mpt, llava_v0, llava_llama_2, conv_llama_v2_vqa)
    parser.add_argument("--problem_type", default="answerable", help="The type of questions you want the MllM to generate.") # (answerable, unanswerable)
    parser.add_argument('--save_base', type=str, default="./result/question", help="Directory path where the results will be saved.")
    args = parser.parse_args()
    main(args)

