#!/usr/bin/env python3
import json, os, math
from typing import List, Dict, Tuple
import numpy as np
import pandas as pd
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    classification_report, confusion_matrix,
    precision_recall_curve, average_precision_score
)
import matplotlib.pyplot as plt

RNG = np.random.default_rng(42)

def load_preds(path: str) -> pd.DataFrame:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for ln in f:
            rows.append(json.loads(ln))
    return pd.DataFrame(rows)

def summarize_classification(df: pd.DataFrame, name: str) -> Dict[str, float]:
    y_true = df["true"].tolist()
    y_pred = df["pred"].tolist()
    labels = ["FLAGGED","NOT FLAGGED"]

    acc = accuracy_score(y_true, y_pred)
    prec_w, rec_w, f1_w, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted", zero_division=0)
    prec_m, rec_m, f1_m, _ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)

    # Per-class F2, then macro
    f2s = []
    for lbl in labels:
        y_t = [1 if t==lbl else 0 for t in y_true]
        y_p = [1 if p==lbl else 0 for p in y_pred]
        p, r, f, _ = precision_recall_fscore_support(y_t, y_p, average="binary", beta=2.0, zero_division=0)
        f2s.append(f)
    f2_m = float(np.mean(f2s))

    # FLAGGED recall (explicit)
    rec_flagged = precision_recall_fscore_support(
        y_true, y_pred, labels=["FLAGGED"], average="macro", zero_division=0
    )[1]

    print(f"\n=== {name} CLASSIFICATION ===")
    print(f"Accuracy: {acc:.4f}")
    print(f"Macro    P:{prec_m:.4f} R:{rec_m:.4f} F1:{f1_m:.4f}  |  F2_macro:{f2_m:.4f}")
    print(f"Weighted P:{prec_w:.4f} R:{rec_w:.4f} F1:{f1_w:.4f}")
    print(f"FLAGGED recall (sensitive class): {rec_flagged:.4f}")
    print("\nReport:\n", classification_report(y_true, y_pred, labels=labels, zero_division=0))
    print("Confusion:\n", pd.DataFrame(confusion_matrix(y_true, y_pred, labels=labels), index=labels, columns=labels))
    return {"acc":acc,"f1_macro":f1_m,"f2_macro":f2_m,"rec_flagged":float(rec_flagged)}

def bootstrap_ci(metric_fn, y_true, y_pred, B=2000, alpha=0.05) -> Tuple[float,float]:
    n = len(y_true)
    vals = []
    for _ in range(B):
        idx = RNG.integers(0, n, size=n)
        yt = [y_true[i] for i in idx]
        yp = [y_pred[i] for i in idx]
        vals.append(metric_fn(yt, yp))
    lo = np.quantile(vals, alpha/2)
    hi = np.quantile(vals, 1 - alpha/2)
    return float(lo), float(hi)

def mcnemar(a_df: pd.DataFrame, b_df: pd.DataFrame) -> Tuple[int,int,float]:
    merged = a_df.merge(b_df, on="text_id", suffixes=("_a","_b"))
    y = merged["true_a"].tolist()
    pa = merged["pred_a"].tolist()
    pb = merged["pred_b"].tolist()
    b01 = 0  # A correct, B wrong
    b10 = 0  # A wrong, B correct
    for t, xa, xb in zip(y, pa, pb):
        if xa==t and xb!=t: b01+=1
        if xa!=t and xb==t: b10+=1
    if (b01 + b10) == 0:
        return b01, b10, 1.0
    chi2 = (abs(b01 - b10)**2) / (b01 + b10)
    p = math.exp(-chi2/2)
    return b01, b10, p

def load_retrieval(path: str) -> pd.DataFrame:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for ln in f:
            rows.append(json.loads(ln))
    return pd.DataFrame(rows)

def main():
    os.makedirs("outputs/figures", exist_ok=True)

    a = load_preds("outputs/preds_a.jsonl")
    b = load_preds("outputs/preds_b.jsonl")

    # Summary
    sa = summarize_classification(a, "Pipeline A")
    sb = summarize_classification(b, "Pipeline B")

    # CIs
    a_y, a_p = a["true"].tolist(), a["pred"].tolist()
    b_y, b_p = b["true"].tolist(), b["pred"].tolist()

    acc_ci_a = bootstrap_ci(lambda yt,yp: accuracy_score(yt, yp), a_y, a_p)
    acc_ci_b = bootstrap_ci(lambda yt,yp: accuracy_score(yt, yp), b_y, b_p)
    f1_ci_a  = bootstrap_ci(lambda yt,yp: precision_recall_fscore_support(yt, yp, average="macro", zero_division=0)[2], a_y, a_p)
    f1_ci_b  = bootstrap_ci(lambda yt,yp: precision_recall_fscore_support(yt, yp, average="macro", zero_division=0)[2], b_y, b_p)

    print(f"\nCIs (95%): A acc {acc_ci_a}, A f1_macro {f1_ci_a}")
    print(f"CIs (95%): B acc {acc_ci_b}, B f1_macro {f1_ci_b}")

    # McNemar
    a2 = a.rename(columns={"true":"true_a","pred":"pred_a"})[["text_id","true_a","pred_a"]]
    b2 = b.rename(columns={"true":"true_b","pred":"pred_b"})[["text_id","true_b","pred_b"]]
    b01, b10, p = mcnemar(a2, b2)
    print(f"\nMcNemar: b01(A correct,B wrong)={b01}, b10(A wrong,B correct)={b10}, p≈{p:.4f}")

    # Retrieval diagnostics (averages) — handle missing columns safely
    try:
        ra = load_retrieval("outputs/retrieval_a.jsonl")
        rb = load_retrieval("outputs/retrieval_b.jsonl")
        for name, df in [("A", ra), ("B", rb)]:
            if df.empty:
                print(f"\nRetrieval summary (Pipeline {name}): <empty>")
                continue
            wanted = ["hit","precision","ndcg","label_precision","diversity"]
            present = [c for c in wanted if c in df.columns]
            print(f"\nRetrieval summary (Pipeline {name}) — showing columns present: {present}")
            if "k" in df.columns and present:
                print(df.groupby("k")[present].mean().round(3))
            else:
                print("(missing 'k' or metrics columns)")
    except Exception as e:
        print(f"\n[WARN] Retrieval summary skipped due to error: {e}")

    # PR curves for FLAGGED (discrete point proxy)
    def pr_point(df, name):
        y = np.array([1 if t=="FLAGGED" else 0 for t in df["true"]])
        s = np.array([1 if p=="FLAGGED" else 0 for p in df["pred"]], dtype=float)
        P, R, _ = precision_recall_curve(y, s)
        ap = average_precision_score(y, s)
        plt.figure()
        plt.step(R, P, where="post")
        plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title(f"PR (FLAGGED) – {name} (AP≈{ap:.3f})")
        plt.grid(True); plt.savefig(f"outputs/figures/pr_{name.lower()}.png", dpi=180)
        plt.close()
        print(f"Saved outputs/figures/pr_{name.lower()}.png")

    pr_point(a, "PipelineA")
    pr_point(b, "PipelineB")

    rep = {
        "pipeline_a": sa | {"acc_ci": acc_ci_a, "f1_macro_ci": f1_ci_a},
        "pipeline_b": sb | {"acc_ci": acc_ci_b, "f1_macro_ci": f1_ci_b},
        "mcnemar": {"b01": b01, "b10": b10, "p_value_approx": p}
    }
    with open("outputs/summary.json","w") as f:
        json.dump(rep, f, indent=2)
    print("\nSaved outputs/summary.json and PR plots.")

if __name__ == "__main__":
    main()
