import argparse
import os
from tqdm import tqdm
import pandas as pd
from evaluator.VizWiz import eval_model
import pandas as pd
from Models.LLaVA import LLaVA_7B, LLaVA_13B
from Models.Qwen2VL import Qwen2VL
from Models.Qwen_VL_Chat import QwenVLChat
from Models.Qwen import Qwen
from Models.Qwen2p5_VL_Instruct import Qwen2p5_VL_Instruct
from Models.BLIP2 import BLIP2
from Models.InstructBLIP import InstructBLIP
from utils.ECE import *
from utils.utils import *
import json
from PIL import Image
import glob
from utils.conversation import conv_templates, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_TOKEN
from evaluator.Answer_is import img_Answer, evaluate_uncertainty_ratio
from utils.prompt import *
from utils.GPTprompt import *
from Ask import get_model_class
import csv

def generate_questions(data, model, args):
    problem_type = args.problem_type
    if problem_type not in ["answerable", "unanswerable"]:
        raise ValueError("problem_type must be 'answerable' or 'unanswerable'")

    if model.name == "GPT4o":
        prompt_generators = {
            "answerable": generate_answerable_GPT_prompts,
            "unanswerable": generate_unanswerable_GPT_prompts,
        }
    else:
        prompt_generators = {
            "answerable": generate_answerable_prompts,
            "unanswerable": generate_unanswerable_prompts,
        }

    prompt_versions = prompt_generators[problem_type]()

    for prompt_version, template in prompt_versions.items():
        if prompt_version == "v2":
            print(f"\nProcessing {problem_type} prompt version: {prompt_version}")
            filename_suffix = f"questions_{prompt_version}.json"
            results = []

            for _, row in tqdm(
                data.iterrows(), total=len(data), desc="Processing images"
            ):
                try:
                    # question_id = row.get("question_id", "")
                    if args.dataset_name in ['MMBench']:
                        question_id = row.get("index", "")
                    else:
                        question_id = row.get("question_id", "")

                    image_bytes = row["image"]["bytes"]

                    conv = conv_templates[args.ask_conv_mode].copy()
                    full_prompt = f"{DEFAULT_IMAGE_TOKEN}\n{template['qs']}\n{template['context']}\n"
                    conv.append_message(conv.roles[0], full_prompt)
                    conv.append_message(conv.roles[1], None)
                    prompt = model.generate_prompt(conv.get_prompt())

                    raw_predicted, input_ids, mo = model.get_answer(image_bytes)

                    entry = {
                        "question_id": question_id,
                        "prompt_version": prompt_version,
                        "full_prompt": prompt,
                        "qs_str": raw_predicted,
                    }
                    results.append(entry)
                except Exception as e:
                    print(f"Error in {prompt_version}: {str(e)}")
                    continue

            save_dir = os.path.join(args.save_base, model.name, problem_type)
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, filename_suffix)
            with open(save_path, "w", encoding="utf-8") as f:
                json.dump(results, f, ensure_ascii=False, indent=2)

            print(f"Saved {prompt_version} results to: {save_path}")

def main(args):

    ModelClass = get_model_class(args.model_path)
    model = ModelClass(args)

   
    if args.dataset_name in ['MMBench']:
        data = pd.read_parquet(args.data_path)
    else:
        selected_files = [
        os.path.join(args.data_path, f) for f in os.listdir(args.data_path)
        if f.endswith('.parquet') and f.startswith(args.split)
        ]
        data = pd.concat([pd.read_parquet(file) for file in selected_files], ignore_index=True)
        
        # data = data.iloc[: args.num_test_samples]
    data = data.sample(n=args.num_test_samples, random_state=41+i).reset_index(drop=True)

    generate_questions(data, model, args)

    img_Answer(data, model, args)
    
    for version in ["v2", "v4"]:

        unans_cnt, total_cnt, ratio = evaluate_uncertainty_ratio(model, version, args)

        csv_path = f"{args.save_base}/avg_results.csv"
        need_header = not os.path.exists(csv_path) or os.path.getsize(csv_path) == 0

        with open(csv_path, "a", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            if need_header:   # 仅文件为空时写表头
                writer.writerow([
                    "model", "version", "problem_type", "run_index",
                    "unanswerable_cnt", "total_cnt", "unanswerable_ratio"
                ])

            writer.writerow([
                model.name, version, args.problem_type, i + 1,
                unans_cnt, total_cnt, f"{ratio:.4f}",args.num_test_samples
            ])

        print(
            f"\033[1;36m[Run {i+1}] Version {version} → "
            f"Unanswerable ratio:\033[0m "
            f"\033[1;32m{ratio:.2%} "
            f"({unans_cnt}/{total_cnt})\033[0m"
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="UnSAF framework is consistent across image scource")
    parser.add_argument('--model_path', type=str,help="Path to the pre-trained or distilled model file. Specify the location of the model.")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--data_path', type=str, help="Path to the dataset.")
    parser.add_argument('--temperature', type=float, default=1.0)
    parser.add_argument('--num_test_samples', type=int, default=1)
    parser.add_argument('--prompt_version', type=str, default=None)
    parser.add_argument("--ask_conv_mode", default="conv_llama_v2_vqa",help="The ask prompt of model")# (mistral_instruct, chatml_direct, llava_v1, mpt, llava_v0, llava_llama_2, conv_llama_v2_vqa)
    parser.add_argument("--ans_conv_mode", default="llava_llama_2",help="The answer prompt of model") # (mistral_instruct, chatml_direct, llava_v1, mpt, llava_v0, llava_llama_2, conv_llama_v2_vqa)
    parser.add_argument("--problem_type", default="answerable", help="The type of questions you want the MllM to generate.") # (answerable, unanswerable)
    parser.add_argument('--save_base', type=str, default="./result/ImageScource", help="Directory path where the results will be saved.")
    parser.add_argument('--dataset_name', type=str, required=True, help="Dataset Name, serve as savedir")
    parser.add_argument('--split', type=str, default="val",help="The Part of Dataset")

    args = parser.parse_args()
    main(args)

