#!/usr/bin/env python3
import argparse, json, sys
from statistics import mean

def load_preds(p):
    rows=[]
    with open(p, "r", encoding="utf-8") as f:
        for ln in f:
            rows.append(json.loads(ln))
    # sanity: make sure p_flagged exists and is float
    for r in rows:
        if "p_flagged" not in r:
            # fall back from label to probability if needed
            r["p_flagged"] = 1.0 if r.get("pred") == "FLAGGED" else 0.0
        else:
            r["p_flagged"] = float(r["p_flagged"])
    return rows

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--bplus", required=True, help="B+ preds.jsonl (has p_flagged)")
    ap.add_argument("--c",      required=True, help="Classifier C preds.jsonl (has p_flagged)")
    ap.add_argument("--alpha", type=float, default=0.70, help="blend weight for C: p_ens = alpha*p_c + (1-alpha)*p_bplus")
    ap.add_argument("--thresh", type=float, default=0.50, help="decision threshold on blended p_ens")
    ap.add_argument("--veto_range", type=float, default=0.12, help="if |p_c - p_b| ≤ this, apply veto logic (unless --no-veto)")
    ap.add_argument("--no-veto", action="store_true", help="disable veto behavior; pure blend + threshold")
    ap.add_argument("--out", default="outputs/preds_ensemble.jsonl")
    args = ap.parse_args()

    b = load_preds(args.bplus)
    c = load_preds(args.c)

    bm = {r["text_id"]: r for r in b}
    cm = {r["text_id"]: r for r in c}
    ids = sorted(set(bm) & set(cm))
    if len(ids) == 0:
        raise ValueError("No overlapping text_id between B+ and C.")
    # optional sanity: if you expect exactly test size
    # print(f"[INFO] overlap={len(ids)} rows; B+={len(b)} C={len(c)}", file=sys.stderr)

    kept = 0
    diffs = 0
    out_rows = []
    for i in ids:
        rb = bm[i]; rc = cm[i]
        pb = float(rb["p_flagged"])
        pc = float(rc["p_flagged"])

        # weighted blend
        p_ens = args.alpha * pc + (1.0 - args.alpha) * pb

        # veto rule (optional)
        if not args.no_veto and abs(pc - pb) <= args.veto_range:
            # inside the gray zone: keep the classifier’s discrete prediction
            pred = rc["pred"]
            p_out = pc
        else:
            # outside gray zone: use blended probability with threshold
            pred = "FLAGGED" if p_ens >= args.thresh else "NOT FLAGGED"
            p_out = p_ens

        if rc["pred"] != pred:
            diffs += 1

        out_rows.append({
            "text_id": i,
            "dataset": rc.get("dataset", rb.get("dataset", "")),
            "true": rc.get("true", rb.get("true", None)),
            "pred": pred,
            "p_flagged": float(p_out),
            "p_c": float(pc),
            "p_bplus": float(pb),
            "p_blend": float(p_ens),
        })
        kept += 1

    # write
    with open(args.out, "w", encoding="utf-8") as f:
        for r in out_rows:
            f.write(json.dumps(r) + "\n")

    # quick summary
    ys = [1 if r.get("true") == "FLAGGED" else 0 for r in out_rows if r.get("true") in ("FLAGGED", "NOT FLAGGED")]
    ps = [1 if r["pred"] == "FLAGGED" else 0 for r in out_rows]
    if len(ys) == len(ps) and len(ys) > 0:
        acc = sum(int(y == p) for y, p in zip(ys, ps)) / len(ys)
        rec_flagged = (sum(1 for y, p in zip(ys, ps) if y == 1 and p == 1) /
                       max(1, sum(1 for y in ys if y == 1)))
        print(f"[DONE] wrote {args.out} (kept {kept} rows, changed {diffs}); acc≈{acc:.4f}  rec_FLAGGED≈{rec_flagged:.4f}")
    else:
        print(f"[DONE] wrote {args.out} (kept {kept} rows, changed {diffs})")

if __name__ == "__main__":
    main()
