import sys
import pandas as pd
import numpy as np
from collections import defaultdict
from pathlib import Path

# —— 与生成脚本保持一致的 6 组比对关系 ——
PAIR_KEYS = [
    ("student_answer_a", "student_answer_b",
     "student_answer_a_vs_student_answer_b"),
    ("student_answer_a", "model_answer_a",
     "student_answer_a_vs_model_answer_a"),
    ("student_answer_a", "model_answer_b",
     "student_answer_a_vs_model_answer_b"),
    ("model_answer_a",  "model_answer_b",
     "model_answer_a_vs_model_answer_b"),
]

# ———— 归一化工具 ————
def canon_model(txt: str) -> str:
    """模型判决 → resp1 / resp2 / equal / unknown"""
    t = txt.lower()
    if "【两个回答相同】" in t or "两个回答相同" in t:
        return "equal"
    if "【回答一】" in t or "**回答一**" in t:
        return "resp1"
    if "【回答二】" in t or "**回答二**" in t:
        return "resp2"
    return "unknown"

def canon_human(txt: str, k1: str, k2: str) -> str:
    """人工标签 → resp1 / resp2 / equal / unknown"""
    t = txt.lower()
    if "equal" in t or "same" in t:
        return "equal"
    # 用左、右答案在列名中的缩写来匹配
    if k1.split(" vs ")[0] in t:
        return "resp1"
    if k2.split(" vs ")[0] in t:
        return "resp2"
    return "unknown"

# ———— F1 计算函数 ————
def calculate_f1(stats: dict) -> dict:
    """计算每个类别的精确率、召回率和F1分数，以及宏观和微观平均"""
    classes = ["resp1", "resp2", "equal"]
    metrics = {}
    
    # 计算每个类别的指标
    for cls in classes:
        tp = stats[f"tp_{cls}"]
        fp = stats[f"fp_{cls}"]
        fn = stats[f"fn_{cls}"]
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        
        metrics[f"precision_{cls}"] = precision
        metrics[f"recall_{cls}"] = recall
        metrics[f"f1_{cls}"] = f1
    
    # 计算宏观平均 (简单平均每个类别的F1)
    metrics["macro_f1"] = np.mean([metrics[f"f1_{cls}"] for cls in classes])
    
    # 计算微观平均 (合并所有类别的 TP, FP, FN 后计算)
    total_tp = sum([stats[f"tp_{cls}"] for cls in classes])
    total_fp = sum([stats[f"fp_{cls}"] for cls in classes])
    total_fn = sum([stats[f"fn_{cls}"] for cls in classes])
    
    micro_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
    micro_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
    metrics["micro_f1"] = 2 * micro_precision * micro_recall / (micro_precision + micro_recall) if (micro_precision + micro_recall) > 0 else 0.0
    
    return metrics

# ———— 主评估函数 ————
def evaluate_one_file(path: Path,
                      totals: defaultdict,
                      unknown_cases: list) -> None:
    df = pd.read_excel(path)

    for k1, k2, prefix in PAIR_KEYS:
        mcol = f"{prefix}_model_verdict"
        hcol = f"{prefix}_human"
        if mcol not in df.columns or hcol not in df.columns:
            continue  # 某些分片可能缺列

        mcol_idx = df.columns.get_loc(mcol) + 1  # Excel 中 1‑based
        hcol_idx = df.columns.get_loc(hcol) + 1

        for ridx, row in df.iterrows():
            model = canon_model(str(row[mcol]))
            human = canon_human(str(row[hcol]), k1, k2)
            excel_row = ridx + 2              # 加 2：+1 header 行、+1 0‑base

            # —— 记录 unknown 情况 ——
            if model == "unknown":
                unknown_cases.append(
                    (path.name, excel_row, mcol_idx))
            if human == "unknown":
                unknown_cases.append(
                    (path.name, excel_row, hcol_idx))

            # —— 计算准确率和F1相关指标 ——
            if human == "unknown":            # 无人工参考 → 不计入
                continue
            
            totals[prefix]["total"] += 1
            if model == human:
                totals[prefix]["correct"] += 1
            
            # 更新每个类别的TP, FP, FN统计
            classes = ["resp1", "resp2", "equal"]
            for cls in classes:
                # 如果human是当前类别且预测正确 → TP+1
                if human == cls and model == cls:
                    totals[prefix][f"tp_{cls}"] += 1
                # 如果human不是当前类别但模型预测是当前类别 → FP+1
                elif human != cls and model == cls:
                    totals[prefix][f"fp_{cls}"] += 1
                # 如果human是当前类别但模型预测不是当前类别 → FN+1
                elif human == cls and model != cls:
                    totals[prefix][f"fn_{cls}"] += 1

def main(files):
    # 初始化默认字典，为每个指标设置默认值0
    totals = defaultdict(lambda: defaultdict(int))
    unknown_cases = []   # (xlsx_file, row, col)

    for fp in files:
        evaluate_one_file(Path(fp), totals, unknown_cases)

    # —— 输出准确率和F1分数 ——
    grand_correct = grand_total = 0
    grand_tp = defaultdict(int)
    grand_fp = defaultdict(int)
    grand_fn = defaultdict(int)
    
    classes = ["resp1", "resp2", "equal"]
    
    print("\n=== Per-category metrics ===")
    for _, _, prefix in PAIR_KEYS:
        c = totals[prefix]["correct"]
        t = totals[prefix]["total"]
        if t == 0:
            continue
        
        # 计算准确率
        grand_correct += c
        grand_total += t
        accuracy = c/t
        
        # 计算F1分数
        for cls in classes:
            grand_tp[cls] += totals[prefix][f"tp_{cls}"]
            grand_fp[cls] += totals[prefix][f"fp_{cls}"]
            grand_fn[cls] += totals[prefix][f"fn_{cls}"]
        
        f1_metrics = calculate_f1(totals[prefix])
        macro_f1 = f1_metrics["macro_f1"]
        micro_f1 = f1_metrics["micro_f1"]
        
        print(f"{prefix:<40s}:")
        print(f"  Accuracy: {c}/{t} = {accuracy:.2%}")
        print(f"  Macro F1: {macro_f1:.4f}")
        print(f"  Micro F1: {micro_f1:.4f}")
        print(f"  Class F1 scores:")
        for cls in classes:
            print(f"    - {cls}: {f1_metrics[f'f1_'+cls]:.4f} (P={f1_metrics[f'precision_'+cls]:.4f}, R={f1_metrics[f'recall_'+cls]:.4f})")
        print("")

    print("\n" + "-"*58)
    if grand_total:
        print(f"Overall accuracy: {grand_correct}/{grand_total} = {grand_correct/grand_total:.2%}")
        
        # 计算整体F1分数
        grand_stats = {}
        for cls in classes:
            grand_stats[f"tp_{cls}"] = grand_tp[cls]
            grand_stats[f"fp_{cls}"] = grand_fp[cls]
            grand_stats[f"fn_{cls}"] = grand_fn[cls]
        
        grand_f1 = calculate_f1(grand_stats)
        print(f"Overall Macro F1: {grand_f1['macro_f1']:.4f}")
        print(f"Overall Micro F1: {grand_f1['micro_f1']:.4f}")
        print("\nClass F1 scores across all categories:")
        for cls in classes:
            print(f"  - {cls}: {grand_f1[f'f1_'+cls]:.4f} (P={grand_f1[f'precision_'+cls]:.4f}, R={grand_f1[f'recall_'+cls]:.4f})")
    else:
        print("No evaluable rows found.")

    # —— 输出无法判断的行列信息 ——
    if unknown_cases:
        print("\n=== Unable-to-judge cells ===")
        for fname, r, c in unknown_cases:
            print(f"{fname:<30s}  Row {r:<5d}  Col {c}")
    else:
        print("\nNo unknown cases.")

if __name__ == "__main__":
    if len(sys.argv) < 2:
        sys.exit("用法: python evaluate_accuracy_with_f1.py  results_part1.xlsx  [more.xlsx ...]")
    main(sys.argv[1:])