"""Stage 1 — primitive-first stratified sample of spans for LLM judging.

Produces sampled_spans.jsonl with one row per span:
    span_id, span_text, preceding_context,
    heuristic_label, heuristic_confidence, n_tokens,
    checkpoint_id, task_name, doc_id, trace_id, episode_idx, correct.

Stratification: 25 per primitive x 10 primitives = 250 spans. Within each
primitive, soft round-robin balance over (checkpoint x correct x task_name)
strata. Empty strata redistribute to the rest.
"""
from __future__ import annotations

import argparse
import json
import os
import random
import re
from collections import defaultdict
from pathlib import Path

import pandas as pd

from analysis.exploration.segmentation import (
    extract_reasoning,
    segment_response,
)
from analysis.exploration.primitive_classification import (
    classify_span,
    classify_trace_spans,
    merge_episodes,
)

from analysis.exploration.llm_validation._client import PRIMITIVES
from analysis.exploration.llm_validation.audit_drift import (
    CHECKPOINT_DIR_MAP,
    find_samples_jsonl,
)


_SENTENCE_END_RE = re.compile(r"[.!?]\s+")


def extract_preceding_context(reasoning_text: str, start_char: int,
                              max_chars: int = 500) -> str:
    """Last 1-2 sentences before `start_char`. Empty if at the beginning."""
    if start_char <= 0:
        return ""
    window = reasoning_text[max(0, start_char - 400):start_char]
    ends = [m.end() for m in _SENTENCE_END_RE.finditer(window)]
    if len(ends) >= 2:
        ctx = window[ends[-2]:]
    elif ends:
        ctx = window[ends[-1]:]
    else:
        ctx = window[-200:]
    return ctx.strip()[:max_chars]


def explode_episode_index(df: pd.DataFrame) -> pd.DataFrame:
    """Expand `primitive_sequence` into one row per (trace, episode_idx, label)."""
    rows = []
    for r in df.itertuples():
        seq = json.loads(r.primitive_sequence)
        for ep_idx, label in enumerate(seq):
            rows.append({
                "checkpoint_id": r.checkpoint_id,
                "task_name": r.task_name,
                "doc_id": int(r.doc_id),
                "trace_id": int(r.trace_id),
                "correct": bool(r.correct),
                "episode_idx": ep_idx,
                "label": label,
            })
    return pd.DataFrame(rows)


def stratified_pick(pool: pd.DataFrame, n: int, rng: random.Random) -> list[dict]:
    """Round-robin across (checkpoint, correct, task_name) strata until n picks.

    Within each stratum, draws are random under `rng`. Empty strata are
    skipped (effectively redistributing to the others).
    """
    if pool.empty:
        return []
    strata = defaultdict(list)
    for r in pool.itertuples():
        key = (r.checkpoint_id, r.correct, r.task_name)
        strata[key].append(r._asdict())
    for k in strata:
        rng.shuffle(strata[k])
    keys = sorted(strata.keys())  # deterministic stratum order

    picks: list[dict] = []
    while len(picks) < n:
        progressed = False
        for k in keys:
            if not strata[k]:
                continue
            picks.append(strata[k].pop())
            progressed = True
            if len(picks) >= n:
                break
        if not progressed:
            break  # all strata exhausted
    return picks


def collect_responses_by_pair(picks: list[dict], raw_root: Path) -> dict[tuple, str]:
    """Resolve raw response text for every (ckpt, task, doc, trace) we picked."""
    by_jsonl: dict[Path, set[tuple[int, int]]] = defaultdict(set)
    ctx_for_jsonl: dict[Path, tuple[str, str]] = {}
    for p in picks:
        jsonl = find_samples_jsonl(raw_root, p["checkpoint_id"], p["task_name"])
        by_jsonl[jsonl].add((p["doc_id"], p["trace_id"]))
        ctx_for_jsonl[jsonl] = (p["checkpoint_id"], p["task_name"])

    out: dict[tuple, str] = {}
    for jsonl_path, wanted in by_jsonl.items():
        ckpt, task = ctx_for_jsonl[jsonl_path]
        wanted_doc_ids = {d for d, _ in wanted}
        with open(jsonl_path) as f:
            for line in f:
                d = json.loads(line)
                doc_id = int(d.get("doc_id", d.get("idx", 0)))
                if doc_id not in wanted_doc_ids:
                    continue
                resps = d.get("resps", [[]])
                responses = resps[0] if resps and isinstance(resps[0], list) else resps
                for tid, resp in enumerate(responses):
                    if (doc_id, tid) in wanted:
                        out[(ckpt, task, doc_id, tid)] = resp
    return out


def materialise_span_for_pick(pick: dict, response_text: str) -> dict | None:
    """Re-segment + classify, verify episode label, and return a span row.

    Returns None if re-segmentation drifts (label mismatch at episode_idx)
    or if no spans are produced.
    """
    reasoning_text = extract_reasoning(response_text) or response_text
    spans = segment_response(response_text)
    if not spans:
        return None
    labeled = classify_trace_spans(spans)
    episodes = merge_episodes(labeled)
    if not episodes:
        return None

    if pick["episode_idx"] >= len(episodes):
        return None
    target_ep = episodes[pick["episode_idx"]]
    if target_ep.label != pick["label"]:
        return None

    span = target_ep.spans[0]  # deterministic: first span of episode
    _, conf = classify_span(span.text)

    return {
        "span_id": (
            f"{pick['checkpoint_id']}|{pick['task_name']}|"
            f"{pick['doc_id']}|{pick['trace_id']}|{span.span_id}"
        ),
        "span_text": span.text,
        "preceding_context": extract_preceding_context(
            reasoning_text, span.start_char
        ),
        "heuristic_label": target_ep.label,
        "heuristic_confidence": conf,
        "n_tokens": span.n_tokens,
        "checkpoint_id": pick["checkpoint_id"],
        "task_name": pick["task_name"],
        "doc_id": pick["doc_id"],
        "trace_id": pick["trace_id"],
        "episode_idx": pick["episode_idx"],
        "correct": pick["correct"],
    }


def sample_for_primitive(
    primitive: str,
    pool: pd.DataFrame,
    n: int,
    raw_root: Path,
    rng: random.Random,
    max_attempts: int = 100,
) -> tuple[list[dict], int]:
    """Pick rows for one primitive, materialise spans, oversample on drift.

    Returns (rows, n_drift_skipped).
    """
    attempts = 0
    rows: list[dict] = []
    drift = 0

    while len(rows) < n and attempts < max_attempts:
        need = n - len(rows)
        # Pull a chunk a bit larger than need to absorb drift attrition
        chunk_n = max(need * 2, need + 5)
        picks = stratified_pick(pool, chunk_n, rng)
        if not picks:
            break
        responses = collect_responses_by_pair(picks, raw_root)
        for p in picks:
            if len(rows) >= n:
                break
            key = (p["checkpoint_id"], p["task_name"], p["doc_id"], p["trace_id"])
            resp = responses.get(key)
            if resp is None:
                drift += 1
                continue
            row = materialise_span_for_pick(p, resp)
            if row is None:
                drift += 1
                continue
            rows.append(row)
        # Remove already-used picks from pool to avoid re-drawing them
        used = {
            (p["checkpoint_id"], p["task_name"], p["doc_id"],
             p["trace_id"], p["episode_idx"])
            for p in picks
        }
        if used:
            mask = pool.apply(
                lambda r: (
                    r.checkpoint_id, r.task_name,
                    r.doc_id, r.trace_id, r.episode_idx
                ) not in used,
                axis=1,
            )
            pool = pool[mask]
        attempts += 1
        if pool.empty:
            break

    return rows, drift


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--parquet", required=True, type=Path)
    ap.add_argument("--raw-results-root", required=True, type=Path)
    ap.add_argument("--n-per-primitive", type=int, default=25)
    ap.add_argument("--seed", type=int, default=17)
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument(
        "--primitives", default=",".join(PRIMITIVES),
        help="Comma-separated subset; default is all 10",
    )
    args = ap.parse_args()

    target_primitives = [p.strip() for p in args.primitives.split(",")]
    for p in target_primitives:
        if p not in PRIMITIVES:
            raise SystemExit(f"unknown primitive: {p}")

    print(f"Loading parquet: {args.parquet}")
    df = pd.read_parquet(args.parquet)

    print(f"Building episode index from {len(df)} traces...")
    episodes = explode_episode_index(df)
    print(f"  total episodes: {len(episodes)}")

    rng = random.Random(args.seed)
    args.out.parent.mkdir(parents=True, exist_ok=True)

    all_rows: list[dict] = []
    coverage: dict[str, dict[tuple, int]] = {}
    for primitive in target_primitives:
        pool = episodes[episodes.label == primitive].copy()
        print(f"\n[{primitive}] pool size: {len(pool)}")
        if pool.empty:
            print(f"  WARNING: no episodes with label {primitive}")
            continue
        rows, drift = sample_for_primitive(
            primitive, pool, args.n_per_primitive, args.raw_results_root, rng,
        )
        print(f"  collected {len(rows)} spans (drift skipped: {drift})")
        per_stratum = defaultdict(int)
        for r in rows:
            per_stratum[(r["checkpoint_id"], r["correct"], r["task_name"])] += 1
        coverage[primitive] = dict(per_stratum)
        all_rows.extend(rows)

    tmp = args.out.with_suffix(args.out.suffix + ".tmp")
    with open(tmp, "w") as f:
        for row in all_rows:
            f.write(json.dumps(row) + "\n")
    os.replace(tmp, args.out)
    print(f"\nWrote {len(all_rows)} rows -> {args.out}")

    print("\nStratum coverage (checkpoint, correct, task) -> count:")
    for p in target_primitives:
        if p not in coverage:
            continue
        print(f"  {p}:")
        for k, v in sorted(coverage[p].items()):
            print(f"    {k}: {v}")


if __name__ == "__main__":
    main()
