"""Aggressive v2-math silver scale-up.

Targets v2 checkpoint families that are missing or under-represented in the
existing math silver. Designed to produce ~6-8k new v2 math spans, which
are then minted with V3-SC v5 and added to silver.
"""
from __future__ import annotations

import argparse
import json
import os
import random
import re
from collections import Counter
from pathlib import Path

from analysis.exploration.segmentation import segment_trace, extract_reasoning


# (source_label, benchmark_label, file_path)
V2_MATH_SOURCES = [
    # NEW: full v2 puzzle SFT on olymp easy/hard
    (
        "olmo3_dsr_v2_sft_full", "olymp_math_easy",
        "<PROJECT_DIR>/results/_silver_scaleup_s3/olmo3_dsr_v2_sft_full__olymp_math_easy.jsonl",
    ),
    (
        "olmo3_dsr_v2_sft_full", "olymp_math_hard",
        "<PROJECT_DIR>/results/_silver_scaleup_s3/olmo3_dsr_v2_sft_full__olymp_math_hard.jsonl",
    ),
    # NEW: olmo3 v2 curriculum step15
    (
        "olmo3_v2_curriculum_s15", "olymp_math_easy",
        "<PROJECT_DIR>/results/_silver_scaleup_s3/olmo3_v2_curriculum_s15__olymp_math_easy.jsonl",
    ),
    (
        "olmo3_v2_curriculum_s15", "olymp_math_hard",
        "<PROJECT_DIR>/results/_silver_scaleup_s3/olmo3_v2_curriculum_s15__olymp_math_hard.jsonl",
    ),
    # NEW: gspo v2 sft s20 olymp_easy
    (
        "olmo3_gspo_v2_sft_s20", "olymp_math_easy",
        "<PROJECT_DIR>/results/_silver_scaleup_s3/olmo3_gspo_v2_sft_s20__olymp_math_easy.jsonl",
    ),
]


_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
_SENTENCE_END_RE = re.compile(r"[.!?]\s+")


def extract_reasoning_tag_aware(response: str) -> str:
    rsg = extract_reasoning(response)
    if rsg is not None:
        return rsg
    m = _THINK_RE.search(response)
    if m:
        return m.group(1).strip()
    return response


def extract_preceding_context(reasoning_text: str, start_char: int, max_chars: int = 500) -> str:
    if start_char <= 0:
        return ""
    window = reasoning_text[max(0, start_char - 400):start_char]
    ends = [m.end() for m in _SENTENCE_END_RE.finditer(window)]
    if len(ends) >= 2:
        ctx = window[ends[-2]:]
    elif ends:
        ctx = window[ends[-1]:]
    else:
        ctx = window[-200:]
    return ctx.strip()[:max_chars]


def sample_spans_from_trace(response: str, n_spans_target: int, rng: random.Random) -> list[dict]:
    reasoning = extract_reasoning_tag_aware(response)
    spans = segment_trace(reasoning, tokenizer_name=None)
    if not spans:
        return []
    n = min(n_spans_target, len(spans))
    chosen_idx = sorted(rng.sample(range(len(spans)), n))
    out = []
    for i in chosen_idx:
        sp = spans[i]
        out.append({
            "span_text": sp.text,
            "preceding_context": extract_preceding_context(reasoning, sp.start_char),
            "n_tokens": sp.n_tokens,
            "span_idx": sp.span_id,
            "n_spans_in_trace": len(spans),
        })
    return out


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument("--exclude-ids", type=Path, default=None)
    ap.add_argument("--n-traces-per-source", type=int, default=80)
    ap.add_argument("--n-spans-per-trace", type=int, default=20)
    ap.add_argument("--seed", type=int, default=48)
    args = ap.parse_args()

    rng = random.Random(args.seed)
    args.out.parent.mkdir(parents=True, exist_ok=True)

    existing_ids: set[str] = set()
    if args.exclude_ids and args.exclude_ids.exists():
        with open(args.exclude_ids) as f:
            for line in f:
                existing_ids.add(json.loads(line)["span_id"])

    rows: list[dict] = []
    for source, benchmark, path in V2_MATH_SOURCES:
        if not Path(path).exists():
            print(f"  SKIP missing: {source} | {benchmark}")
            continue
        with open(path) as f:
            lines = f.readlines()
        sampled_lines = rng.sample(lines, min(args.n_traces_per_source, len(lines)))
        kept = 0
        for line in sampled_lines:
            d = json.loads(line)
            doc_id = d.get("doc_id", d.get("idx", 0))
            resps = d.get("resps", [[]])
            responses = resps[0] if resps and isinstance(resps[0], list) else resps
            if not responses:
                continue
            response = responses[0]
            spans = sample_spans_from_trace(response, args.n_spans_per_trace, rng)
            for sp in spans:
                sp.update({
                    "span_id": f"{source}|{benchmark}|{doc_id}|0|{sp['span_idx']}",
                    "checkpoint_id": source,
                    "task_name": benchmark,
                    "doc_id": int(doc_id),
                    "trace_id": 0,
                    "episode_idx": sp["span_idx"],
                    "n_episodes_in_trace": sp["n_spans_in_trace"],
                    "correct": False,
                })
                if sp["span_id"] in existing_ids:
                    continue
                rows.append(sp)
                kept += 1
        print(f"  {source:<26} | {benchmark:<14} | kept {kept} new spans")

    seen = set()
    deduped = []
    for r in rows:
        if r["span_id"] in seen: continue
        seen.add(r["span_id"])
        deduped.append(r)

    tmp = args.out.with_suffix(args.out.suffix + ".tmp")
    with open(tmp, "w") as f:
        for r in deduped:
            f.write(json.dumps(r) + "\n")
    os.replace(tmp, args.out)

    print(f"\nWrote {len(deduped)} new v2 math spans -> {args.out}")
    print()
    for k, v in Counter((r["checkpoint_id"], r["task_name"]) for r in deduped).most_common():
        print(f"  {k}: {v}")


if __name__ == "__main__":
    main()
