import json
import re

import pandas as pd

def extract_answers_from_json(answers_file_path):
    # 读取答案 JSON 文件
    with open(answers_file_path, 'r') as file:
        data = json.load(file)
    
    # 定义正则表达式，匹配左括号、左方括号或左花括号后跟字母 A 到 Z（不区分大小写）
    pattern = re.compile(r'[\(\[\{]\s*(\d{1,2}|[A-Z])\s*[\)\]\}]|([A-Z])\s*\.', re.IGNORECASE)
    
    extracted_answers = []
    none_count = 0  # 计数 None 的数量
    
    for entry in data:
        answer = entry.get('answer', '')  # 获取 answer 字段
        
        # 查找所有匹配的字符串
        matches = pattern.findall(answer)
        
        if matches:
            # 获取最后一个匹配项，注意 matches[-1] 是个 tuple
            match = matches[-1]
            val = match[0] if match[0] else match[1]  # 选出有效的捕获内容

            if val and val.strip().isdigit():
                num = int(val.strip())
                if 1 <= num <= 26:
                    extracted_answers.append(chr(num + 64))  # 转换为大写字母
                else:
                    extracted_answers.append(None)
                    none_count += 1
            else:
                extracted_answers.append(val.strip().upper())
        else:
            extracted_answers.append(None)
            none_count += 1
    
    return extracted_answers, none_count

def get_correct_answers(correct_file_path):
    """
    从 JSON 文件中读取正确答案列表，并将每个答案转换为大写字母。
    
    参数:
        correct_file_path (str): JSON 文件的路径。
        
    返回:
        List[str]: 包含大写字母的正确答案列表。
    """
    with open(correct_file_path, 'r', encoding='utf-8') as f:
        correct_answers = json.load(f)
    
    # 确保所有答案为大写字母
    correct_answers = [str(ans).strip().upper() for ans in correct_answers]
    
    return correct_answers

def calculate_accuracy(extracted_answers, correct_answers):
    # 计算准确率
    total = len(extracted_answers)
    correct_count = 0
    
    for extracted, correct in zip(extracted_answers, correct_answers):
        if extracted and extracted == correct[-1].upper():  # 统一转换为大写进行比较
            correct_count += 1
    
    accuracy = correct_count / total if total > 0 else 0
    return accuracy


# 调用函数并打印提取的结果、准确率和 None 的数量
answers_file_path = "LLAMA3_MATHQA_EVAL.json"  # 替换为你的答案 JSON 文件路径
correct_file_path = "AI_School_main_vllm/eval/GENERAL_EVAL/SCI-Q/sciq_answer_keys.json"  # 替换为你的正确答案 JSON 文件路径

# 提取答案和正确答案
extracted_answers, none_count = extract_answers_from_json(answers_file_path)
correct_answers = get_correct_answers(correct_file_path)

# 计算准确率
accuracy = calculate_accuracy(extracted_answers, correct_answers)

print(f"Accuracy: {accuracy * 100:.2f}%")
print(f"Number of None entries: {none_count}")
