'''
read question from ./question/{Model.name}/question.json
and ask model to answer
record the amount of the question the model answering. (if it is needed to remind model that answer unanswerable when it is not sure)
'''
import json
import os
from utils.ECE import *
from utils.utils import *
from utils.conversation import conv_templates,  DEFAULT_IMAGE_TOKEN
import csv
from utils.prompt import *
from utils.GPTprompt import *
import re
def Answer(model, args):

    problem_type=args.problem_type
    if problem_type not in ["answerable", "unanswerable"]:
        raise ValueError("problem_type must be 'answerable' or 'unanswerable'")

    
    question_path = f'{args.save_base}/{model.name}'
    image_path = args.data_path

    prompt_generators = {
        "answerable": generate_answerable_prompts,
        "unanswerable": generate_unanswerable_prompts
    }
    prompt_versions = prompt_generators[problem_type]()

    sorted_versions = sorted(
        prompt_versions.keys(),
        key=lambda x: int(x[1:]) )


    for version in sorted_versions:
        print("version", version)
        results = []

        input_filename = f"questions_{version}.json"
        output_filename = f"answers_{version}.json"
        input_file = os.path.join(question_path, problem_type, input_filename)
        
        print("input_file",input_file)
        try:
            with open(input_file, 'r', encoding='utf-8') as f:
                answers = json.load(f)
        except FileNotFoundError:
            print(f"Question file {input_file} not found, skipping {version}...")
            continue
        save_dir = os.path.join(question_path, problem_type)
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, output_filename)

        results = []
        existing_question_subids = set()
        if os.path.exists(save_path):
            with open(save_path, "r", encoding="utf-8") as f:
                results = json.load(f)
        existing_question_subids = {r["question_subid"] for r in results}


        for item in answers:
            image = item['image']
            image_id = image.split(".")[0]
            if not (image.lower().endswith('.jpg') or image.lower().endswith('.jpeg')):
                continue

            qs_str = item.get('qs_str', '')  

            pattern = re.compile(
                                r"""
                                (                                   # -------- begin -------
                                    (?:^\s*\d+\.\s)                                   # 1.  1. xxx
                                | (?:^\s*\d+\)\s)                                   # 2.  1) xxx
                                | (?:^\s*(?:\d+\.\s*)?(?:\*\*\s*)?Question           # 3.  Question: / **Question 2:**
                                    (?:\s+\d+)?(?:\s*\*\*)?\s*:)
                                | (?:^\s*"qs_str"\s*:\s*")                          # 4.  "qs_str": "
                                | (?:^\s*[A-Z][^\n]*\?\s*$)                         # 5.  end with ? 
                                )
                                .*?                                                  
                                (?=                                                  
                                    \n(?:\s*\d+\.\s|\s*\d+\)\s|                     # 1. / 1)
                                        \s*(?:\*\*\s*)?Question(?:\s+\d+)?(?:\s*\*\*)?\s*: |
                                        \s*"qs_str"\s*: |                            # "qs_str":
                                        \s*[A-Z][^\n]*\?\s*$ )                       
                                | \Z                                               
                                )
                                """,
                                flags=re.MULTILINE | re.DOTALL | re.VERBOSE,
                            )
            
            question_list = [match.group(0).strip() for match in pattern.finditer(qs_str)]
                

            if len(question_list) == 0: 
                pattern = re.compile(
                    r'(#{2,4}\s+\*{0,2}Question\s+\d+\*{0,2}:?.*?)(?=\n#{2,4}\s+\*{0,2}Question\s+\d+\*{0,2}:?|\Z)',
                    flags=re.MULTILINE | re.DOTALL
                )

                question_list = [match.group(0).strip() for match in pattern.finditer(qs_str)]
            
            if len(question_list) == 0:
                pattern = re.compile(
                    r'(?:^#{2,4}\s+\*{0,2}Question\s+\d+\s*\*{0,2}:).*?(?=\n#{2,4}\s+\*{0,2}Question\s+\d+\s*\*{0,2}:|\Z)',
                    flags=re.MULTILINE | re.DOTALL | re.IGNORECASE
                )
                question_list = [match.group(0).strip() for match in pattern.finditer(qs_str)]
            

            for idx, single_q in enumerate(question_list, start=1):
                single_q = single_q.strip()
                if not single_q:
                    continue
                question_subid = f"{image_id}_{idx}"
                if question_subid in existing_question_subids:
                    continue
                conv = conv_templates[args.ans_conv_mode].copy()
                if version.startswith(('v1', 'v2')):
                    context = 'if unsure, reply with "unanswerable" without guessing. Please answer directly with the option letter (A/B/C/D).'
                else:
                    context = 'If unsure, respond with "unanswerable" without guessing. Please use a single-word or phrase to answer. '

                full_prompt = f"{DEFAULT_IMAGE_TOKEN}\n{single_q}\n{context}\n"

                conv.append_message(conv.roles[0], full_prompt)
                conv.append_message(conv.roles[1], None)
                prompt = conv.get_prompt()
                model.generate_prompt(prompt)
                
                answer, _ , _= model.get_answer(os.path.join(image_path, image))
                print(answer)

                if version.startswith(('v2')):
                    ans = extract_answer_after_answerability(answer.strip()) 
                    ans = process_open_answer(ans)
                else:
                    words = answer.lower()
                    ans = process_open_answer(words)

                record = {
                    'version': version,
                    'image_id': image_id,
                    'question_id': idx,
                    'question_subid': question_subid,
                    'prompt': prompt,
                    'question': single_q,
                    'raw_answer': answer,
                    'ans': ans
                }

                results.append(record)
                with open(save_path, "w", encoding="utf-8") as f:
                    json.dump(results, f, ensure_ascii=False, indent=2)
        
        print(f"Saved streaming results to: {save_path}")

def load_json_or_jsonl(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        first_char = f.read(1)
        f.seek(0)
        if first_char == "[":
            # JSON format
            return json.load(f)
        else:
            # JSONL format
            return [json.loads(line) for line in f if line.strip()]                

def evaluate_uncertainty(model, version, args):
    count = 0
    count_unans = 0
    problem_type =  args.problem_type
    answer_path = f"{args.save_base}/{model.name}/{problem_type}"
    os.makedirs(answer_path, exist_ok=True)

    file_name = f"answers_{version}.json"
    file_path = os.path.join(answer_path, file_name)

    # with open(file_path, 'r') as f:
    #     data = json.load(f)
    data = load_json_or_jsonl(file_path)
    
    for item in data:
        ans = item.get('ans', '')
        if ans is not None:
            ans = ans.strip().lower()
        count += 1
        if ans == "unanswerable":
            count_unans += 1

    ratio = count_unans / count if count != 0 else 0.0    

    csv_file = f"{args.save_base}/results.csv"
    file_exists = os.path.isfile(csv_file)

    fieldnames = ['model', 'version', 'problem_type', 'total_answers', 'unanswerable_count', 'unanswerable_ratio']     
    file_exists = os.path.isfile(csv_file)

    with open(csv_file, 'a', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)

        if not file_exists or f.tell() == 0:
            writer.writeheader()

        writer.writerow({
            'model': model.name,  
            'version': version,
            'problem_type': problem_type,
            'total_answers': count,
            'unanswerable_count': count_unans,
            'unanswerable_ratio': f"{ratio:.2f}"  
        })

    print("\n\033[1;36m=== Counting Result ===\033[0m")
    print(f"\033[33m Problem Type\033[0m: {problem_type}")
    print(f"\033[33m Total answers\033[0m: {count}")
    print(f"\033[33m Unanswerable count\033[0m: {count_unans}")
    print(f"\033[33m Unanswerable ratio\033[0m: \033[1;32m{ratio:.2%}\033[0m")