import sys
import pandas as pd
import numpy as np
from collections import defaultdict
from pathlib import Path

PAIR_KEYS = [
    ("student_answer_a_en", "student_answer_b_en",
     "student_answer_a_vs_student_answer_b"),
    ("student_answer_a_en", "model_answer_a_en",
     "student_answer_a_vs_model_answer_a"),
    ("student_answer_a_en", "model_answer_b_en",
     "student_answer_a_vs_model_answer_b"),
    ("student_answer_b_en", "model_answer_a_en",
     "student_answer_b_vs_model_answer_a"),
    ("student_answer_b_en", "model_answer_b_en",
     "student_answer_b_vs_model_answer_b"),
    ("model_answer_a_en",  "model_answer_b_en",
     "model_answer_a_vs_model_answer_b"),
]

def canon_model(txt: str) -> str:
    """模型判决 → resp1 / resp2 / equal / unknown"""
    t = txt.lower()
    if "both" in t and "equal" in t or "both" in t and "same" in t:
        return "equal"
    if "response 1" in t:
        return "resp1"
    if "response 2" in t:
        return "resp2"
    return "unknown"

def canon_human(txt: str, k1: str, k2: str) -> str:
    t = txt.lower()
    if "equal" in t or "same" in t:
        return "equal"
    if k1.split("_vs_")[0].rsplit("_en", 1)[0] in t:
        return "resp1"
    if k2.split("_vs_")[0].rsplit("_en", 1)[0] in t:
        return "resp2"
    return "unknown"

def calculate_f1(stats: dict) -> dict:
    classes = ["resp1", "resp2", "equal"]
    metrics = {}
    
    for cls in classes:
        tp = stats[f"tp_{cls}"]
        fp = stats[f"fp_{cls}"]
        fn = stats[f"fn_{cls}"]
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        
        metrics[f"precision_{cls}"] = precision
        metrics[f"recall_{cls}"] = recall
        metrics[f"f1_{cls}"] = f1
    
    metrics["macro_f1"] = np.mean([metrics[f"f1_{cls}"] for cls in classes])
    
    total_tp = sum([stats[f"tp_{cls}"] for cls in classes])
    total_fp = sum([stats[f"fp_{cls}"] for cls in classes])
    total_fn = sum([stats[f"fn_{cls}"] for cls in classes])
    
    micro_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
    micro_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
    metrics["micro_f1"] = 2 * micro_precision * micro_recall / (micro_precision + micro_recall) if (micro_precision + micro_recall) > 0 else 0.0
    
    return metrics

def evaluate_one_file(path: Path,
                      totals: defaultdict,
                      unknown_cases: list) -> None:
    df = pd.read_excel(path)

    for k1, k2, prefix in PAIR_KEYS:
        mcol = f"{prefix}_model_verdict"
        hcol = f"{prefix}_human"
        if mcol not in df.columns or hcol not in df.columns:
            continue

        mcol_idx = df.columns.get_loc(mcol) + 1 
        hcol_idx = df.columns.get_loc(hcol) + 1

        for ridx, row in df.iterrows():
            model = canon_model(str(row[mcol]))
            human = canon_human(str(row[hcol]), k1, k2)
            excel_row = ridx + 2             

            if model == "unknown":
                unknown_cases.append(
                    (path.name, excel_row, mcol_idx))
            if human == "unknown":
                unknown_cases.append(
                    (path.name, excel_row, hcol_idx))

            if human == "unknown":         
                continue
            
            totals[prefix]["total"] += 1
            if model == human:
                totals[prefix]["correct"] += 1
            
            classes = ["resp1", "resp2", "equal"]
            for cls in classes:
                if human == cls and model == cls:
                    totals[prefix][f"tp_{cls}"] += 1
                elif human != cls and model == cls:
                    totals[prefix][f"fp_{cls}"] += 1
                elif human == cls and model != cls:
                    totals[prefix][f"fn_{cls}"] += 1

def main(files):
    totals = defaultdict(lambda: defaultdict(int))
    unknown_cases = []   # (xlsx_file, row, col)

    for fp in files:
        evaluate_one_file(Path(fp), totals, unknown_cases)

    grand_correct = grand_total = 0
    grand_tp = defaultdict(int)
    grand_fp = defaultdict(int)
    grand_fn = defaultdict(int)
    
    classes = ["resp1", "resp2", "equal"]
    
    print("\n=== Per-category metrics ===")
    for _, _, prefix in PAIR_KEYS:
        c = totals[prefix]["correct"]
        t = totals[prefix]["total"]
        if t == 0:
            continue
        
        grand_correct += c
        grand_total += t
        accuracy = c/t
        
        for cls in classes:
            grand_tp[cls] += totals[prefix][f"tp_{cls}"]
            grand_fp[cls] += totals[prefix][f"fp_{cls}"]
            grand_fn[cls] += totals[prefix][f"fn_{cls}"]
        
        f1_metrics = calculate_f1(totals[prefix])
        macro_f1 = f1_metrics["macro_f1"]
        micro_f1 = f1_metrics["micro_f1"]
        
        print(f"{prefix:<40s}:")
        print(f"  Accuracy: {c}/{t} = {accuracy:.2%}")
        print(f"  Macro F1: {macro_f1:.4f}")
        print(f"  Micro F1: {micro_f1:.4f}")
        print(f"  Class F1 scores:")
        for cls in classes:
            print(f"    - {cls}: {f1_metrics[f'f1_'+cls]:.4f} (P={f1_metrics[f'precision_'+cls]:.4f}, R={f1_metrics[f'recall_'+cls]:.4f})")
        print("")

    print("\n" + "-"*58)
    if grand_total:
        print(f"Overall accuracy: {grand_correct}/{grand_total} = {grand_correct/grand_total:.2%}")
        
        grand_stats = {}
        for cls in classes:
            grand_stats[f"tp_{cls}"] = grand_tp[cls]
            grand_stats[f"fp_{cls}"] = grand_fp[cls]
            grand_stats[f"fn_{cls}"] = grand_fn[cls]
        
        grand_f1 = calculate_f1(grand_stats)
        print(f"Overall Macro F1: {grand_f1['macro_f1']:.4f}")
        print(f"Overall Micro F1: {grand_f1['micro_f1']:.4f}")
        print("\nClass F1 scores across all categories:")

    if unknown_cases:
        print("\n=== Unable-to-judge cells ===")
        for fname, r, c in unknown_cases:
            print(f"{fname:<30s}  Row {r:<5d}  Col {c}")
    else:
        print("\nNo unknown cases.")

if __name__ == "__main__":
    if len(sys.argv) < 2:
        sys.exit("用法: python evaluate_accuracy_with_f1.py  results_part1.xlsx  [more.xlsx ...]")
    main(sys.argv[1:])