import json
import re
from datasets import load_dataset, concatenate_datasets
def extract_answers_from_json(answers_file_path, student_field):
    # 读取答案 JSON 文件
    with open(answers_file_path, 'r') as file:
        data = json.load(file)
    
    # 定义正则表达式，匹配左括号、左方括号或左花括号后跟字母 A 到 Z（不区分大小写）
    pattern = re.compile(r'[\(\[\{]\s*(\d{1,2}|[A-Z])', re.IGNORECASE)
    
    extracted_answers = []
    none_count = 0  # 计数 None 的数量
    
    for entry in data:
        answer = entry.get(student_field, '')['answer']  # 获取 answer 字段
        
        # 查找所有匹配的字符串
        matches = pattern.findall(answer)
        
        if matches:
            # 如果找到了匹配项，取最后一个匹配项并转换为大写
            if matches[-1].strip().isdigit():
                if 1 <= matches[-1].strip().isdigit() <= 26:
                    # 将数字转换为对应的大写字母
                    extracted_answers.append(chr(int(matches[-1].strip()) + 64))                
            else:
                extracted_answers.append(matches[-1].upper())
        else:
            extracted_answers.append(None)  # 没有匹配到时返回 None
            none_count += 1  # 增加 None 计数
    
    return extracted_answers, none_count

def get_correct_answers(correct_file_path):
    # 读取正确答案 JSON 文件
    subcategories = [
        'abstract_algebra',
        'anatomy',
        'astronomy',
        'business_ethics',
        'clinical_knowledge',
        'college_biology',
        'college_chemistry',
        'college_computer_science',
        'college_mathematics',
        'college_medicine',
        'college_physics',
        'computer_security',
        'conceptual_physics',
        'econometrics',
        'electrical_engineering',
        'elementary_mathematics',
        'formal_logic',
        'global_facts',
        'high_school_biology',
        'high_school_chemistry',
        'high_school_computer_science',
        'high_school_european_history',
        'high_school_geography',
        'high_school_government_and_politics',
        'high_school_macroeconomics',
        'high_school_mathematics',
        'high_school_microeconomics',
        'high_school_physics',
        'high_school_psychology',
        'high_school_statistics',
        'high_school_us_history',
        'high_school_world_history',
        'human_aging',
        'human_sexuality',
        'international_law',
        'jurisprudence',
        'logical_fallacies',
        'machine_learning',
        'management',
        'marketing',
        'medical_genetics',
        'miscellaneous',
        'moral_disputes',
        'moral_scenarios',
        'nutrition',
        'philosophy',
        'prehistory',
        'professional_accounting',
        'professional_law',
        'professional_medicine',
        'professional_psychology',
        'public_relations',
        'security_studies',
        'sociology',
        'us_foreign_policy',
        'virology',
        'world_religions'
    ]

    datasets_list = []
    for subcat in subcategories:
        try:
            dataset = load_dataset(
                path=correct_file_path,
                name=subcat,
                split='test'
            )
            datasets_list.append(dataset)
        except Exception as e:
            print(f"加载子类别 {subcat} 时出错: {e}")

    if not datasets_list:
        print("未能加载任何子类别的测试集。请检查数据目录和子类别名称是否正确。")
        return []

    # 合并所有子类别的测试集
    data = concatenate_datasets(datasets_list)
    
    correct_answers = []
    
    for entry in data:
        correct = entry.get('answer', '')
        if isinstance(correct, int):
            if 0 <= correct <= 25:
                correct = chr(ord('A') + correct)
            else:
                print(f"Warning: 数值 {correct} 超出有效范围（0-25），无法转换为字母。")
        elif isinstance(correct, str):
            correct = correct.strip().upper()
        else:
            print(f"Warning: 无法识别的 answer 类型：{type(correct)}")
        correct_answers.append(correct)

    
    return correct_answers

def calculate_accuracy(extracted_answers, correct_answers):
    # 计算准确率
    total = len(extracted_answers)
    correct_count = 0
    
    for extracted, correct in zip(extracted_answers, correct_answers):
        if extracted and extracted == correct[-1].upper():  # 统一转换为大写进行比较
            correct_count += 1
    
    accuracy = correct_count / total if total > 0 else 0
    return accuracy

def compare_student_answers(studentA_answers, studentC_answers, correct_answers):
    A_correct_C_wrong = 0
    A_wrong_C_correct = 0

    for a_ans, c_ans, correct in zip(studentA_answers, studentC_answers, correct_answers):
        a_correct = a_ans and a_ans == correct[-1].upper()
        c_correct = c_ans and c_ans == correct[-1].upper()

        if a_correct and not c_correct:
            A_correct_C_wrong += 1
        elif not a_correct and c_correct:
            A_wrong_C_correct += 1

    return A_correct_C_wrong, A_wrong_C_correct

# 调用函数并打印提取的结果、准确率和 None 的数量
answers_file_path = "LLAMA3_MATHQA_EVAL.json"  # 替换为你的答案 JSON 文件路径
correct_file_path = "AI_School_main_vllm/data/eval_datasets/mmlu"  # 替换为你的正确答案 JSON 文件路径

# 提取答案和正确答案
studentA_answers, none_count_A = extract_answers_from_json(answers_file_path, '<StudentA>')
studentC_answers, none_count_C = extract_answers_from_json(answers_file_path, '<StudentC>')
correct_answers = get_correct_answers(correct_file_path)

# 计算准确率
accuracy_A = calculate_accuracy(studentA_answers, correct_answers)
accuracy_C = calculate_accuracy(studentC_answers, correct_answers)

# 比较两位学生的答案
A_correct_C_wrong, A_wrong_C_correct = compare_student_answers(studentA_answers, studentC_answers, correct_answers)

print(f"Accuracy of StudentA: {accuracy_A * 100:.2f}%")
print(f"Accuracy of StudentC: {accuracy_C * 100:.2f}%")
print(f"Number of None entries for StudentA: {none_count_A}")
print(f"Number of None entries for StudentC: {none_count_C}")
print(f"Number of questions where StudentA is correct and StudentC is wrong: {A_correct_C_wrong}")
print(f"Number of questions where StudentA is wrong and StudentC is correct: {A_wrong_C_correct}")
