"""Stage 1: mint silver labels by V3-SC on randomly sampled spans.

Random-shuffles the parquet's (trace, episode) pool, materializes the first
N spans (deterministic-first-span-of-episode like sample_spans.py), then
invokes judge_runner.py via subprocess with --n-samples 5 --temperature 0.7
--model deepseek-chat for V3-SC labels. Joins the LLM labels back onto the
span records and writes silver_train.jsonl.

Reuses materialization machinery from sample_spans.py; differs only in the
sampling step (random vs primitive-stratified).
"""
from __future__ import annotations

import argparse
import json
import os
import random
import subprocess
import sys
from collections import Counter
from pathlib import Path

import pandas as pd

from analysis.exploration.llm_validation._client import PRIMITIVES
from analysis.exploration.llm_validation.sample_spans import (
    collect_responses_by_pair,
    explode_episode_index,
    materialise_span_for_pick,
)


def n_episodes_per_trace(episodes: pd.DataFrame) -> dict[tuple, int]:
    """Map (checkpoint_id, task_name, doc_id, trace_id) -> trace's episode count."""
    grouped = (
        episodes.groupby(["checkpoint_id", "task_name", "doc_id", "trace_id"])
        .size()
        .to_dict()
    )
    return grouped


def materialize_spans(
    picks: list[dict], raw_root: Path, eps_per_trace: dict[tuple, int],
) -> list[dict]:
    """Resolve raw responses for `picks`, materialize spans, return rows.

    Mirrors sample_for_primitive in sample_spans.py but without the
    primitive-filter retry loop.
    """
    responses = collect_responses_by_pair(picks, raw_root)
    out: list[dict] = []
    for p in picks:
        key = (p["checkpoint_id"], p["task_name"], p["doc_id"], p["trace_id"])
        resp = responses.get(key)
        if resp is None:
            continue
        row = materialise_span_for_pick(p, resp)
        if row is None:
            continue
        row["n_episodes_in_trace"] = eps_per_trace.get(key, 1)
        out.append(row)
    return out


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--parquet", required=True, type=Path)
    ap.add_argument("--raw-results-root", required=True, type=Path)
    ap.add_argument("--n", type=int, default=10_000)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument(
        "--workers", type=int, default=100,
        help="Pass-through to judge_runner",
    )
    ap.add_argument(
        "--n-samples", type=int, default=5,
        help="Self-consistency samples for V3-SC",
    )
    ap.add_argument(
        "--temperature", type=float, default=0.7,
        help="Sampling temperature for V3-SC",
    )
    ap.add_argument(
        "--keep-intermediate", action="store_true",
        help="Don't delete the materialized-spans + judgments intermediate files",
    )
    args = ap.parse_args()

    args.out.parent.mkdir(parents=True, exist_ok=True)
    spans_path = args.out.with_name(args.out.stem + "_spans.jsonl")
    judge_path = args.out.with_name(args.out.stem + "_judgments.jsonl")

    print(f"Loading parquet: {args.parquet}")
    df = pd.read_parquet(args.parquet)
    print(f"  traces: {len(df)}")

    print("Building episode index...")
    episodes = explode_episode_index(df)
    print(f"  episodes: {len(episodes)}")
    eps_per_trace = n_episodes_per_trace(episodes)

    rng = random.Random(args.seed)
    indices = list(range(len(episodes)))
    rng.shuffle(indices)

    # ---- Materialize spans in batches with oversample for drift attrition ----
    if spans_path.exists():
        existing = list(open(spans_path))
        spans_so_far = [json.loads(l) for l in existing]
        print(f"Resume: found {len(spans_so_far)} already-materialized spans at {spans_path}")
    else:
        spans_so_far = []

    cursor = 0
    BATCH = 1000
    seen_ids = {r["span_id"] for r in spans_so_far}

    print(f"Materializing up to {args.n} spans...")
    spans_file = open(spans_path, "a")
    while len(spans_so_far) < args.n and cursor < len(indices):
        end = min(cursor + BATCH, len(indices))
        picks = [episodes.iloc[indices[i]].to_dict() for i in indices[cursor:end]]
        cursor = end
        rows = materialize_spans(picks, args.raw_results_root, eps_per_trace)
        for r in rows:
            if r["span_id"] in seen_ids:
                continue
            spans_file.write(json.dumps(r) + "\n")
            spans_file.flush()
            spans_so_far.append(r)
            seen_ids.add(r["span_id"])
            if len(spans_so_far) >= args.n:
                break
        print(f"  cursor={cursor}/{len(indices)}  spans={len(spans_so_far)}", flush=True)
    spans_file.close()

    if len(spans_so_far) < args.n:
        print(f"WARNING: only materialized {len(spans_so_far)} spans (target {args.n}).")

    # ---- Run V3-SC labeling via judge_runner ----
    print()
    print("Running V3-SC via judge_runner.py...")
    cmd = [
        sys.executable, "-m", "analysis.exploration.llm_validation.judge_runner",
        "--in", str(spans_path), "--out", str(judge_path),
        "--model", "deepseek-chat",
        "--temperature", str(args.temperature),
        "--workers", str(args.workers),
        "--n-samples", str(args.n_samples),
    ]
    print("  cmd:", " ".join(cmd))
    subprocess.run(cmd, check=True)

    # ---- Join span records with judgments ----
    print()
    print("Joining spans + judgments...")
    judgments: dict[str, dict] = {}
    with open(judge_path) as f:
        for line in f:
            j = json.loads(line)
            judgments[j["span_id"]] = j

    final_rows: list[dict] = []
    n_dropped = {"missing": 0, "parse_error": 0, "api_failure": 0}
    for span in spans_so_far:
        j = judgments.get(span["span_id"])
        if j is None:
            n_dropped["missing"] += 1
            continue
        label = j["llm_label"]
        if label == "PARSE_ERROR":
            n_dropped["parse_error"] += 1
            continue
        if label == "API_FAILURE":
            n_dropped["api_failure"] += 1
            continue
        if label not in PRIMITIVES:
            n_dropped["parse_error"] += 1
            continue
        final_rows.append({
            **span,
            "llm_label": label,
            "vote_count": j.get("vote_count"),
            "all_labels": j.get("llm_labels"),
            "tokens_used": j.get("tokens_used"),
        })

    tmp = args.out.with_suffix(args.out.suffix + ".tmp")
    with open(tmp, "w") as f:
        for r in final_rows:
            f.write(json.dumps(r) + "\n")
    os.replace(tmp, args.out)

    print(f"Wrote {len(final_rows)} silver labels -> {args.out}")
    print(f"Dropped: {n_dropped}")
    print()
    print("Label distribution:")
    counts = Counter(r["llm_label"] for r in final_rows)
    total = sum(counts.values())
    for p in PRIMITIVES:
        c = counts.get(p, 0)
        print(f"  {p:13} {c:>5} ({c/total*100:>5.1f}%)")

    if not args.keep_intermediate:
        print()
        print("(intermediate files retained for resumption: "
              f"{spans_path}, {judge_path})")


if __name__ == "__main__":
    main()
