"""Build a HELD-OUT puzzle test set from v2 rollouts (gspo_v2 + dsr_v2_sft).

These checkpoints are NOT in the training silver. The resulting silver
is used solely as a generalization test set for the production ensemble.
Mirrors the math sampler: tag-aware extraction, dedup, V3-SC v6 minting.
"""
from __future__ import annotations

import argparse
import json
import os
import random
import re
from collections import Counter
from pathlib import Path

from analysis.exploration.segmentation import segment_trace, extract_reasoning


V2_PUZZLE_SOURCES = [
    # gspo_v2 step 20 — RL-tuned olmo3 v2
    (
        "olmo3_gspo_v2_step20", "bridges_8x8de",
        "<PROJECT_DIR>/results/gspo_v2_s20/hard_puzzle_pass32/checkpoints__olmo3-puzzle-grpo__multi_puzzle_gspo_olmo3_v2_sft_v2__merged_step_20/samples_bridges_8x8de_pass32_2026-04-24T23-03-07.158909.jsonl",
    ),
    (
        "olmo3_gspo_v2_step20", "undead_5x5de",
        "<PROJECT_DIR>/results/gspo_v2_s20/hard_puzzle_pass32/checkpoints__olmo3-puzzle-grpo__multi_puzzle_gspo_olmo3_v2_sft_v2__merged_step_20/samples_undead_5x5de_pass32_2026-04-24T23-03-07.158909.jsonl",
    ),
    # sft_v2 ep5 — SFT v2 baseline
    (
        "olmo3_dsr_v2_sft_ep5", "bridges_8x8de",
        "<PROJECT_DIR>/results/sft_v2_ep5/hard_puzzle_pass32/checkpoints__olmo3_7b_multi_puzzle_dsr_v2__merged_ep5_fp32/samples_bridges_8x8de_pass32_2026-04-24T15-21-45.043307.jsonl",
    ),
    (
        "olmo3_dsr_v2_sft_ep5", "undead_5x5de",
        "<PROJECT_DIR>/results/sft_v2_ep5/hard_puzzle_pass32/checkpoints__olmo3_7b_multi_puzzle_dsr_v2__merged_ep5_fp32/samples_undead_5x5de_pass32_2026-04-24T15-21-45.043307.jsonl",
    ),
    (
        "olmo3_dsr_v2_sft_ep5", "bridges_7x7dm",
        "<PROJECT_DIR>/results/sft_v2_ep5/puzzle_pass32/checkpoints__olmo3_7b_multi_puzzle_dsr_v2__merged_ep5_fp32/samples_bridges_7x7dm_pass32_2026-04-24T07-16-11.701519.jsonl",
    ),
    (
        "olmo3_dsr_v2_sft_ep5", "pattern_4x4",
        "<PROJECT_DIR>/results/sft_v2_ep5/puzzle_pass32/checkpoints__olmo3_7b_multi_puzzle_dsr_v2__merged_ep5_fp32/samples_pattern_4x4_pass32_2026-04-24T07-16-11.701519.jsonl",
    ),
    (
        "olmo3_dsr_v2_sft_ep5", "undead_4x4de",
        "<PROJECT_DIR>/results/sft_v2_ep5/puzzle_pass32/checkpoints__olmo3_7b_multi_puzzle_dsr_v2__merged_ep5_fp32/samples_undead_4x4de_pass32_2026-04-24T07-16-11.701519.jsonl",
    ),
]


_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
_SENTENCE_END_RE = re.compile(r"[.!?]\s+")


def extract_reasoning_tag_aware(response: str) -> str:
    rsg = extract_reasoning(response)
    if rsg is not None:
        return rsg
    m = _THINK_RE.search(response)
    if m:
        return m.group(1).strip()
    return response


def extract_preceding_context(reasoning_text: str, start_char: int, max_chars: int = 500) -> str:
    if start_char <= 0:
        return ""
    window = reasoning_text[max(0, start_char - 400):start_char]
    ends = [m.end() for m in _SENTENCE_END_RE.finditer(window)]
    if len(ends) >= 2:
        ctx = window[ends[-2]:]
    elif ends:
        ctx = window[ends[-1]:]
    else:
        ctx = window[-200:]
    return ctx.strip()[:max_chars]


def sample_spans_from_trace(response: str, n_spans_target: int, rng: random.Random) -> list[dict]:
    reasoning = extract_reasoning_tag_aware(response)
    spans = segment_trace(reasoning, tokenizer_name=None)
    if not spans:
        return []
    n = min(n_spans_target, len(spans))
    chosen_idx = sorted(rng.sample(range(len(spans)), n))
    out = []
    for i in chosen_idx:
        sp = spans[i]
        out.append({
            "span_text": sp.text,
            "preceding_context": extract_preceding_context(reasoning, sp.start_char),
            "n_tokens": sp.n_tokens,
            "span_idx": sp.span_id,
            "n_spans_in_trace": len(spans),
        })
    return out


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument("--n-traces-per-source", type=int, default=20)
    ap.add_argument("--n-spans-per-trace", type=int, default=10)
    ap.add_argument("--seed", type=int, default=46)
    args = ap.parse_args()

    rng = random.Random(args.seed)
    args.out.parent.mkdir(parents=True, exist_ok=True)

    rows: list[dict] = []
    for source, benchmark, path in V2_PUZZLE_SOURCES:
        if not Path(path).exists():
            print(f"  SKIP missing: {source} | {benchmark}")
            continue
        with open(path) as f:
            lines = f.readlines()
        sampled_lines = rng.sample(lines, min(args.n_traces_per_source, len(lines)))
        kept = 0
        for line in sampled_lines:
            d = json.loads(line)
            doc_id = d.get("doc_id", d.get("idx", 0))
            resps = d.get("resps", [[]])
            responses = resps[0] if resps and isinstance(resps[0], list) else resps
            if not responses:
                continue
            response = responses[0]
            spans = sample_spans_from_trace(response, args.n_spans_per_trace, rng)
            for sp in spans:
                sp.update({
                    "span_id": f"{source}|{benchmark}|{doc_id}|0|{sp['span_idx']}",
                    "checkpoint_id": source,
                    "task_name": benchmark,
                    "doc_id": int(doc_id),
                    "trace_id": 0,
                    "episode_idx": sp["span_idx"],
                    "n_episodes_in_trace": sp["n_spans_in_trace"],
                    "correct": False,
                })
                rows.append(sp)
                kept += 1
        print(f"  {source:<28} | {benchmark:<14} | kept {kept} spans")

    seen = set()
    deduped = []
    for r in rows:
        if r["span_id"] in seen: continue
        seen.add(r["span_id"])
        deduped.append(r)

    tmp = args.out.with_suffix(args.out.suffix + ".tmp")
    with open(tmp, "w") as f:
        for r in deduped:
            f.write(json.dumps(r) + "\n")
    os.replace(tmp, args.out)

    print(f"\nWrote {len(deduped)} v2-rollout puzzle spans -> {args.out}")
    print()
    print("By source:")
    for k, v in Counter((r["checkpoint_id"], r["task_name"]) for r in deduped).most_common():
        print(f"  {k}: {v}")


if __name__ == "__main__":
    main()
