"""Math silver scale-up sampler.

Picks ~3000 NEW math spans from previously-untouched (source, benchmark)
pairs covering aime25, olymp_math_easy, gsm8k, plus a few new aime24/hard
checkpoints. Uses tag-aware extraction (handles <reasoning>, <think>,
bare-text). Deduplicates against existing math silver.

Mirrors `sample_math_diverse.py` but with a fresh source list and dedup
against the prior 2.4k silver labels.
"""
from __future__ import annotations

import argparse
import json
import os
import random
import re
from collections import Counter
from pathlib import Path

from analysis.exploration.segmentation import segment_trace, extract_reasoning


# (source_label, benchmark_label, file_path)
NEW_SOURCES = [
    # === aime25 (new benchmark — completely untouched in previous silver) ===
    (
        "qwen3_dsr_v3", "aime25",
        "<PROJECT_DIR>/results/qwen3_thinking_mode_avg8/dsr_v3_sft/aime25_r1_think_avg8/checkpoints__qwen3_8b_multi_puzzle_dsr_v3__merged_step_462/samples_aime25_r1_think_avg8_2026-04-01T17-01-49.319489.jsonl",
    ),
    (
        "qwen3_gspo_v1_step35", "aime25",
        "<PROJECT_DIR>/results/qwen3_thinking_mode_avg8/gspo_v1_s35/aime25_r1_think_avg8/checkpoints__qwen3-puzzle-grpo__multi_puzzle_gspo_qwen3_v1__merged_step_35/samples_aime25_r1_think_avg8_2026-04-01T17-13-42.868662.jsonl",
    ),
    (
        "qwen3_8b_baseline", "aime25",
        "<PROJECT_DIR>/results/qwen3_8b_baseline/aime25_think_avg8/Qwen__Qwen3-8B/samples_aime25_r1_think_avg8_2026-04-01T21-15-39.151618.jsonl",
    ),
    (
        "olmo3_gspo_v2_step20", "aime25",
        "<PROJECT_DIR>/results/gspo_v2_sft_s20_math_eval_v2/checkpoints__olmo3-puzzle-grpo__multi_puzzle_gspo_olmo3_v2_sft_v2__merged_step_20/samples_aime25_r1_pass64_2026-04-19T17-12-55.026735.jsonl",
    ),

    # === olymp_math_easy (new benchmark) ===
    (
        "olmo3_gspo_v2_step20", "olymp_math_easy",
        "<PROJECT_DIR>/results/gspo_v2_sft_s20_math_eval_v2/checkpoints__olmo3-puzzle-grpo__multi_puzzle_gspo_olmo3_v2_sft_v2__merged_step_20/samples_olymp_math_easy_avg8_2026-04-19T10-56-42.758673.jsonl",
    ),
    (
        "novelty_mini_step20", "olymp_math_easy",
        "<PROJECT_DIR>/results/novelty_mini_step20/olymp_math_easy_pass32/checkpoints__olmo3-puzzle-grpo__novelty_mini_gspo_topk100_a01_n4__merged_step_20/samples_olymp_math_easy_pass32_2026-04-23T20-19-21.191409.jsonl",
    ),

    # === gsm8k (new benchmark, simpler problems — good distribution diversity) ===
    (
        "olmo3_exploration_fix_step10", "gsm8k",
        "<PROJECT_DIR>/results/olmo3_exploration_fix/step10/gsm8k_r1_pass8/samples_gsm8k_r1_pass8_2026-04-08T18-58-32.041443.jsonl",
    ),

    # === olymp_math_hard from new checkpoint ===
    (
        "qwen3_gspo_v1_step35", "olymp_math_hard",
        "<PROJECT_DIR>/results/qwen3_thinking_mode_avg8/gspo_v1_s35/olymp_math_hard_avg8/checkpoints__qwen3-puzzle-grpo__multi_puzzle_gspo_qwen3_v1__merged_step_35/samples_olymp_math_hard_avg8_2026-04-01T18-44-19.014614.jsonl",
    ),

    # === olmo3 sources downloaded from S3 (different checkpoints than already sampled) ===
    (
        "olmo3_dsr_v2_sft_baseline", "aime25",
        "<PROJECT_DIR>/results/_silver_scaleup_s3/olmo3_dsr_v2_sft_baseline__aime25.jsonl",
    ),
    (
        "olmo3_dsr_v2_sft_baseline", "olymp_math_hard",
        "<PROJECT_DIR>/results/_silver_scaleup_s3/olmo3_dsr_v2_sft_baseline__olymp_math_hard.jsonl",
    ),
    (
        "olmo3_sft_dsr", "olymp_math_easy",
        "<PROJECT_DIR>/results/_silver_scaleup_s3/olmo3_sft_dsr__olymp_math_easy.jsonl",
    ),
    (
        "olmo3_sft_dsr", "olymp_math_hard",
        "<PROJECT_DIR>/results/_silver_scaleup_s3/olmo3_sft_dsr__olymp_math_hard.jsonl",
    ),
    (
        "olmo3_explfix_step20", "aime24",
        "<PROJECT_DIR>/results/_silver_scaleup_s3/olmo3_explfix_step20__aime24.jsonl",
    ),
    (
        "olmo3_sft_dynamics_ep8", "olymp_math_easy",
        "<PROJECT_DIR>/results/_silver_scaleup_s3/olmo3_sft_dynamics_ep8__olymp_math_easy.jsonl",
    ),
]


_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
_SENTENCE_END_RE = re.compile(r"[.!?]\s+")


def extract_reasoning_tag_aware(response: str) -> str:
    rsg = extract_reasoning(response)
    if rsg is not None:
        return rsg
    m = _THINK_RE.search(response)
    if m:
        return m.group(1).strip()
    return response


def extract_preceding_context(reasoning_text: str, start_char: int, max_chars: int = 500) -> str:
    if start_char <= 0:
        return ""
    window = reasoning_text[max(0, start_char - 400):start_char]
    ends = [m.end() for m in _SENTENCE_END_RE.finditer(window)]
    if len(ends) >= 2:
        ctx = window[ends[-2]:]
    elif ends:
        ctx = window[ends[-1]:]
    else:
        ctx = window[-200:]
    return ctx.strip()[:max_chars]


def sample_spans_from_trace(response: str, n_spans_target: int, rng: random.Random) -> list[dict]:
    reasoning = extract_reasoning_tag_aware(response)
    spans = segment_trace(reasoning, tokenizer_name=None)
    if not spans:
        return []
    n = min(n_spans_target, len(spans))
    chosen_idx = sorted(rng.sample(range(len(spans)), n))
    out = []
    for i in chosen_idx:
        sp = spans[i]
        out.append({
            "span_text": sp.text,
            "preceding_context": extract_preceding_context(reasoning, sp.start_char),
            "n_tokens": sp.n_tokens,
            "span_idx": sp.span_id,
            "n_spans_in_trace": len(spans),
        })
    return out


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument("--exclude-ids", type=Path, default=None,
                    help="JSONL of existing silver to deduplicate against (uses span_id field).")
    ap.add_argument("--n-traces-per-source", type=int, default=25)
    ap.add_argument("--n-spans-per-trace", type=int, default=15)
    ap.add_argument("--seed", type=int, default=44)
    args = ap.parse_args()

    rng = random.Random(args.seed)
    args.out.parent.mkdir(parents=True, exist_ok=True)

    existing_ids: set[str] = set()
    if args.exclude_ids and args.exclude_ids.exists():
        with open(args.exclude_ids) as f:
            for line in f:
                existing_ids.add(json.loads(line)["span_id"])
        print(f"Excluding {len(existing_ids)} existing span_ids")

    rows: list[dict] = []
    for source, benchmark, path in NEW_SOURCES:
        if not Path(path).exists():
            print(f"  SKIP missing: {source} | {benchmark} | {path}")
            continue
        with open(path) as f:
            lines = f.readlines()
        sampled_lines = rng.sample(lines, min(args.n_traces_per_source, len(lines)))
        kept = 0
        for line in sampled_lines:
            d = json.loads(line)
            doc_id = d.get("doc_id", d.get("idx", 0))
            resps = d.get("resps", [[]])
            responses = resps[0] if resps and isinstance(resps[0], list) else resps
            if not responses:
                continue
            response = responses[0]
            spans = sample_spans_from_trace(response, args.n_spans_per_trace, rng)
            for sp in spans:
                sp.update({
                    "span_id": f"{source}|{benchmark}|{doc_id}|0|{sp['span_idx']}",
                    "checkpoint_id": source,
                    "task_name": benchmark,
                    "doc_id": int(doc_id),
                    "trace_id": 0,
                    "episode_idx": sp["span_idx"],
                    "n_episodes_in_trace": sp["n_spans_in_trace"],
                    "correct": False,
                })
                if sp["span_id"] in existing_ids:
                    continue
                rows.append(sp)
                kept += 1
        print(f"  {source:>26} | {benchmark:<18} | kept {kept} new spans")

    seen = set()
    deduped = []
    for r in rows:
        if r["span_id"] in seen:
            continue
        seen.add(r["span_id"])
        deduped.append(r)
    rows = deduped

    tmp = args.out.with_suffix(args.out.suffix + ".tmp")
    with open(tmp, "w") as f:
        for r in rows:
            f.write(json.dumps(r) + "\n")
    os.replace(tmp, args.out)

    print(f"\nWrote {len(rows)} new math spans -> {args.out}")
    print()
    print("By source:")
    for k, v in Counter((r["checkpoint_id"], r["task_name"]) for r in rows).most_common():
        print(f"  {k}: {v}")


if __name__ == "__main__":
    main()
