import json
import os
from utils.ECE import *
from utils.utils import *
from utils.conversation import conv_templates, DEFAULT_IMAGE_TOKEN
import csv
from utils.prompt import *
from utils.GPTprompt import *
import re

def Answer(model, args):

    problem_type=args.problem_type
    if problem_type not in ["answerable", "unanswerable"]:
        raise ValueError("problem_type must be 'answerable' or 'unanswerable'")

    
    question_path = f'./question/{model.name}'
    image_path = args.data_path

    prompt_generators = {
        "answerable": generate_answerable_prompts,
        "unanswerable": generate_unanswerable_prompts
    }
    prompt_versions = prompt_generators[problem_type]()

    sorted_versions = sorted(
        prompt_versions.keys(),
        key=lambda x: int(x[1:])  
    )


    for version in sorted_versions:
        print("version", version)
        results = []
        input_filename = f"questions_{version}.json"
        output_filename = f"answers_{version}.json"
        input_file = os.path.join(question_path, problem_type, input_filename)
        
        print("input_file",input_file)
        try:
            with open(input_file, 'r', encoding='utf-8') as f:
                answers = json.load(f)
        except FileNotFoundError:
            print(f"Question file {input_file} not found, skipping {version}...")
            continue
        save_dir = os.path.join(args.save_base,model.name, problem_type)
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, output_filename)

        results = []
        existing_question_subids = set()
        if os.path.exists(save_path):
            with open(save_path, "r", encoding="utf-8") as f:
                results = json.load(f)
        existing_question_subids = {r["question_subid"] for r in results}


        for item in answers:
            image = item['image']
            image_id = image.split(".")[0]
            if not (image.lower().endswith('.jpg') or image.lower().endswith('.jpeg')):
                continue

            qs_str = item.get('qs_str', '')  
            pattern = re.compile(
                                r"""
                                (                                   # -------- begin -------
                                    (?:^\s*\d+\.\s)                                   # 1.  1. xxx
                                | (?:^\s*\d+\)\s)                                   # 2.  1) xxx
                                | (?:^\s*(?:\d+\.\s*)?(?:\*\*\s*)?Question           # 3.  Question: / **Question 2:**
                                    (?:\s+\d+)?(?:\s*\*\*)?\s*:)
                                | (?:^\s*"qs_str"\s*:\s*")                          # 4.  "qs_str": "
                                | (?:^\s*[A-Z][^\n]*\?\s*$)                         # 5.  end with ? 
                                )
                                .*?                                                  
                                (?=                                                  
                                    \n(?:\s*\d+\.\s|\s*\d+\)\s|                     # 1. / 1)
                                        \s*(?:\*\*\s*)?Question(?:\s+\d+)?(?:\s*\*\*)?\s*: |
                                        \s*"qs_str"\s*: |                            # "qs_str":
                                        \s*[A-Z][^\n]*\?\s*$ )                       
                                | \Z                                               
                                )
                                """,
                                flags=re.MULTILINE | re.DOTALL | re.VERBOSE,
                            )
            
            question_list = [match.group(0).strip() for match in pattern.finditer(qs_str)]
            for idx, single_q in enumerate(question_list, start=1):
                single_q = single_q.strip()
                if not single_q:
                    continue
                question_subid = f"{image_id}_{idx}"
                if question_subid in existing_question_subids:
                    continue
                conv = conv_templates[args.ans_conv_mode].copy()
                if version.startswith(('v1', 'v2')):
                    context = 'if unsure, reply with "unanswerable" without guessing. Please answer directly with the option letter (A/B/C/D).'
                else:
                    context = 'If unsure, respond with "unanswerable" without guessing. Please use a single-word or phrase to answer. '
                
                instruction_map = {
                    "LLaVA": " Think carefully before answering. Respond only if you are truly confident in your knowledge. If there is any doubt or uncertainty, answer with 'unanswerable'.",
                    "Qwen2VL": " Think carefully before answering. Respond only if you are truly confident in your knowledge. If there is any doubt or uncertainty, answer with 'unanswerable'.",
                    "Qwen": " Do your best to provide an answer. Trust your instincts and go for it. Only use 'unanswerable' if you're completely uncertain.",
                    "mPLUG_7B": ' Do your best to provide an answer. Trust your instincts and go for it. Only use "unanswerable" if you are completely uncertain.', # mPLUG对于'/""
                    "MiniCPM": ' Do your best to provide an answer. Trust your instincts and go for it. Only use "unanswerable" if you are completely uncertain.',
                    "InternVL3_8b": " Think carefully before answering. Respond only if you are truly confident in your knowledge. If there is any doubt or uncertainty, answer with 'unanswerable'."
                }

                for key, instruction in instruction_map.items():
                    if key in model.name:
                        context += instruction
                        break 

                full_prompt = f"{DEFAULT_IMAGE_TOKEN}\n{single_q}\n{context}\n"

                conv.append_message(conv.roles[0], full_prompt)
                conv.append_message(conv.roles[1], None)
                prompt = conv.get_prompt()
                model.generate_prompt(prompt)
                
                answer, _ , _= model.get_answer(os.path.join(image_path, image))
                print(answer)

                if version.startswith(('v2')):
                    ans = extract_answer_after_answerability(answer.strip()) 
                    ans = process_open_answer(ans)
                else:
                    words = answer.lower()
                    ans = process_open_answer(words)
                record = {
                    'version': version,
                    'image_id': image_id,
                    'question_id': idx,
                    'question_subid': question_subid,
                    'prompt': prompt,
                    'question': single_q,
                    'raw_answer': answer,
                    'ans': ans
                }

                results.append(record)
                with open(save_path, "w", encoding="utf-8") as f:
                    json.dump(results, f, ensure_ascii=False, indent=2)
        
        print(f"Saved streaming results to: {save_path}")

def load_json_or_jsonl(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        first_char = f.read(1)
        f.seek(0)
        if first_char == "[":
            # JSON format
            return json.load(f)
        else:
            # JSONL format
            return [json.loads(line) for line in f if line.strip()]                

def evaluate_uncertainty(model, version, args):
    count = 0
    count_unans = 0
    problem_type =  args.problem_type
    answer_path = f"{args.save_base}/{model.name}/{problem_type}"
    os.makedirs(answer_path, exist_ok=True)

    file_name = f"answers_{version}.json"
    file_path = os.path.join(answer_path, file_name)

    data = load_json_or_jsonl(file_path)
    
    for item in data:
        ans = item.get('ans', '')
        if ans is not None:
            ans = ans.strip().lower()
        count += 1
        if ans == "unanswerable":
            count_unans += 1

    ratio = count_unans / count if count != 0 else 0.0    

    csv_file = f"{args.save_base}/results.csv"
    file_exists = os.path.isfile(csv_file)

    fieldnames = ['model', 'version', 'problem_type', 'total_answers', 'unanswerable_count', 'unanswerable_ratio']     
    file_exists = os.path.isfile(csv_file)

    with open(csv_file, 'a', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)

        if not file_exists or f.tell() == 0:
            writer.writeheader()

        writer.writerow({
            'model': model.name,  
            'version': version,
            'problem_type': problem_type,
            'total_answers': count,
            'unanswerable_count': count_unans,
            'unanswerable_ratio': f"{ratio:.2f}"  
        })

    print("\n\033[1;36m=== Counting Result ===\033[0m")
    print(f"\033[33m Problem Type\033[0m: {problem_type}")
    print(f"\033[33m Total answers\033[0m: {count}")
    print(f"\033[33m Unanswerable count\033[0m: {count_unans}")
    print(f"\033[33m Unanswerable ratio\033[0m: \033[1;32m{ratio:.2%}\033[0m")

