"""Pilot sampler for math spans across DIVERSE model checkpoints.

Mirrors the puzzle pilot approach: stratify across (source, benchmark),
segment each trace with the existing puzzle segmenter, and write a
sampled_spans-style JSONL. Designed for prompt iteration: ~240 spans is
enough to compute V3-SC vs R1-SC agreement and inspect disagreements.

Key difference from the OOD math sampler we used earlier:
1. **Diverse sources** — mix of OLMo-3 (SFT, DSR, GSPO) + Qwen3-8B (base, DSR).
2. **Tag-aware extraction** — handles <reasoning>, <think>, and bare-text traces.
"""
from __future__ import annotations

import argparse
import json
import os
import random
import re
from collections import Counter
from pathlib import Path

from analysis.exploration.segmentation import (
    Span, segment_trace, extract_reasoning,
)


# (source_label, benchmark_label, file_path)
DIVERSE_SOURCES = [
    # OLMo-3 family ------------------------------------------------------
    (
        "olmo3_sft", "aime24",
        "<PROJECT_DIR>/results/pass_at_k_s3/olmo3_base/aime24_r1_pass64/allenai__OLMo-3-7B-Instruct-SFT/samples_aime24_r1_pass64_2026-04-09T17-05-11.508418.jsonl",
    ),
    (
        "olmo3_dsr", "aime24",
        "<PROJECT_DIR>/results/pass_at_k_s3/olmo3_dsr/aime24_r1_pass64/checkpoints__olmo3_7b_multi_puzzle_dsr__merged_best/samples_aime24_r1_pass64_2026-04-09T16-43-15.630497.jsonl",
    ),
    (
        "olmo3_gspo_v2_step20", "aime24",
        "<PROJECT_DIR>/results/gspo_v2_sft_s20_math_eval_v2/checkpoints__olmo3-puzzle-grpo__multi_puzzle_gspo_olmo3_v2_sft_v2__merged_step_20/samples_aime24_r1_pass64_2026-04-19T10-14-11.780093.jsonl",
    ),
    (
        "olmo3_gspo_v2_step20", "olymp_math_hard",
        "<PROJECT_DIR>/results/gspo_v2_sft_s20_math_eval_v2/checkpoints__olmo3-puzzle-grpo__multi_puzzle_gspo_olmo3_v2_sft_v2__merged_step_20/samples_olymp_math_hard_pass32_2026-04-19T15-35-20.004553.jsonl",
    ),
    # Qwen3 family -------------------------------------------------------
    (
        "qwen3_8b_base", "aime24",
        "<PROJECT_DIR>/results/qwen3_8b_baseline/aime24_pass8/Qwen__Qwen3-8B/samples_aime24_r1_pass8_2026-03-31T08-29-51.903279.jsonl",
    ),
    (
        "qwen3_dsr_v3", "aime24",
        "<PROJECT_DIR>/results/qwen3_thinking_mode_avg8/dsr_v3_sft/aime24_r1_think_avg8/checkpoints__qwen3_8b_multi_puzzle_dsr_v3__merged_step_462/samples_aime24_r1_think_avg8_2026-04-01T16-43-28.032540.jsonl",
    ),
    (
        "qwen3_gspo_v1_step35", "aime24",
        "<PROJECT_DIR>/results/qwen3_thinking_mode_avg8/gspo_v1_s35/aime24_r1_think_avg8/checkpoints__qwen3-puzzle-grpo__multi_puzzle_gspo_qwen3_v1__merged_step_35/samples_aime24_r1_think_avg8_2026-04-01T16-55-54.828082.jsonl",
    ),
]


_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
_SENTENCE_END_RE = re.compile(r"[.!?]\s+")


def extract_reasoning_tag_aware(response: str) -> str:
    """Try <reasoning>, then <think>, then full response."""
    rsg = extract_reasoning(response)
    if rsg is not None:
        return rsg
    m = _THINK_RE.search(response)
    if m:
        return m.group(1).strip()
    return response


def extract_preceding_context(reasoning_text: str, start_char: int, max_chars: int = 500) -> str:
    if start_char <= 0:
        return ""
    window = reasoning_text[max(0, start_char - 400):start_char]
    ends = [m.end() for m in _SENTENCE_END_RE.finditer(window)]
    if len(ends) >= 2:
        ctx = window[ends[-2]:]
    elif ends:
        ctx = window[ends[-1]:]
    else:
        ctx = window[-200:]
    return ctx.strip()[:max_chars]


def sample_spans_from_trace(
    response: str, n_spans_target: int, rng: random.Random,
) -> list[dict]:
    reasoning = extract_reasoning_tag_aware(response)
    spans = segment_trace(reasoning, tokenizer_name=None)
    if not spans:
        return []
    n = min(n_spans_target, len(spans))
    chosen_idx = sorted(rng.sample(range(len(spans)), n))
    out = []
    for i in chosen_idx:
        sp = spans[i]
        out.append({
            "span_text": sp.text,
            "preceding_context": extract_preceding_context(reasoning, sp.start_char),
            "n_tokens": sp.n_tokens,
            "span_idx": sp.span_id,
            "n_spans_in_trace": len(spans),
        })
    return out


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument("--n-traces-per-source", type=int, default=12,
                    help="Sample N traces per (source, benchmark) cell")
    ap.add_argument("--n-spans-per-trace", type=int, default=4,
                    help="Up to N spans per sampled trace")
    ap.add_argument("--seed", type=int, default=42)
    args = ap.parse_args()

    rng = random.Random(args.seed)
    args.out.parent.mkdir(parents=True, exist_ok=True)

    rows = []
    for source, benchmark, path in DIVERSE_SOURCES:
        if not Path(path).exists():
            print(f"  SKIP missing: {source} | {benchmark} | {path}")
            continue
        with open(path) as f:
            lines = f.readlines()
        sampled_lines = rng.sample(lines, min(args.n_traces_per_source, len(lines)))
        kept = 0
        for line in sampled_lines:
            d = json.loads(line)
            doc_id = d.get("doc_id", d.get("idx", 0))
            resps = d.get("resps", [[]])
            responses = resps[0] if resps and isinstance(resps[0], list) else resps
            if not responses:
                continue
            response = responses[0]
            spans = sample_spans_from_trace(response, args.n_spans_per_trace, rng)
            for sp in spans:
                sp.update({
                    "span_id": f"{source}|{benchmark}|{doc_id}|0|{sp['span_idx']}",
                    "checkpoint_id": source,
                    "task_name": benchmark,
                    "doc_id": int(doc_id),
                    "trace_id": 0,
                    "episode_idx": sp["span_idx"],
                    "n_episodes_in_trace": sp["n_spans_in_trace"],
                    "correct": False,
                })
                rows.append(sp)
                kept += 1
        print(f"  {source:>26} | {benchmark:<18} | kept {kept} spans")

    # De-dupe
    seen = set()
    deduped = []
    for r in rows:
        if r["span_id"] in seen:
            continue
        seen.add(r["span_id"])
        deduped.append(r)
    rows = deduped

    tmp = args.out.with_suffix(args.out.suffix + ".tmp")
    with open(tmp, "w") as f:
        for r in rows:
            f.write(json.dumps(r) + "\n")
    os.replace(tmp, args.out)

    print(f"\nWrote {len(rows)} math spans -> {args.out}")
    print()
    print("By source:")
    for k, v in Counter((r["checkpoint_id"], r["task_name"]) for r in rows).most_common():
        print(f"  {k}: {v}")


if __name__ == "__main__":
    main()
