import json
import os
from utils.ECE import *
from utils.utils import *
from utils.conversation import conv_templates, DEFAULT_IMAGE_TOKEN
import csv
from utils.prompt import *
from utils.GPTprompt import generate_k_answerable_GPT_prompts, generate_k_unanswerable_GPT_prompts
import time
import re
def ars_answer(model, args, max_retries_per_question=3, retry_delay=2):
    k = args.num_generate_questions
    problem_type = args.problem_type
    if problem_type not in ["answerable", "unanswerable"]:
        raise ValueError("problem_type must be 'answerable' or 'unanswerable'")

    question_path = f'{args.save_base}/{k}_question/{model.name}'
    image_path = args.data_path

    if model.name == "GPT4o":
        prompt_generators = {
            "answerable": generate_k_answerable_GPT_prompts(k),
            "unanswerable": generate_k_unanswerable_GPT_prompts(k)
        }
    else:
        prompt_generators = {
            "answerable": generate_k_answerable_prompts(k),
            "unanswerable": generate_k_unanswerable_prompts(k)
        }
    prompt_versions = prompt_generators[problem_type]

    sorted_versions = sorted(
        prompt_versions.keys(),
        key=lambda x: int(x[1:])
    )

    for version in sorted_versions:
        results = []
        failed_questions = []  # 

        input_filename = f"questions_{version}.json"
        output_filename = f"answers_{version}.json"
        failed_filename = f"failed_questions_{version}.json"  
        input_file = os.path.join(question_path, problem_type, input_filename)
        save_dir = os.path.join(question_path, problem_type)
        save_path = os.path.join(save_dir, output_filename)
        failed_path = os.path.join(save_dir, failed_filename)
        os.makedirs(save_dir, exist_ok=True)

        if os.path.exists(save_path):
            with open(save_path, "r", encoding="utf-8") as f:
                results = json.load(f)
            existing_question_subids = {item['question_subid'] for item in results}
        else:
            existing_question_subids = set()

        if os.path.exists(failed_path):
            with open(failed_path, "r", encoding="utf-8") as f:
                failed_questions = json.load(f)

        try:
            with open(input_file, 'r', encoding='utf-8') as f:
                answers = json.load(f)

        except FileNotFoundError:
            print(f"Question file {input_file} not found, skipping {version}...")
            continue

        for item in answers:
            image = item['image']
            image_id = image.split(".")[0]
            if not image.endswith('.jpg'):
                continue

            qs_str = item.get('qs_str', '')
            question_list = re.split(
                r'\n(?=(?:^\s*\d+\.\s)|(?:^\s*\d+\)\s)|(?:^\s*(?:\d+\.\s*)?(?:\*\*\s*)?Question(?:\s+\d+)?(?:\s*\*\*)?\s*:))',
                qs_str,
                flags=re.MULTILINE
            )


            question_list = [q.strip() for q in question_list if q.strip()]
            for idx, single_q in enumerate(question_list, start=1):
                single_q = single_q.strip()
                if not single_q:
                    continue
                    
                question_subid = f'{image_id}_{idx}'
                if question_subid in existing_question_subids:
                    continue  

                conv = conv_templates[args.ans_conv_mode].copy() #                                Then output exactly **two lines** in the following format (do not output anything else): \
                if version.startswith(('v1', 'v2')):
                    context = 'if unsure, reply with "unanswerable" without guessing. Please answer directly with the option letter (A/B/C/D).'
                else:
                    context = 'Please use a single-word or phrase to answer. If unsure, respond with "unanswerable" without guessing.'
                

                full_prompt = f"{DEFAULT_IMAGE_TOKEN}\n{single_q}\n{context}\n"
                conv.append_message(conv.roles[0], full_prompt)
                conv.append_message(conv.roles[1], None)
                prompt = conv.get_prompt()

                model.generate_prompt(prompt)

                    # ==== 加 retry ====
                attempt = 0
                success = False
                while attempt < max_retries_per_question:
                    try:
                        answer, _, _ = model.get_answer(os.path.join(image_path, image))
                        if answer is None or not answer.strip():
                            raise ValueError("Received None or empty answer")
                        success = True
                        break  
                    except Exception as e:
                        attempt += 1
                        print(f"Failed to get answer for {image_id} (Attempt {attempt}/{max_retries_per_question}): {type(e).__name__}: {e}")
                        if attempt < max_retries_per_question:
                            time.sleep(retry_delay)

                if not success:
                    print(f"Giving up on {image_id}after {max_retries_per_question} attempts.")
                    failed_questions.append({
                            'image_id': image_id,
                            'question_subid': question_subid
                    })
                    with open(failed_path, "w", encoding="utf-8") as f:
                        json.dump(failed_questions, f, ensure_ascii=False, indent=2)
                    continue 

                if version.startswith(('v2')):
                    ans = extract_answer_after_answerability(answer.strip()) 
                    ans = process_open_answer(ans)
                else:
                    words = answer.lower()
                    ans = process_open_answer(words)

                result_item = {
                    'version': version,
                    'image_id': image_id,
                    'question_subid': question_subid,
                    'prompt': prompt,
                    'question': single_q,
                    'raw_answer': answer,
                    'ans': ans
                }
                results.append(result_item)
                existing_question_subids.add(question_subid)

                with open(save_path, "w", encoding="utf-8") as f:
                    json.dump(results, f, ensure_ascii=False, indent=2)
                with open(failed_path, "w", encoding="utf-8") as f:
                    json.dump(failed_questions, f, ensure_ascii=False, indent=2)

        print(f"Saved result to {save_path}")

    print("✅ All versions done!")
                

def ars_evaluate_uncertainty(model, version, args):
    count = 0
    count_unans = 0
    problem_type =  args.problem_type
    k = args.num_generate_questions
    answer_path = f"{args.save_base}/{k}_question/{model.name}/{problem_type}"
    os.makedirs(answer_path, exist_ok=True)

    file_name = f"answers_{version}.json"
    file_path = os.path.join(answer_path, file_name)

    with open(file_path, 'r') as f:
        data = json.load(f)
    
    for item in data:
        ans = item.get('ans') or ''
        ans = ans.strip().lower()

        # ans = item.get('ans', '').strip().lower()
        count += 1
        if ans == "unanswerable":
            count_unans += 1

    ratio = count_unans / count if count != 0 else 0.0     

    csv_file = f"{args.save_base}/{k}_question/results.csv"
    os.makedirs(os.path.dirname(csv_file), exist_ok=True)
    file_exists = os.path.isfile(csv_file)

    fieldnames = ['model', 'version', 'problem_type', 'total_answers', 'unanswerable_count', 'unanswerable_ratio','k']     
    file_exists = os.path.isfile(csv_file)

    with open(csv_file, 'a', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)

        if not file_exists or f.tell() == 0:
            writer.writeheader()

        writer.writerow({
            'model': model.name,  
            'version': version,
            'problem_type': problem_type,
            'total_answers': count,
            'unanswerable_count': count_unans,
            'unanswerable_ratio': f"{ratio:.2f}",
            'k': args.num_generate_questions
        })

    print("\n\033[1;36m=== Counting Result ===\033[0m")
    print(f"\033[33m Problem Type\033[0m: {problem_type}")
    print(f"\033[33m Total answers\033[0m: {count}")
    print(f"\033[33m Unanswerable count\033[0m: {count_unans}")
    print(f"\033[33m Unanswerable ratio\033[0m: \033[1;32m{ratio:.2%}\033[0m")




