import json
import os
from utils.ECE import *
from utils.utils import *
from utils.conversation import conv_templates, DEFAULT_IMAGE_TOKEN
from utils.prompt import *
from utils.GPTprompt import *
import os, json, re

def img_Answer(data,model, args):
    problem_type = args.problem_type
    if problem_type not in ["answerable", "unanswerable"]:
        raise ValueError("problem_type must be 'answerable' or 'unanswerable'")
                   
    if args.dataset_name in ['MMBench']:
        image_dict = {
            str(row.get("index", idx)): row["image"]["bytes"]
            for idx, row in data.iterrows()
        }
    else:
        image_dict = {
            str(row.get("question_id", idx)): row["image"]["bytes"]
            for idx, row in data.iterrows()
        }

    question_path = f"{args.save_base}/{model.name}"
    prompt_generators = (
        {
            "answerable": generate_answerable_GPT_prompts,
            "unanswerable": generate_unanswerable_GPT_prompts,
        }
        if model.name == "GPT4o"
        else {
            "answerable": generate_answerable_prompts,
            "unanswerable": generate_unanswerable_prompts,
        }
    )
    prompt_versions = prompt_generators[problem_type]()
    sorted_versions = sorted(prompt_versions, key=lambda x: int(x[1:]))

    for version in sorted_versions:
        input_filename  = f"questions_{version}.json"
        output_filename = f"answers_{version}.json"
        input_file      = os.path.join(question_path, problem_type, input_filename)
        save_dir  = os.path.join(question_path, problem_type)
        save_path = os.path.join(save_dir, output_filename)

        # question_subid
        existing_question_subids = set()
        if os.path.exists(save_path):         
            with open(save_path, "r", encoding="utf-8") as f:
                existing_question_subids = {r["question_id"] for r in json.load(f)}
                
        try:
            with open(input_file, "r", encoding="utf-8") as f:
                questions = json.load(f)
        except FileNotFoundError:
            print(f"[WARN] {input_file} not found, skip {version}.")
            continue

        results = []
        for item in questions:
            qs_str   = item.get("qs_str", "")
            
            # question_id = item.get("question_id")
            question_id = item.get("question_id", "")
                
            if question_id is None:
                print("[WARN] No question_id, skip.")
                continue

            question_id_str = str(question_id).strip()

            if question_id_str not in image_dict:
                print(f"[WARN] question_id {question_id_str} not found in image_dict keys.")
                continue

            if question_id_str not in image_dict:
                print(f"Missing: {repr(question_id)}   (actual type: {type(item.get('question_id'))})")


            pattern = re.compile(
                    r"""
                    (                                   # -------- begin -------
                        (?:^\s*\d+\.\s)                                   # 1.  1. xxx
                    | (?:^\s*\d+\)\s)                                   # 2.  1) xxx
                    | (?:^\s*(?:\d+\.\s*)?(?:\*\*\s*)?Question           # 3.  Question: / **Question 2:**
                        (?:\s+\d+)?(?:\s*\*\*)?\s*:)
                    | (?:^\s*"qs_str"\s*:\s*")                          # 4.  "qs_str": "
                    | (?:^\s*[A-Z][^\n]*\?\s*$)                         # 5.  end with ? 
                    )
                    .*?                                                  
                    (?=                                                  
                        \n(?:\s*\d+\.\s|\s*\d+\)\s|                     # 1. / 1)
                            \s*(?:\*\*\s*)?Question(?:\s+\d+)?(?:\s*\*\*)?\s*: |
                            \s*"qs_str"\s*: |                            # "qs_str":
                            \s*[A-Z][^\n]*\?\s*$ )                       
                    | \Z                                               
                    )
                    """,
                    flags=re.MULTILINE | re.DOTALL | re.VERBOSE,
                )
                
            question_list = [match.group(0).strip() for match in pattern.finditer(qs_str)]

            for idx, single_q in enumerate(question_list, 1):
                question_subid = f"{question_id}_{idx}"  
                if question_subid in existing_question_subids:
                    continue  
                conv = conv_templates[args.ans_conv_mode].copy()
                if version.startswith(("v1", "v2")):
                    context = (
                        'if unsure, reply with "unanswerable" without guessing. '
                        "Please answer directly with the option letter (A/B/C/D)."
                    )
                else:
                    context = (
                        'If unsure, respond with "unanswerable" without guessing. '
                        "Please use a single-word or phrase to answer."
                    )

                full_prompt = f"{DEFAULT_IMAGE_TOKEN}\n{single_q}\n{context}\n"
                conv.append_message(conv.roles[0], full_prompt)
                conv.append_message(conv.roles[1], None)
                prompt = conv.get_prompt()
                model.generate_prompt(prompt)

                image_bytes = image_dict[question_id_str]
                raw_ans, _, _ = model.get_answer(image_bytes)

                if version.startswith(("v1", "v2")):
                    ans = extract_answer(single_q, raw_ans.strip()) or "unanswerable"
                else:
                    ans = process_open_answer(raw_ans.lower())

                results.append(
                    {
                        "version": version,
                        "question_id": question_id,
                        "question_id": idx,
                        "question_subid": f"{question_id}_{idx}",
                        "prompt": prompt,
                        "question": single_q,
                        "raw_answer": raw_ans,
                        "ans": ans,
                    }
                )

        save_dir  = os.path.join(question_path, problem_type)
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, output_filename)
        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        print(f"[INFO] Saved {version} {problem_type} results → {save_path}") 

def evaluate_uncertainty_ratio(model, version, args):
    file_path = os.path.join(f"{args.save_base}/{model.name}/{args.problem_type}", f"answers_{version}.json")
    with open(file_path, 'r') as f:
        data = json.load(f)

    count, count_unans = 0, 0
    for item in data:
        ans = item.get('ans', '')
        if ans is not None:
            ans = ans.strip().lower()
        count += 1
        if ans == "unanswerable":
            count_unans += 1

    ratio = count_unans / count if count != 0 else 0.0
    return count_unans, count, ratio
