#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
compare_judgement.py

统计 LRM “思考阶段” 与 “最终回答” 对人工标签的符合情况，并导出不一致样本。
"""

import json
from pathlib import Path

# ========= 配置区：按需修改 ========= #
GROUND_TRUTH_PATH      = Path()         # 含 "label" 的文件
LRM_JUDGEMENT_PATH     = Path()        # 含 thinking_result / judge_result 的文件
PARTIAL_MISMATCH_PATH  = Path()      # 输出：TWJR + TRJW
FULL_MISMATCH_PATH     = Path()         # 输出：TWJW
# =================================== #


def load_json(path: Path):
    """读取 JSON（数组形式）。"""
    with path.open("r", encoding="utf-8") as f:
        return json.load(f)


def main():
    # 1. 读取两个文件
    truth_data = load_json(GROUND_TRUTH_PATH)
    lrm_data   = load_json(LRM_JUDGEMENT_PATH)

    # 将标准答案转换成 {id: label} 字典，id 统一转成 str 方便匹配
    truth_dict = {str(item["id"]): item["label"] for item in truth_data}

    # 统计量
    TRJR = TWJR = TRJW = TWJW = 0

    # 需要导出的样本
    partial_mismatch_samples = []  # TWJR + TRJW
    full_mismatch_samples    = []  # TWJW

    for item in lrm_data:
        _id = str(item["id"])                      # 保证 id 为 str
        label = truth_dict.get(_id)                # 标准答案

        if label is None:
            print(f"[Warn] id={_id} 在标准答案文件中不存在，跳过。")
            continue

        # 兼容两种写法：judge_result / judge result
        thinking_result = item.get("thinking_result") or item.get("thinking result")
        judge_result    = item.get("judge_result")  or item.get("judge result")

        # 2. 计算类别
        if label == "A=B":
            # 规则：人工标注为 A=B → 无条件计入 TRJR
            TRJR += 1
            continue

        think_match  = thinking_result == label
        judge_match  = judge_result    == label

        if think_match and judge_match:
            TRJR += 1
        elif (not think_match) and judge_match:
            TWJR += 1
            partial_mismatch_samples.append(item)
        elif think_match and (not judge_match):
            TRJW += 1
            partial_mismatch_samples.append(item)
        else:
            TWJW += 1
            full_mismatch_samples.append(item)

    # 3. 输出结果
    print("\n=== 统计结果 ===")
    print(f"TRJR (两次判断都对)        : {TRJR}")
    print(f"TWJR (思考错, 最终对)      : {TWJR}")
    print(f"TRJW (思考对, 最终错)      : {TRJW}")
    print(f"TWJW (两次判断都错)        : {TWJW}")

    # 4. 将不一致样本写文件
    PARTIAL_MISMATCH_PATH.parent.mkdir(parents=True, exist_ok=True)
    FULL_MISMATCH_PATH.parent.mkdir(parents=True, exist_ok=True)

    with PARTIAL_MISMATCH_PATH.open("w", encoding="utf-8") as f:
        json.dump(partial_mismatch_samples, f, ensure_ascii=False, indent=2)

    with FULL_MISMATCH_PATH.open("w", encoding="utf-8") as f:
        json.dump(full_mismatch_samples, f, ensure_ascii=False, indent=2)

    print(f"\n已导出部分不一致样本 → {PARTIAL_MISMATCH_PATH}")
    print(f"已导出全部不一致样本 → {FULL_MISMATCH_PATH}")


if __name__ == "__main__":
    main()
