"""Stage 1b: heuristic-prefiltered top-up of rare-class silver labels.

Augments silver_train.jsonl with extra labels concentrated on classes that
were starved in the initial random 10k mint (SUMMARIZE/100, OTHER/234,
PLAN/392, HYPOTHESIZE/426, BACKTRACK/478). For each target class, samples
candidates from the parquet's heuristic primitive_sequence in the
appropriate heuristic class(es), materializes spans, and runs V3-SC.

Yield is imperfect (heuristic over-/under-fires), but most spans labeled
this way still produce useful labels — even if V3-SC reclassifies a
heuristic-SUMMARIZE candidate as COMPUTE, we get a fresh COMPUTE label
that augments the training set.
"""
from __future__ import annotations

import argparse
import json
import os
import random
import subprocess
import sys
from collections import Counter
from pathlib import Path

import pandas as pd

from analysis.exploration.llm_validation._client import (
    PRIMITIVES, to_new_taxonomy,
)
from analysis.exploration.llm_validation.classifier.mint_silver import (
    materialize_spans, n_episodes_per_trace,
)
from analysis.exploration.llm_validation.sample_spans import (
    explode_episode_index,
)


# Map rare 9-class targets to heuristic 10-class candidate pools.
# Heuristic tends to over-detect SUMMARIZE/BACKTRACK; under-detect HYPOTHESIZE
# (caught only via "Suppose"/"Assume" patterns); PLAN is rarely heuristic-PLAN
# (most plans are heuristic-OTHER or DECOMPOSE).
TARGET_HEURISTIC_POOLS = {
    "SUMMARIZE":   ["SUMMARIZE"],
    "OTHER":       ["OTHER"],
    "BACKTRACK":   ["BACKTRACK"],
    "HYPOTHESIZE": ["HYPOTHESIZE"],
    "PLAN":        ["PLAN", "DECOMPOSE"],
}


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--parquet", required=True, type=Path)
    ap.add_argument("--raw-results-root", required=True, type=Path)
    ap.add_argument(
        "--existing-silver", required=True, type=Path,
        help="silver_train.jsonl from initial mint (its span_ids are excluded)",
    )
    ap.add_argument(
        "--per-class", type=int, default=400,
        help="Target number of candidates per rare class (before V3-SC filtering)",
    )
    ap.add_argument("--seed", type=int, default=43)
    ap.add_argument("--out", required=True, type=Path,
                    help="Output topup-only labels file (NOT merged with existing)")
    ap.add_argument("--workers", type=int, default=100)
    ap.add_argument("--n-samples", type=int, default=5)
    ap.add_argument("--temperature", type=float, default=0.7)
    args = ap.parse_args()

    # ---- Load existing silver to deduplicate ----
    print(f"Loading existing silver: {args.existing_silver}")
    existing_ids: set[str] = set()
    with open(args.existing_silver) as f:
        for line in f:
            existing_ids.add(json.loads(line)["span_id"])
    print(f"  {len(existing_ids)} existing span_ids to exclude")

    # ---- Load parquet, build episode index ----
    print(f"Loading parquet: {args.parquet}")
    df = pd.read_parquet(args.parquet)
    episodes = explode_episode_index(df)
    print(f"  total episodes: {len(episodes)}")
    eps_per_trace = n_episodes_per_trace(episodes)

    # Episode-row span_id format: <ckpt>|<task>|<doc>|<trace>|<unknown_span_idx>
    # We can't compute span_id pre-materialization; instead exclude on
    # (ckpt, task, doc, trace, episode_idx) since that's deterministic.
    excluded_episodes: set[tuple] = set()
    with open(args.existing_silver) as f:
        for line in f:
            r = json.loads(line)
            excluded_episodes.add((
                r["checkpoint_id"], r["task_name"],
                int(r["doc_id"]), int(r["trace_id"]),
                int(r["episode_idx"]),
            ))
    print(f"  excluded {len(excluded_episodes)} (ckpt,task,doc,trace,ep) tuples")

    # ---- Sample candidates per target class ----
    rng = random.Random(args.seed)
    args.out.parent.mkdir(parents=True, exist_ok=True)
    spans_path = args.out.with_name(args.out.stem + "_spans.jsonl")
    judge_path = args.out.with_name(args.out.stem + "_judgments.jsonl")

    # Pre-compute a vectorized "trace_ep_key" column on the full episode table
    # once. List-of-tuples lookup against `excluded_episodes` is much faster
    # than df.apply(axis=1) for large pools.
    print("Building trace_ep_key column...")
    keys = list(zip(
        episodes["checkpoint_id"].tolist(),
        episodes["task_name"].tolist(),
        episodes["doc_id"].astype(int).tolist(),
        episodes["trace_id"].astype(int).tolist(),
        episodes["episode_idx"].astype(int).tolist(),
    ))
    episodes = episodes.assign(_key=keys)
    excluded_set = excluded_episodes  # rename for clarity
    mask_excluded = episodes["_key"].isin(excluded_set)
    episodes = episodes[~mask_excluded].drop(columns=["_key"])
    print(f"  episodes after exclusion: {len(episodes)}")

    all_picks: list[dict] = []
    n_per_target = {}
    for target_class, heuristic_pools in TARGET_HEURISTIC_POOLS.items():
        pool = episodes[episodes.label.isin(heuristic_pools)]
        avail = len(pool)
        n_take = min(args.per_class, avail)
        # Shuffle deterministically
        idx = list(range(len(pool)))
        rng.shuffle(idx)
        chosen = pool.iloc[idx[:n_take]]
        for r in chosen.itertuples():
            all_picks.append({
                "checkpoint_id": r.checkpoint_id,
                "task_name": r.task_name,
                "doc_id": int(r.doc_id),
                "trace_id": int(r.trace_id),
                "correct": bool(r.correct),
                "episode_idx": int(r.episode_idx),
                "label": r.label,
            })
        n_per_target[target_class] = n_take
        print(f"  {target_class:<12}: pool={avail:>6}, sampled={n_take}")

    print(f"\nTotal candidates: {len(all_picks)}")

    # ---- Materialize spans (resumable on spans_path) ----
    if spans_path.exists():
        existing_rows = [json.loads(l) for l in open(spans_path)]
        seen = {r["span_id"] for r in existing_rows}
        print(f"Resume: {len(existing_rows)} already materialized at {spans_path}")
        materialized = existing_rows
    else:
        materialized = []
        seen = set()

    print("Materializing...")
    rows = materialize_spans(all_picks, args.raw_results_root, eps_per_trace)
    spans_file = open(spans_path, "a")
    n_drift = 0
    for r in rows:
        if r["span_id"] in seen or r["span_id"] in existing_ids:
            n_drift += 1
            continue
        spans_file.write(json.dumps(r) + "\n")
        materialized.append(r)
        seen.add(r["span_id"])
    spans_file.close()
    print(f"  materialized={len(materialized)} new (drift/dup skipped: {n_drift})")

    # ---- Run V3-SC labeling via judge_runner ----
    print()
    print("Running V3-SC...")
    cmd = [
        sys.executable, "-m", "analysis.exploration.llm_validation.judge_runner",
        "--in", str(spans_path), "--out", str(judge_path),
        "--model", "deepseek-chat",
        "--temperature", str(args.temperature),
        "--workers", str(args.workers),
        "--n-samples", str(args.n_samples),
    ]
    print("  cmd:", " ".join(cmd))
    subprocess.run(cmd, check=True)

    # ---- Join + write ----
    judgments = {}
    with open(judge_path) as f:
        for line in f:
            j = json.loads(line)
            judgments[j["span_id"]] = j

    final_rows = []
    n_dropped = {"missing": 0, "parse_error": 0, "api_failure": 0}
    for span in materialized:
        j = judgments.get(span["span_id"])
        if j is None:
            n_dropped["missing"] += 1
            continue
        label = j["llm_label"]
        if label == "PARSE_ERROR":
            n_dropped["parse_error"] += 1
            continue
        if label == "API_FAILURE":
            n_dropped["api_failure"] += 1
            continue
        if label not in PRIMITIVES:
            n_dropped["parse_error"] += 1
            continue
        final_rows.append({
            **span,
            "llm_label": label,
            "vote_count": j.get("vote_count"),
            "all_labels": j.get("llm_labels"),
            "tokens_used": j.get("tokens_used"),
        })

    tmp = args.out.with_suffix(args.out.suffix + ".tmp")
    with open(tmp, "w") as f:
        for r in final_rows:
            f.write(json.dumps(r) + "\n")
    os.replace(tmp, args.out)

    print()
    print(f"Top-up: {len(final_rows)} new labels -> {args.out}")
    print(f"Dropped: {n_dropped}")
    print()
    counts = Counter(r["llm_label"] for r in final_rows)
    total = sum(counts.values())
    print("New label distribution:")
    for p in PRIMITIVES:
        c = counts.get(p, 0)
        print(f"  {p:13} {c:>5} ({c/total*100:>5.1f}%)")


if __name__ == "__main__":
    main()
