import json


# 加载问题库和结果文件
with open("/home/test/yxl/MCoT/data/scienceqa/problems.json", "r") as f:
    questions = json.load(f)

with open("/home/test/yxl/MCoT/sqa/results/mistral-small3.1_test/exp1_test_QCM-ALE_seed_4_SC.json", "r") as f:
    results_data = json.load(f)
    results = results_data["results"]

# 筛选自然科学的question_id
# grades = {"grade1", "grade2", "grade3", "grade4", "grade5", "grade6"}
grades = {"grade7", "grade8", "grade9", "grade10", "grade11", "grade12"}
natural_science_test_questions = {
    qid: details for qid, details in questions.items()
    # if details.get("subject") == "natural science"
    # if details.get("subject") == "social science"
    # if details.get("subject") == "language science"
    if details.get("grade") in grades
    and details.get("split") == "test"
}

# 统计正确数和总题数
total = len(natural_science_test_questions)
correct = 0

for qid, details in natural_science_test_questions.items():
    # 获取正确答案（problems.json中的answer字段）
    true_answer = details["answer"]
    # 获取模型预测结果（results中的值，需转为整数）
    # 注意：假设results中的键是字符串，与questions.json的qid类型一致
    predicted_answer = results.get(qid, -1)  # -1表示未找到预测结果
    # 判断预测是否正确
    if predicted_answer == true_answer:
        correct += 1

# 计算准确率
accuracy = (correct / total) * 100 if total > 0 else 0

print(f"自然科测试题数: {total}")
print(f"正确数: {correct}")
print(f"准确率: {accuracy:.2f}%")
