#!/usr/bin/env python3
import argparse, json, numpy as np

def load_preds(p):
    rows = [json.loads(x) for x in open(p, "r", encoding="utf-8")]
    m = {r["text_id"]: r for r in rows}
    return m

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--bplus", required=True)
    ap.add_argument("--c", required=True)
    ap.add_argument("--alpha", type=float, default=0.80, help="weight on classifier C in rank blend")
    ap.add_argument("--thresh", type=float, default=0.60, help="decision threshold on blended prob")
    ap.add_argument("--margin", type=float, default=0.18, help="|pC-0.5| above which we trust C outright")
    ap.add_argument("--out", default="outputs/preds_ensemble.jsonl")
    args = ap.parse_args()

    bm = load_preds(args.bplus)
    cm = load_preds(args.c)
    ids = sorted(set(bm) & set(cm))
    if not ids:
        raise SystemExit("No overlapping text_id keys between B+ and C.")

    # collect probs
    pB = np.array([float(bm[i]["p_flagged"]) for i in ids])
    pC = np.array([float(cm[i]["p_flagged"]) for i in ids])

    # ranks (higher prob => higher rank). Use argsort twice to get 1..N ranks.
    rB = np.argsort(np.argsort(pB))
    rC = np.argsort(np.argsort(pC))
    rB = rB.astype(np.float32); rC = rC.astype(np.float32)

    # normalized ranks to [0,1]
    if len(ids) > 1:
        rB = rB / (len(ids) - 1)
        rC = rC / (len(ids) - 1)
    # rank-blended score
    s_rank = (1.0 - args.alpha) * rB + (args.alpha) * rC

    # probability-blended score (as a tie-breaker)
    s_prob = (1.0 - args.alpha) * pB + (args.alpha) * pC

    # final score: mix rank with a little prob to break rank ties smoothly
    s = 0.85 * s_rank + 0.15 * s_prob

    # apply confidence veto from C
    confC = np.abs(pC - 0.5) >= args.margin

    def to_label(x): return "FLAGGED" if x >= args.thresh else "NOT FLAGGED"

    from pathlib import Path
    Path("outputs").mkdir(exist_ok=True)
    with open(args.out, "w", encoding="utf-8") as f:
        for k, idx in zip(ids, range(len(ids))):
            # default by ensemble
            pred_ens = to_label(s[idx])
            # veto: if C is confident, take C’s hard class
            pred_c = "FLAGGED" if pC[idx] >= args.thresh else "NOT FLAGGED"
            final_pred = pred_c if confC[idx] else pred_ens

            f.write(json.dumps({
                "text_id": k,
                "dataset": bm[k].get("dataset", cm[k].get("dataset", "")),
                "true": bm[k].get("true", cm[k].get("true")),
                "pred": final_pred,
                "p_flagged": float(s[idx]),         # ensemble score (for reference)
                "p_flagged_c": float(pC[idx]),
                "p_flagged_b": float(pB[idx]),
                "veto_c": bool(confC[idx])
            }) + "\n")

    print(f"[DONE] wrote {args.out} (kept {len(ids)} overlapped rows)")

if __name__ == "__main__":
    main()
