"""Active sampling: use the production 5-way ensemble to score a candidate
span pool, then pick spans by class with priority on rare/weak classes.

Output: a JSONL of selected spans (subset of input) ready for V3-SC minting.

Per-class quotas favor classes the ensemble predicts but we have low support
on (e.g. math BACKTRACK, HYPOTHESIZE). For each class we pick:
  - top `n_high_conf` by max-probability (precision picks)
  - top `n_boundary` by entropy where this class is in top-2 (recall picks)
"""
from __future__ import annotations

import argparse
import json
import glob
from collections import defaultdict
from pathlib import Path

import joblib
import numpy as np

from analysis.exploration.llm_validation._client import PRIMITIVES
from analysis.exploration.llm_validation.classifier.ensemble import predict_proba


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--candidates", type=Path, required=True,
                    help="JSONL of unlabeled span candidates (must have span_id, span_text, preceding_context).")
    ap.add_argument("--ensemble", nargs="+", required=True,
                    help="Joblib bundle paths to use as the production ensemble.")
    ap.add_argument("--out", type=Path, required=True)
    ap.add_argument(
        "--quotas", type=str, default="BACKTRACK:600,HYPOTHESIZE:600,ENUMERATE:400,SUMMARIZE:300,OTHER:200,PLAN:400,SETUP:300,COMPUTE:300,CHECK:300",
        help="Comma-separated CLASS:N high-conf picks per class.",
    )
    ap.add_argument("--n-boundary-per-class", type=int, default=150,
                    help="Additional entropy-near-boundary picks per class.")
    args = ap.parse_args()

    rows = [json.loads(l) for l in open(args.candidates)]
    print(f"Loaded {len(rows)} candidates")

    print(f"Predicting with {len(args.ensemble)} models...")
    proba_sum = None
    for p in args.ensemble:
        b = joblib.load(p)
        proba = predict_proba(b, rows)
        proba_sum = proba if proba_sum is None else proba_sum + proba
        print(f"  + {Path(p).stem}")
    proba = proba_sum / len(args.ensemble)
    print(f"Proba shape: {proba.shape}")

    # Parse quotas
    quotas = {}
    for chunk in args.quotas.split(","):
        cls, n = chunk.split(":")
        quotas[cls.strip()] = int(n.strip())
    print(f"Quotas: {quotas}")

    label_idx = {l: i for i, l in enumerate(PRIMITIVES)}
    n = len(rows)
    selected = set()

    # Phase 1: high-confidence picks per class
    for cls, n_pick in quotas.items():
        cls_i = label_idx[cls]
        # sort by predicted probability for cls, descending
        order = np.argsort(-proba[:, cls_i])
        picked = 0
        for idx in order:
            if rows[idx]["span_id"] in selected:
                continue
            selected.add(rows[idx]["span_id"])
            picked += 1
            if picked >= n_pick:
                break
        print(f"  high-conf {cls}: picked {picked}")

    # Phase 2: entropy boundary picks per class — spans where cls is in top-2 with low margin
    sorted_idx = np.argsort(-proba, axis=1)
    margins = proba[np.arange(n), sorted_idx[:, 0]] - proba[np.arange(n), sorted_idx[:, 1]]
    for cls, _ in quotas.items():
        cls_i = label_idx[cls]
        # spans where cls is top1 or top2
        in_top2 = (sorted_idx[:, 0] == cls_i) | (sorted_idx[:, 1] == cls_i)
        # sort by smallest margin (most uncertain)
        order = np.argsort(margins)
        picked = 0
        for idx in order:
            if not in_top2[idx]:
                continue
            if rows[idx]["span_id"] in selected:
                continue
            selected.add(rows[idx]["span_id"])
            picked += 1
            if picked >= args.n_boundary_per_class:
                break
        print(f"  boundary  {cls}: picked {picked}")

    # Output picked rows
    out_rows = [r for r in rows if r["span_id"] in selected]
    with open(args.out, "w") as f:
        for r in out_rows:
            f.write(json.dumps(r) + "\n")
    print(f"\nWrote {len(out_rows)} selected spans -> {args.out}")

    # Distribution of picked spans by predicted top-1
    pred_top = proba.argmax(1)
    from collections import Counter
    by_top = Counter(PRIMITIVES[pred_top[i]] for i in range(n) if rows[i]["span_id"] in selected)
    print("\nPicked spans by top-1 ensemble prediction:")
    for cls, c in by_top.most_common():
        print(f"  {cls:<14} {c}")


if __name__ == "__main__":
    main()
