import json
import os
from utils.ECE import *
from utils.utils import *
from utils.conversation import conv_templates, DEFAULT_IMAGE_TOKEN
import csv
from utils.prompt import *
from utils.GPTprompt import *
import re
def tuned_answer(model, args):

    problem_type=args.problem_type
    if problem_type not in ["answerable", "unanswerable"]:
        raise ValueError("problem_type must be 'answerable' or 'unanswerable'")

    
    question_path = f'{args.save_base}/{model.name}'
    image_path = args.data_path

    if model.name == "GPT4o":
        prompt_generators = {
            "answerable": generate_answerable_GPT_prompts,
            "unanswerable": generate_unanswerable_GPT_prompts
        }
    else:
        prompt_generators = {
            "answerable": generate_answerable_prompts,
            "unanswerable": generate_unanswerable_prompts
        }
    prompt_versions = prompt_generators[problem_type]()

    sorted_versions = sorted(
        prompt_versions.keys(),
        key=lambda x: int(x[1:])  
    )

    for version in sorted_versions:
        results = []
        input_filename = f"questions_{version}.json" 
        output_filename = f"answers_{version}.json"
        input_file = os.path.join(question_path, problem_type, input_filename)
        
        try:
            with open(input_file, 'r', encoding='utf-8') as f:
                answers = json.load(f)
        except FileNotFoundError:
            print(f"Question file {input_file} not found, skipping {version}...")
            continue

        for item in answers:
            image = item['image']
            image_id = image.split(".")[0]
            if not (image.lower().endswith('.jpg') or image.lower().endswith('.jpeg')):
                continue

            
            qs_str = item.get('qs_str', '')  
            pattern = re.compile(
                    r"""
                    (                                   # -------- begin -------
                        (?:^\s*\d+\.\s)                                   # 1.  1. xxx
                    | (?:^\s*\d+\)\s)                                   # 2.  1) xxx
                    | (?:^\s*(?:\d+\.\s*)?(?:\*\*\s*)?Question           # 3.  Question: / **Question 2:**
                        (?:\s+\d+)?(?:\s*\*\*)?\s*:)
                    | (?:^\s*"qs_str"\s*:\s*")                          # 4.  "qs_str": "
                    | (?:^\s*[A-Z][^\n]*\?\s*$)                         # 5.  end with ? 
                    )
                    .*?                                                  
                    (?=                                                  
                        \n(?:\s*\d+\.\s|\s*\d+\)\s|                     # 1. / 1)
                            \s*(?:\*\*\s*)?Question(?:\s+\d+)?(?:\s*\*\*)?\s*: |
                            \s*"qs_str"\s*: |                            # "qs_str":
                            \s*[A-Z][^\n]*\?\s*$ )                       
                    | \Z                                               
                    )
                    """,
                    flags=re.MULTILINE | re.DOTALL | re.VERBOSE,
                )
            
            question_list = [match.group(0).strip() for match in pattern.finditer(qs_str)]
            for idx, single_q in enumerate(question_list, start=1):
                    single_q = single_q.strip()
                    if not single_q:
                        continue

                    conv = conv_templates[args.ans_conv_mode].copy()
                    if version.startswith(('v1', 'v2')):
                        context = 'if unsure, reply with "unanswerable" without guessing. Please answer directly with the option letter (A/B/C/D).'
                    else:
                        context = 'If unsure, respond with "unanswerable" without guessing. Please use a single-word or phrase to answer. '


                    full_prompt = f"{DEFAULT_IMAGE_TOKEN}\n{single_q}\n{context}\n"

                    conv.append_message(conv.roles[0], full_prompt)
                    conv.append_message(conv.roles[1], None)
                    prompt = conv.get_prompt()
                    model.generate_prompt(prompt)
                    
                    answer, _ , _= model.get_answer(os.path.join(image_path, image))

                    if version.startswith(('v2')):
                        ans = extract_answer_after_answerability(answer.strip()) 
                        ans = process_open_answer(ans)
                    else:
                        words = answer.lower()
                        ans = process_open_answer(words)

                    results.append({
                        'version': version,
                        'image_id': image_id,
                        'question_id': idx,
                        'question_subid': f'{image_id}_{idx}',
                        'prompt' : prompt,
                        'question': single_q,
                        'raw_answer': answer,
                        'ans': ans
                    })
                    print(ans)
        save_dir = os.path.join(question_path, problem_type)
        os.makedirs(save_dir, exist_ok=True)
        
        save_path = os.path.join(save_dir, output_filename)
        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
            
        print(f"Saved {version} {problem_type} results to: {save_path}")
                

def evaluate_uncertainty(model, version, args):
    count = 0
    count_unans = 0
    problem_type =  args.problem_type
    answer_path = f"{args.save_base}/{model.name}/{problem_type}"
    os.makedirs(answer_path, exist_ok=True)

    file_name = f"answers_{version}.json"
    file_path = os.path.join(answer_path, file_name)

    with open(file_path, 'r') as f:
        data = json.load(f)
    
    for item in data:
        ans = item.get('ans', '')
        if ans is not None:
            ans = ans.strip().lower()
        count += 1
        if ans == "unanswerable":
            count_unans += 1

    ratio = count_unans / count if count != 0 else 0.0     

    csv_file = f"{args.save_base}/results.csv"
    file_exists = os.path.isfile(csv_file)

    fieldnames = ['model', 'version', 'problem_type', 'total_answers', 'unanswerable_count', 'unanswerable_ratio']     
    file_exists = os.path.isfile(csv_file)

    with open(csv_file, 'a', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)

        if not file_exists or f.tell() == 0:
            writer.writeheader()

        writer.writerow({
            'model': model.name,  
            'version': version,
            'problem_type': problem_type,
            'total_answers': count,
            'unanswerable_count': count_unans,
            'unanswerable_ratio': f"{ratio:.2f}"  
        })

    print("\n\033[1;36m=== Counting Result ===\033[0m")
    print(f"\033[33m Problem Type\033[0m: {problem_type}")
    print(f"\033[33m Total answers\033[0m: {count}")
    print(f"\033[33m Unanswerable count\033[0m: {count_unans}")
    print(f"\033[33m Unanswerable ratio\033[0m: \033[1;32m{ratio:.2%}\033[0m")
    print("save_path",csv_file)



