"""Stage 6a: sample spans from math benchmarks for OOD evaluation.

Reads ~8 representative samples_*.jsonl files (4 benchmarks × 2 checkpoints),
segments each picked trace with the existing puzzle segmenter, and writes
~500 spans in the same schema as silver_train.jsonl (sans `llm_label`).
"""
from __future__ import annotations

import argparse
import json
import os
import random
import re
from pathlib import Path

from analysis.exploration.segmentation import (
    extract_reasoning,
    segment_response,
)


# Representative files: 4 benchmarks × 2 checkpoints (one SFT-ish, one GSPO).
# All these were verified present at planning time.
DEFAULT_FILES = {
    # benchmark         checkpoint_id        path
    ("aime24",       "gspo_v2_step20"): "<PROJECT_DIR>/results/gspo_v2_sft_s20_math_eval_v2/checkpoints__olmo3-puzzle-grpo__multi_puzzle_gspo_olmo3_v2_sft_v2__merged_step_20/samples_aime24_r1_pass64_2026-04-19T10-14-11.780093.jsonl",
    ("aime25",       "gspo_v2_step20"): "<PROJECT_DIR>/results/gspo_v2_sft_s20_math_eval_v2/checkpoints__olmo3-puzzle-grpo__multi_puzzle_gspo_olmo3_v2_sft_v2__merged_step_20/samples_aime25_r1_pass64_2026-04-19T17-12-55.026735.jsonl",
    ("olymp_easy",   "gspo_v2_step20"): "<PROJECT_DIR>/results/gspo_v2_sft_s20_math_eval_v2/checkpoints__olmo3-puzzle-grpo__multi_puzzle_gspo_olmo3_v2_sft_v2__merged_step_20/samples_olymp_math_easy_avg8_2026-04-19T10-56-42.758673.jsonl",
    ("olymp_hard",   "gspo_v2_step20"): "<PROJECT_DIR>/results/gspo_v2_sft_s20_math_eval_v2/checkpoints__olmo3-puzzle-grpo__multi_puzzle_gspo_olmo3_v2_sft_v2__merged_step_20/samples_olymp_math_hard_pass32_2026-04-19T15-35-20.004553.jsonl",
    ("aime25",       "sft_baseline"):    "<PROJECT_DIR>/results/sft_baseline_math_eval_diverse/checkpoints__olmo3_7b_multi_puzzle_dsr_v2__merged_ep5_fp32/samples_aime25_r1_pass64_2026-04-20T20-30-40.103331.jsonl",
    ("olymp_hard",   "sft_baseline"):    "<PROJECT_DIR>/results/sft_baseline_math_eval_diverse/checkpoints__olmo3_7b_multi_puzzle_dsr_v2__merged_ep5_fp32/samples_olymp_math_hard_pass32_2026-04-20T21-22-45.641899.jsonl",
}


_SENTENCE_END_RE = re.compile(r"[.!?]\s+")


def extract_preceding_context(reasoning_text: str, start_char: int, max_chars: int = 500) -> str:
    """Same logic as sample_spans.extract_preceding_context."""
    if start_char <= 0:
        return ""
    window = reasoning_text[max(0, start_char - 400):start_char]
    ends = [m.end() for m in _SENTENCE_END_RE.finditer(window)]
    if len(ends) >= 2:
        ctx = window[ends[-2]:]
    elif ends:
        ctx = window[ends[-1]:]
    else:
        ctx = window[-200:]
    return ctx.strip()[:max_chars]


def sample_spans_from_trace(
    response: str, n_spans_target: int, rng: random.Random,
) -> list[dict]:
    """Segment a single math trace; sample up to n_spans_target spans."""
    reasoning = extract_reasoning(response) or response
    spans = segment_response(response)
    if not spans:
        return []
    n = min(n_spans_target, len(spans))
    chosen_idx = sorted(rng.sample(range(len(spans)), n))
    out = []
    for i in chosen_idx:
        sp = spans[i]
        out.append({
            "span_text": sp.text,
            "preceding_context": extract_preceding_context(reasoning, sp.start_char),
            "n_tokens": sp.n_tokens,
            "span_idx": sp.span_id,
            "n_spans_in_trace": len(spans),
        })
    return out


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument(
        "--n-traces-per-file", type=int, default=8,
        help="Number of doc_ids to sample per JSONL file",
    )
    ap.add_argument(
        "--n-spans-per-trace", type=int, default=10,
        help="Max spans per sampled trace",
    )
    ap.add_argument("--seed", type=int, default=42)
    args = ap.parse_args()

    rng = random.Random(args.seed)
    args.out.parent.mkdir(parents=True, exist_ok=True)

    rows: list[dict] = []
    for (benchmark, ckpt), path in DEFAULT_FILES.items():
        if not Path(path).exists():
            print(f"  SKIP missing: {path}")
            continue
        # Read all doc_ids first (each line is a doc), then sample.
        with open(path) as f:
            lines = f.readlines()
        sampled = rng.sample(lines, min(args.n_traces_per_file, len(lines)))
        for line in sampled:
            d = json.loads(line)
            doc_id = d.get("doc_id", d.get("idx", 0))
            resps = d.get("resps", [[]])
            responses = resps[0] if resps and isinstance(resps[0], list) else resps
            if not responses:
                continue
            # Pick the first response (resps[0][0]) — could randomize later.
            response = responses[0]
            spans = sample_spans_from_trace(response, args.n_spans_per_trace, rng)
            for sp in spans:
                sp.update({
                    "span_id": f"{ckpt}|{benchmark}|{doc_id}|0|{sp['span_idx']}",
                    "checkpoint_id": ckpt,
                    "task_name": benchmark,
                    "doc_id": int(doc_id),
                    "trace_id": 0,
                    "episode_idx": sp["span_idx"],
                    "n_episodes_in_trace": sp["n_spans_in_trace"],
                    "correct": False,  # not used for OOD; just placeholder
                })
                rows.append(sp)
        print(f"  {benchmark:>11} / {ckpt:<14}: kept {sum(1 for r in rows if r['task_name']==benchmark and r['checkpoint_id']==ckpt)} spans")

    tmp = args.out.with_suffix(args.out.suffix + ".tmp")
    with open(tmp, "w") as f:
        for r in rows:
            f.write(json.dumps(r) + "\n")
    os.replace(tmp, args.out)
    print(f"\nWrote {len(rows)} math spans -> {args.out}")

    from collections import Counter
    print("\nBy benchmark:")
    for bm, c in sorted(Counter((r["task_name"], r["checkpoint_id"]) for r in rows).items()):
        print(f"  {bm}: {c}")


if __name__ == "__main__":
    main()
