"""Stage v2.5: active sampling for rare classes.

Materialize a fresh pool of unlabeled spans, predict with v1 classifier,
take top-K by predicted probability for each rare target class.

Compared to mint_silver_topup.py (heuristic prefilter), v1's predictions
are calibrated for the 9-class taxonomy and have much higher precision on
rare classes — particularly SUMMARIZE, where heuristic prefilter had only
16% V3-SC yield.
"""
from __future__ import annotations

import argparse
import json
import os
import random
import subprocess
import sys
from collections import Counter
from pathlib import Path

import joblib
import numpy as np
import pandas as pd

from analysis.exploration.llm_validation._client import PRIMITIVES
from analysis.exploration.llm_validation.classifier.mint_silver import (
    materialize_spans, n_episodes_per_trace,
)
from analysis.exploration.llm_validation.sample_spans import (
    explode_episode_index,
)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--parquet", required=True, type=Path)
    ap.add_argument("--raw-results-root", required=True, type=Path)
    ap.add_argument(
        "--existing-silvers", nargs="+", required=True, type=Path,
        help="One or more existing silver JSONL files; their span_ids are excluded",
    )
    ap.add_argument("--bundle", required=True, type=Path,
                    help="Trained v1 classifier joblib")
    ap.add_argument(
        "--pool-size", type=int, default=8000,
        help="How many fresh random spans to materialize as the candidate pool",
    )
    ap.add_argument(
        "--target-classes", default="SUMMARIZE,HYPOTHESIZE,BACKTRACK",
        help="Comma-separated rare class names",
    )
    ap.add_argument(
        "--per-class", type=int, default=400,
        help="How many top-confidence candidates to sample per target class",
    )
    ap.add_argument("--seed", type=int, default=44)
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument("--workers", type=int, default=100)
    ap.add_argument("--n-samples", type=int, default=5)
    ap.add_argument("--temperature", type=float, default=0.7)
    args = ap.parse_args()

    target_classes = [c.strip() for c in args.target_classes.split(",")]
    for c in target_classes:
        if c not in PRIMITIVES:
            raise SystemExit(f"unknown class: {c}")

    args.out.parent.mkdir(parents=True, exist_ok=True)

    # ---- Load existing silver to dedupe ----
    excluded_episodes: set[tuple] = set()
    excluded_span_ids: set[str] = set()
    for sp in args.existing_silvers:
        with open(sp) as f:
            for line in f:
                r = json.loads(line)
                excluded_span_ids.add(r["span_id"])
                excluded_episodes.add((
                    r["checkpoint_id"], r["task_name"],
                    int(r["doc_id"]), int(r["trace_id"]),
                    int(r["episode_idx"]),
                ))
    print(f"Excluded {len(excluded_span_ids)} existing span_ids "
          f"({len(excluded_episodes)} episode tuples)")

    # ---- Build episode index, exclude already-labeled ----
    print(f"Loading parquet: {args.parquet}")
    df = pd.read_parquet(args.parquet)
    episodes = explode_episode_index(df)
    print(f"  total episodes: {len(episodes)}")
    eps_per_trace = n_episodes_per_trace(episodes)

    keys = list(zip(
        episodes["checkpoint_id"].tolist(),
        episodes["task_name"].tolist(),
        episodes["doc_id"].astype(int).tolist(),
        episodes["trace_id"].astype(int).tolist(),
        episodes["episode_idx"].astype(int).tolist(),
    ))
    episodes = episodes.assign(_key=keys)
    episodes = episodes[~episodes["_key"].isin(excluded_episodes)].drop(columns=["_key"])
    print(f"  after exclusion: {len(episodes)}")

    # ---- Random-sample pool ----
    rng = random.Random(args.seed)
    indices = list(range(len(episodes)))
    rng.shuffle(indices)

    # Materialize spans in chunks (resumable on pool_spans.jsonl)
    pool_path = args.out.with_name(args.out.stem + "_pool_spans.jsonl")
    pool_seen: set[str] = set()
    if pool_path.exists():
        with open(pool_path) as f:
            for line in f:
                pool_seen.add(json.loads(line)["span_id"])
        print(f"Resume: {len(pool_seen)} already-materialized in pool")

    pool_rows: list[dict] = []
    if pool_path.exists():
        with open(pool_path) as f:
            for line in f:
                pool_rows.append(json.loads(line))

    cursor = 0
    BATCH = 1000
    pool_file = open(pool_path, "a")
    print(f"Materializing pool (target {args.pool_size})...")
    while len(pool_rows) < args.pool_size and cursor < len(indices):
        end = min(cursor + BATCH, len(indices))
        picks = [episodes.iloc[indices[i]].to_dict() for i in indices[cursor:end]]
        cursor = end
        rows = materialize_spans(picks, args.raw_results_root, eps_per_trace)
        for r in rows:
            if r["span_id"] in pool_seen or r["span_id"] in excluded_span_ids:
                continue
            pool_file.write(json.dumps(r) + "\n")
            pool_file.flush()
            pool_rows.append(r)
            pool_seen.add(r["span_id"])
            if len(pool_rows) >= args.pool_size:
                break
        print(f"  cursor={cursor}/{len(indices)}  pool={len(pool_rows)}", flush=True)
    pool_file.close()

    if len(pool_rows) < args.pool_size:
        print(f"WARNING: only materialized {len(pool_rows)} of {args.pool_size}")

    # ---- v1 prediction over pool ----
    print()
    print(f"Predicting with classifier {args.bundle}")
    bundle = joblib.load(args.bundle)
    pipe = bundle["feature_pipeline"]
    clf = bundle["classifier"]
    label_order = bundle["label_order"]

    X = pipe.transform(pool_rows)
    proba = clf.predict_proba(X)
    print(f"  proba shape: {proba.shape}")

    # ---- Select top-K per target class ----
    selected_ids: set[str] = set()
    selected_per_class: dict[str, int] = {}
    for cls in target_classes:
        cls_idx = label_order.index(cls)
        scores = proba[:, cls_idx]
        # Sort descending, pick first per_class with non-duplicate span_ids
        order = np.argsort(-scores)
        picked = 0
        for i in order:
            sid = pool_rows[int(i)]["span_id"]
            if sid in selected_ids:
                continue
            selected_ids.add(sid)
            picked += 1
            if picked >= args.per_class:
                break
        selected_per_class[cls] = picked
        # Report top score
        top_score = float(scores[order[0]]) if len(order) else 0.0
        print(f"  {cls:<12}: {picked} picked, top-P(class) = {top_score:.3f}")

    selected = [r for r in pool_rows if r["span_id"] in selected_ids]
    print(f"\nTotal unique spans to V3-SC label: {len(selected)}")

    # ---- Write to a spans file for judge_runner ----
    spans_path = args.out.with_name(args.out.stem + "_spans.jsonl")
    judge_path = args.out.with_name(args.out.stem + "_judgments.jsonl")
    with open(spans_path, "w") as f:
        for r in selected:
            f.write(json.dumps(r) + "\n")

    # ---- Run V3-SC labeling ----
    print()
    print("Running V3-SC...")
    cmd = [
        sys.executable, "-m", "analysis.exploration.llm_validation.judge_runner",
        "--in", str(spans_path), "--out", str(judge_path),
        "--model", "deepseek-chat",
        "--temperature", str(args.temperature),
        "--workers", str(args.workers),
        "--n-samples", str(args.n_samples),
    ]
    print("  cmd:", " ".join(cmd))
    subprocess.run(cmd, check=True)

    # ---- Join + write final ----
    judgments = {}
    with open(judge_path) as f:
        for line in f:
            j = json.loads(line)
            judgments[j["span_id"]] = j

    final_rows = []
    n_dropped = {"missing": 0, "parse_error": 0, "api_failure": 0}
    for span in selected:
        j = judgments.get(span["span_id"])
        if j is None:
            n_dropped["missing"] += 1
            continue
        label = j["llm_label"]
        if label == "PARSE_ERROR":
            n_dropped["parse_error"] += 1
            continue
        if label == "API_FAILURE":
            n_dropped["api_failure"] += 1
            continue
        if label not in PRIMITIVES:
            n_dropped["parse_error"] += 1
            continue
        final_rows.append({
            **span,
            "llm_label": label,
            "vote_count": j.get("vote_count"),
            "all_labels": j.get("llm_labels"),
            "tokens_used": j.get("tokens_used"),
        })

    tmp = args.out.with_suffix(args.out.suffix + ".tmp")
    with open(tmp, "w") as f:
        for r in final_rows:
            f.write(json.dumps(r) + "\n")
    os.replace(tmp, args.out)
    print(f"\nWrote {len(final_rows)} active-sampled labels -> {args.out}")
    print(f"Dropped: {n_dropped}")
    print()
    print("New label distribution:")
    counts = Counter(r["llm_label"] for r in final_rows)
    total = sum(counts.values())
    for p in PRIMITIVES:
        c = counts.get(p, 0)
        print(f"  {p:13} {c:>5} ({c/total*100:>5.1f}%)")
    print()
    print("Per-target-class precision (V3-SC says target | v1 said target):")
    for cls in target_classes:
        # Of spans selected for this target class, how many V3-SC actually labels as target?
        # This is approximate since some spans may be in the top-K for multiple classes.
        cls_rows = [r for r in final_rows]  # all selected; we don't track per-target
        true_pos = sum(1 for r in cls_rows if r["llm_label"] == cls)
        print(f"  {cls:13}: {true_pos} of {len(final_rows)} (overall confirmed)")


if __name__ == "__main__":
    main()
