import json
from tqdm import tqdm
import base64
from vllm import LLM, SamplingParams
from transformers import AutoProcessor, AutoTokenizer
import time
import argparse, os
import torch._dynamo
import re
from mathruler.grader import grade_answer

torch._dynamo.config.suppress_errors = True

parser = argparse.ArgumentParser()
parser.add_argument("--benchmark", default='mathverse', type=str)
parser.add_argument("--model", default='qwen2.5vl-7b', type=str)
args = parser.parse_args()
mcq_benchmarks = ["mmstar", "hrbench-4k", "hrbench-8k","vstar", "cvbench-2d", "cvbench-3d", "colorbench", "mme-realworld", "mme-realworld-cn"]

def extract_first_option(text):
    if not text:
        return ""
    
    # 这里的正则逻辑：
    # 1. 优先匹配括号里的字母，如 (A)
    # 2. 其次匹配 A. 或 A) 或 孤立的 A
    # [A-Z] 表示大写字母
    
    # 模式 1: 匹配 (A)
    match = re.search(r'\(([A-Z])\)', text)
    if match:
        return match.group(1)
    
    # 模式 2: 匹配 A. 或 A) 或 A 开头后跟空格
    match = re.search(r'([A-Z])[\.\)\s]', text)
    if match:
        return match.group(1)

    # 模式 3: 直接找第一个出现的大写字母（兜底方案）
    match = re.search(r'([A-Z])', text)
    if match:
        return match.group(1)
        
    return ""

def extract_mcq_option(answer):
    """
    判断答案是否为多选题格式 (例如: A, A., (A), A xxx)
    同时排除 Any, Apple, Area 等普通单词。
    """

    if not isinstance(answer, str) or not answer:
        return ''
    
    # 去除首尾空格
    text = answer.strip()
    
    pattern = r'^[ (\[]*([A-F])(?:(?=$)|[\.\)\]]|(?:[\:\-]\s+))'
    match = re.match(pattern, text)

    if match:
        return match.group(1)  # 返回捕获的字母
    return ""
  
def first_letter_match(gt, answer):
    # gt_val = extract_first_option(gt)
    gt_val = extract_mcq_option(gt)
    pred_val = extract_first_option(answer)
    if gt and pred_val and gt_val == pred_val:
      return True
    else:
      return False

if __name__ == '__main__':
    answer_path = f"model_answer/{args.benchmark}/{args.model}_answer.json"
    save_path = f"judge/{args.benchmark}/{args.model}_answer.json"
    os.makedirs(f"judge/{args.benchmark}", exist_ok=True)
    is_mcq = args.benchmark in mcq_benchmarks

    # with open(answer_path, 'r', encoding='utf-8') as f:
    #     data_list = json.load(f)
    data_list = []
    with open(answer_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            data_list.append(data)

    # 初始化模型（仅在需要时使用）
    sampling_params = SamplingParams(max_tokens=2048, temperature=0)
    llm = LLM(model='/r-contentsecurity/share/checkpoints/opensources/Qwen3-30B-A3B-Instruct-2507',tensor_parallel_size=1,dtype=torch.bfloat16, gpu_memory_utilization=0.9)
    tokenizer = AutoTokenizer.from_pretrained('/r-contentsecurity/share/checkpoints/opensources/Qwen3-30B-A3B-Instruct-2507')

    prompt_template = "Your task is to judge whether the response expresses the same meaning as the answer of a question.\nThe question is: {question}\nThe answer is: {gt}\nThe response is: {response}\nPlease check and compare them and then judge. If the response is correct, your output should be Yes. Otherwise, your output should be No. Directly give me your output."
    prompt_lists = []
    # 提前定义Prompt模板
    prompt_template = "Your task is to judge whether the response expresses the same meaning as the answer of a question.\nThe question is: {question}\nThe answer is: {gt}\nThe response is: {response}\nPlease check and compare them and then judge. If the response is correct, your output should be Yes. Otherwise, your output should be No. Directly give me your output."

    # 第一步：使用 mathruler 进行初步判断
    to_llm_indices = []  # 记录需要LLM介入的索引
    prompt_lists = []    # 存放LLM的Prompt
    
    print("Step 1: Running MathRuler Grader...")
    for i, item in enumerate(tqdm(data_list)):
        # --- 答案提取逻辑 ---
        question = item['query'].replace('<image>', '')
        model_answer_raw = item['model_answer']
        
        if '<answer>' in model_answer_raw:
            extracted_answer = model_answer_raw[model_answer_raw.find('<answer>'):model_answer_raw.find('</answer>')].replace('<answer>', '').replace('</answer>', '')
        elif 'Answer:' in model_answer_raw:
            extracted_answer = model_answer_raw[model_answer_raw.find('Answer:'):]
        else:
            extracted_answer = '\n'.join(model_answer_raw.split('\n')[-3:])
        
        gt = item['response']
        item['extracted_answer'] = extracted_answer # 保存一下提取后的答案
        
        # --- 尝试使用 MathRuler ---
        try:
            # grade_answer 通常返回 True/False 或 1/0
            is_correct = grade_answer(gt, extracted_answer)
        except Exception as e:
            print(f"Grader error at index {i}: {e}")
            is_correct = False
        
        # --- 对于多选题mcq，用收首个大写字母匹配 ---
        is_letter_correct = False
        if not is_correct and is_mcq:
            try:
            # grade_answer 通常返回 True/False 或 1/0
                is_letter_correct = first_letter_match(gt, extracted_answer)
            except Exception as e:
                print(f"Grader error at index {i}: {e}")
                is_letter_correct = False

        if is_correct:
            item['judge'] = 'Yes'
            item['judge_source'] = 'mathruler'
        elif is_letter_correct:
            item['judge'] = 'Yes'
            item['judge_source'] = 'first letter'
        else:
            # 只有这里需要构建 chat 模板
            messages = [{"role": "user", "content": prompt_template.format(gt=gt, response=extracted_answer, question=question)}]
            text = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)
            # 延迟初始化 tokenizer 以节省资源，或者在循环外初始化
            to_llm_indices.append(i)
            prompt_lists.append(text)

    # 第二步：对失败的案例调用 LLM
    if prompt_lists:
        print(f"Step 2: Calling LLM for {len(prompt_lists)} remaining cases...")
        
        
        # 批量生成
        outputs = llm.generate(prompt_lists, sampling_params)
        
        # 将结果填回 data_list
        for idx_in_llm, output in enumerate(outputs):
            original_idx = to_llm_indices[idx_in_llm]
            response_text = output.outputs[0].text.strip()
            data_list[original_idx]['judge'] = response_text
            data_list[original_idx]['judge_source'] = 'llm'

    # 第三步：统计与保存
    correct_num = 0
    for item in data_list:
        if 'judge' in item and ('Yes' in item['judge'] or 'yes' in item['judge']):
            correct_num += 1

    print(f"Final Accuracy: {correct_num/len(data_list):.4f}")
    print(f"Total: {len(data_list)}, LLM used: {len(prompt_lists)}")

    with open(save_path, 'w', encoding='utf-8') as out_file:
        json.dump(data_list, out_file, ensure_ascii=False, indent=4)