"""§4.1 primitive + motif distribution figure (2×2 grid).

Spec: reports/neurips/tasks/section4_1_primitive_motif_distributions.md

Sources (trace-level parquets from `analysis.exploration.pipeline`):
  Base   math    -> results/exploration_analysis/v90_base_math/
  Base   puzzles -> results/exploration_analysis/v90_base_puzzles/   (may be missing)
  DSR    puzzles -> results/exploration_analysis/v90_dsr_puzzles/
  SFT    math    -> results/exploration_analysis/v90_sft_math/
  SFT    puzzles -> results/exploration_analysis/v90_sft_puzzles/

Outputs:
  results/exploration_analysis/section4_1_distributions.csv
    long-form: condition, domain, target_type, target_name, scheme,
               mean_per_trace, n_traces
  writing/neurips_paper/figures/fig_primitive_motif_distributions_k{N}.png
    one figure per motif scheme (k=3, k=5, k3+5 union) for comparison;
    final filename `fig_primitive_motif_distributions.png` is a copy of
    the chosen scheme (default k=3).

Usage:
    python scripts/analysis/section4_1_primitive_motif_distributions.py
    python scripts/analysis/section4_1_primitive_motif_distributions.py \\
        --motif_scheme k3 --layout 2x2

Re-renders in <30s when the CSV is already present (--use_cache).
"""

from __future__ import annotations

import argparse
import json
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

PRIMITIVES = [
    "PLAN", "SETUP", "ENUMERATE", "HYPOTHESIZE",
    "COMPUTE", "CHECK", "BACKTRACK", "SUMMARIZE", "OTHER",
]

# Conditions: (label, domain, parquet_path)
SOURCES = [
    ("Base", "math",    "results/exploration_analysis/v90_base_math/trace_level_metrics.parquet"),
    ("SFT",  "math",    "results/exploration_analysis/v90_sft_math/trace_level_metrics.parquet"),
    ("Base", "puzzles", "results/exploration_analysis/v90_base_puzzles/trace_level_metrics.parquet"),
    ("DSR",  "puzzles", "results/exploration_analysis/v90_dsr_puzzles/trace_level_metrics.parquet"),
    ("SFT",  "puzzles", "results/exploration_analysis/v90_sft_puzzles/trace_level_metrics.parquet"),
]

# Pastel palette — keep in sync with CLAUDE.md "Figure colour scheme".
PALETTE = {
    "Base": "#cfd8dc",   # blue-grey 100
    "DSR":  "#e0e0e0",   # grey 300 — DSR isn't a checkpoint stage; neutral
    "SFT":  "#c8e6c9",   # green 100
}
EDGE = {
    "Base": "#607d8b",
    "DSR":  "#757575",
    "SFT":  "#2e7d32",
}

# Motif-category definitions (3 schemes, picked at runtime via --motif_scheme).
ALLOWED_MIDDLE = {
    "PLAN", "SETUP", "ENUMERATE", "COMPUTE", "SUMMARIZE", "OTHER"
    # explicitly NOT: CHECK (would break the anchor pattern),
    # HYPOTHESIZE / BACKTRACK (those are Recovery, not Verification)
}

OUT_CSV   = Path("results/exploration_analysis/section4_1_distributions.csv")
OUT_FIG_DIR = Path("writing/neurips_paper/figures")
OUT_REPORT = Path("reports/neurips/section4_1_distributions.md")
OUT_FIG_DIR.mkdir(parents=True, exist_ok=True)


# ---------------------------------------------------------------------------
# Trace helpers
# ---------------------------------------------------------------------------

def parse_seq(s) -> list[str]:
    if s is None:
        return []
    try:
        return json.loads(s)
    except (json.JSONDecodeError, TypeError):
        return []


def primitive_counts(seq: list[str]) -> dict[str, int]:
    """Per-trace count of each of the 9 primitives."""
    counts = {p: 0 for p in PRIMITIVES}
    for x in seq:
        if x in counts:
            counts[x] += 1
    return counts


# ---- Motif counters (overlapping sliding-window, stride 1) ----

def count_recovery_window(seq: list[str], k: int) -> int:
    """Number of overlapping length-k windows containing both HYP and BTK."""
    n = len(seq)
    if n < k:
        return 0
    c = 0
    for i in range(n - k + 1):
        win = seq[i:i + k]
        if "HYPOTHESIZE" in win and "BACKTRACK" in win:
            c += 1
    return c


def count_exact_motif(seq: list[str], motif: tuple[str, ...]) -> int:
    """Overlapping count of exact contiguous match of `motif` in `seq`."""
    n, m = len(seq), len(motif)
    if m == 0 or n < m:
        return 0
    c = 0
    for i in range(n - m + 1):
        if tuple(seq[i:i + m]) == motif:
            c += 1
    return c


def count_exploitation_k3(seq):
    return count_exact_motif(seq, ("COMPUTE", "CHECK", "COMPUTE"))


def count_exploitation_k5(seq):
    return count_exact_motif(
        seq, ("COMPUTE", "CHECK", "COMPUTE", "CHECK", "COMPUTE"))


def count_verification_k3(seq):
    """`CHK -> X -> CHK` with X in ALLOWED_MIDDLE."""
    n = len(seq)
    if n < 3:
        return 0
    c = 0
    for i in range(n - 2):
        if (seq[i] == "CHECK"
                and seq[i + 2] == "CHECK"
                and seq[i + 1] in ALLOWED_MIDDLE):
            c += 1
    return c


def count_verification_k5(seq):
    """`CHK -> X -> CHK -> Y -> CHK` with X,Y in ALLOWED_MIDDLE."""
    n = len(seq)
    if n < 5:
        return 0
    c = 0
    for i in range(n - 4):
        if (seq[i]     == "CHECK"
                and seq[i + 2] == "CHECK"
                and seq[i + 4] == "CHECK"
                and seq[i + 1] in ALLOWED_MIDDLE
                and seq[i + 3] in ALLOWED_MIDDLE):
            c += 1
    return c


def motif_counts_for_scheme(seq: list[str], scheme: str) -> dict[str, int]:
    """Returns dict with keys {Recovery, Exploitation, Verification}."""
    if scheme == "k3":
        return {
            "Recovery":     count_recovery_window(seq, 3),
            "Exploitation": count_exploitation_k3(seq),
            "Verification": count_verification_k3(seq),
        }
    if scheme == "k5":
        return {
            "Recovery":     count_recovery_window(seq, 5),
            "Exploitation": count_exploitation_k5(seq),
            "Verification": count_verification_k5(seq),
        }
    if scheme == "k3+5":
        return {
            "Recovery":     count_recovery_window(seq, 3) + count_recovery_window(seq, 5),
            "Exploitation": count_exploitation_k3(seq)    + count_exploitation_k5(seq),
            "Verification": count_verification_k3(seq)    + count_verification_k5(seq),
        }
    raise ValueError(f"unknown scheme: {scheme}")


# ---------------------------------------------------------------------------
# Compute long-form table
# ---------------------------------------------------------------------------

def compute_distributions() -> pd.DataFrame:
    rows: list[dict] = []
    for cond, dom, path in SOURCES:
        p = Path(path)
        if not p.exists():
            print(f"  [missing] {cond}/{dom}: {path}")
            continue
        df = pd.read_parquet(p)
        seqs = df["primitive_sequence"].apply(parse_seq).tolist()
        n = len(seqs)
        print(f"  {cond:5s} {dom:8s}  n={n}  ({path})")

        # Primitives
        prim_lists: dict[str, list[int]] = {p: [] for p in PRIMITIVES}
        for s in seqs:
            c = primitive_counts(s)
            for p_name in PRIMITIVES:
                prim_lists[p_name].append(c[p_name])
        for p_name in PRIMITIVES:
            mean = float(np.mean(prim_lists[p_name])) if n else float("nan")
            rows.append({
                "condition": cond, "domain": dom,
                "target_type": "primitive", "target_name": p_name,
                "scheme": "n/a",
                "mean_per_trace": mean, "n_traces": n,
            })

        # Motifs (3 schemes)
        for scheme in ("k3", "k5", "k3+5"):
            counts_by_cat = {"Recovery": [], "Exploitation": [], "Verification": []}
            for s in seqs:
                m = motif_counts_for_scheme(s, scheme)
                for cat, v in m.items():
                    counts_by_cat[cat].append(v)
            for cat, lst in counts_by_cat.items():
                mean = float(np.mean(lst)) if lst else float("nan")
                rows.append({
                    "condition": cond, "domain": dom,
                    "target_type": "motif_category", "target_name": cat,
                    "scheme": scheme,
                    "mean_per_trace": mean, "n_traces": n,
                })

    return pd.DataFrame(rows)


# ---------------------------------------------------------------------------
# Plot
# ---------------------------------------------------------------------------

def _draw_grouped_bars(ax, *, conditions, x_labels, values_by_cond, title,
                       ylabel, ymax=None, label_conditions=None):
    """Grouped bar chart. `values_by_cond` is dict[cond -> list[value]] aligned
    with x_labels. NaNs are shown as faint hatched bars.

    `label_conditions`: iterable of condition names whose bars get numeric
    annotations above. None = label all conditions.
    """
    n_cond = len(conditions)
    n_x = len(x_labels)
    width = 0.8 / max(n_cond, 1)
    xs = np.arange(n_x)
    label_set = set(conditions) if label_conditions is None else set(label_conditions)

    for j, cond in enumerate(conditions):
        offsets = xs + (j - (n_cond - 1) / 2) * width
        vals = values_by_cond[cond]
        face = PALETTE[cond]
        edge = EDGE[cond]
        plot_vals = [0.0 if (v is None or np.isnan(v)) else v for v in vals]
        is_missing = [v is None or np.isnan(v) for v in vals]
        bars = ax.bar(offsets, plot_vals, width=width * 0.95,
                      color=face, edgecolor=edge, linewidth=1.0,
                      label=cond)
        for b, miss in zip(bars, is_missing):
            if miss:
                b.set_hatch("///")
                b.set_facecolor("none")
                b.set_alpha(0.35)
        if ymax is not None and ymax > 0 and cond in label_set:
            for x, v, miss in zip(offsets, plot_vals, is_missing):
                if not miss:
                    ax.text(x, v + 0.015 * ymax, f"{v:.1f}",
                            ha="center", va="bottom", fontsize=7)

    ax.set_xticks(xs)
    ax.set_xticklabels(x_labels, rotation=20, ha="right", fontsize=8)
    ax.set_ylabel(ylabel, fontsize=9)
    ax.set_title(title, fontsize=10)
    ax.grid(True, axis="y", alpha=0.3)
    if ymax is not None:
        ax.set_ylim(0, ymax * 1.15)


def render_figure(df: pd.DataFrame, scheme: str, *, nrows=2, ncols=2,
                  out_path: Path | None = None) -> Path:
    if out_path is None:
        out_path = OUT_FIG_DIR / f"fig_primitive_motif_distributions_{scheme.replace('+','plus')}.png"

    # ------------------------------------------------------------------ data
    def get_vals(condition, domain, target_type, names, this_scheme=None):
        sub = df[(df["condition"] == condition)
                 & (df["domain"] == domain)
                 & (df["target_type"] == target_type)]
        if this_scheme is not None:
            sub = sub[sub["scheme"] == this_scheme]
        idx = sub.set_index("target_name")["mean_per_trace"]
        return [float(idx[n]) if n in idx.index else float("nan") for n in names]

    motif_cats = ["Recovery", "Exploitation", "Verification"]
    domains_conditions = {
        "puzzles": ["Base", "DSR", "SFT"],
        "math":    ["Base", "SFT"],
    }
    # Per-row shared y-max so within-row comparisons are visually fair.
    def row_ymax(domain, target_type, names, this_scheme=None):
        vals = []
        for cond in domains_conditions[domain]:
            vals.extend([v for v in get_vals(cond, domain, target_type,
                                             names, this_scheme=this_scheme)
                         if v is not None and not np.isnan(v)])
        return (max(vals) if vals else 1.0)

    prim_ymax_puzzles = row_ymax("puzzles", "primitive", PRIMITIVES)
    prim_ymax_math    = row_ymax("math",    "primitive", PRIMITIVES)
    motif_ymax_puzzles = row_ymax("puzzles", "motif_category", motif_cats, scheme)
    motif_ymax_math    = row_ymax("math",    "motif_category", motif_cats, scheme)

    # ---------------------------------------------------------------- figure
    fig, axes = plt.subplots(nrows, ncols, figsize=(13, 7), squeeze=False)

    if (nrows, ncols) == (2, 2):
        ax_pp, ax_mp = axes[0, 0], axes[0, 1]
        ax_pm, ax_mm = axes[1, 0], axes[1, 1]
    else:  # 1x4
        ax_pp, ax_mp, ax_pm, ax_mm = axes.flat[:4]

    scheme_suffix = {"k3": "trigrams", "k5": "5-grams", "k3+5": "trigrams + 5-grams"}[scheme]

    # Top-left : Primitives — Puzzles  (label SFT only)
    _draw_grouped_bars(
        ax_pp,
        conditions=domains_conditions["puzzles"],
        x_labels=PRIMITIVES,
        values_by_cond={c: get_vals(c, "puzzles", "primitive", PRIMITIVES)
                        for c in domains_conditions["puzzles"]},
        title="Primitives — Puzzles",
        ylabel="Mean count per trace",
        ymax=prim_ymax_puzzles,
        label_conditions=["SFT"],
    )

    # Top-right : Motifs — Puzzles
    _draw_grouped_bars(
        ax_mp,
        conditions=domains_conditions["puzzles"],
        x_labels=motif_cats,
        values_by_cond={c: get_vals(c, "puzzles", "motif_category",
                                    motif_cats, this_scheme=scheme)
                        for c in domains_conditions["puzzles"]},
        title=f"Motifs — Puzzles ({scheme_suffix})",
        ylabel="Mean count per trace",
        ymax=motif_ymax_puzzles,
    )

    # Bottom-left : Primitives — Math  (label SFT only)
    _draw_grouped_bars(
        ax_pm,
        conditions=domains_conditions["math"],
        x_labels=PRIMITIVES,
        values_by_cond={c: get_vals(c, "math", "primitive", PRIMITIVES)
                        for c in domains_conditions["math"]},
        title="Primitives — Math",
        ylabel="Mean count per trace",
        ymax=prim_ymax_math,
        label_conditions=["SFT"],
    )

    # Bottom-right : Motifs — Math
    _draw_grouped_bars(
        ax_mm,
        conditions=domains_conditions["math"],
        x_labels=motif_cats,
        values_by_cond={c: get_vals(c, "math", "motif_category",
                                    motif_cats, this_scheme=scheme)
                        for c in domains_conditions["math"]},
        title=f"Motifs — Math ({scheme_suffix})",
        ylabel="Mean count per trace",
        ymax=motif_ymax_math,
    )

    # Shared legend (Base / DSR / SFT) at top
    handles = [plt.Rectangle((0, 0), 1, 1,
                             facecolor=PALETTE[c], edgecolor=EDGE[c],
                             linewidth=1.0, label=c)
               for c in ["Base", "DSR", "SFT"]]
    fig.legend(handles=handles, loc="upper center", ncol=3,
               bbox_to_anchor=(0.5, 1.00), frameon=False, fontsize=10)

    fig.tight_layout(rect=(0, 0, 1, 0.97))
    fig.savefig(out_path, dpi=160, bbox_inches="tight")
    plt.close(fig)
    return out_path


# ---------------------------------------------------------------------------
# Report
# ---------------------------------------------------------------------------

def _fmt(v):
    if v is None or (isinstance(v, float) and np.isnan(v)):
        return "n/a"
    return f"{v:.2f}"


def write_report(df: pd.DataFrame, scheme: str, fig_path: Path) -> Path:
    """Emit the §4.1 distributions report with figure embed + tables."""
    from datetime import date

    def get(condition, domain, target_type, name, this_scheme=None):
        sub = df[(df["condition"] == condition)
                 & (df["domain"] == domain)
                 & (df["target_type"] == target_type)
                 & (df["target_name"] == name)]
        if this_scheme is not None:
            sub = sub[sub["scheme"] == this_scheme]
        if sub.empty:
            return float("nan"), 0
        row = sub.iloc[0]
        return float(row["mean_per_trace"]), int(row["n_traces"])

    # Primitive table — 4-column
    prim_rows = []
    for p_name in PRIMITIVES:
        v_base_p, _   = get("Base", "puzzles", "primitive", p_name)
        v_dsr_p, _    = get("DSR",  "puzzles", "primitive", p_name)
        v_sft_p, _    = get("SFT",  "puzzles", "primitive", p_name)
        v_base_m, _   = get("Base", "math",    "primitive", p_name)
        v_sft_m, _    = get("SFT",  "math",    "primitive", p_name)
        prim_rows.append((p_name, v_base_p, v_dsr_p, v_sft_p, v_base_m, v_sft_m))

    # Episodes/trace totals (sum across primitives) — sanity check vs dsr_vs_sft_primitives.md
    def total_episodes(condition, domain):
        s = sum(get(condition, domain, "primitive", p)[0] for p in PRIMITIVES
                if not np.isnan(get(condition, domain, "primitive", p)[0]))
        return s

    n_traces = {
        ("Base", "puzzles"): get("Base", "puzzles", "primitive", "PLAN")[1],
        ("DSR",  "puzzles"): get("DSR",  "puzzles", "primitive", "PLAN")[1],
        ("SFT",  "puzzles"): get("SFT",  "puzzles", "primitive", "PLAN")[1],
        ("Base", "math"):    get("Base", "math",    "primitive", "PLAN")[1],
        ("SFT",  "math"):    get("SFT",  "math",    "primitive", "PLAN")[1],
    }

    # Motif table at chosen scheme
    motif_cats = ["Recovery", "Exploitation", "Verification"]
    motif_rows = []
    for cat in motif_cats:
        v_base_p, _ = get("Base", "puzzles", "motif_category", cat, this_scheme=scheme)
        v_dsr_p, _  = get("DSR",  "puzzles", "motif_category", cat, this_scheme=scheme)
        v_sft_p, _  = get("SFT",  "puzzles", "motif_category", cat, this_scheme=scheme)
        v_base_m, _ = get("Base", "math",    "motif_category", cat, this_scheme=scheme)
        v_sft_m, _  = get("SFT",  "math",    "motif_category", cat, this_scheme=scheme)
        motif_rows.append((cat, v_base_p, v_dsr_p, v_sft_p, v_base_m, v_sft_m))

    # Cross-scheme comparison table for motifs (just SFT-puzzles mean — gives a sense of scale per scheme)
    scheme_compare = []
    for sch in ("k3", "k5", "k3+5"):
        for cat in motif_cats:
            v_dsr, _ = get("DSR",  "puzzles", "motif_category", cat, this_scheme=sch)
            v_sft, _ = get("SFT",  "puzzles", "motif_category", cat, this_scheme=sch)
            v_sft_m, _ = get("SFT", "math",   "motif_category", cat, this_scheme=sch)
            scheme_compare.append((sch, cat, v_dsr, v_sft, v_sft_m))

    rel_fig_path = "../../" + str(fig_path).replace("\\", "/")

    body = f"""# §4.1 Primitive + Motif Distributions

**Date**: {date.today().isoformat()}
**Spec**: `reports/neurips/tasks/section4_1_primitive_motif_distributions.md`
**Generator**: `scripts/analysis/section4_1_primitive_motif_distributions.py`
**Numbers CSV**: `{OUT_CSV}`
**Canonical motif scheme**: `{scheme}`

This report backs the body figure for §4.1 ("Puzzle SFT Induces a Reasoning
Primitive Vocabulary"). The 2×2 grid shows per-trace mean primitive and
motif counts on puzzles (top row) and math (bottom row). DSR teacher traces
are included on the puzzle row only — DSR was never run on math problems
in this project, so the math row has just two bars (Base, SFT).

![Primitive and motif distributions]({rel_fig_path})

## Primitives — per-trace mean episode count

| Primitive | Puzzles Base | Puzzles DSR | Puzzles SFT | Math Base | Math SFT |
|---|---:|---:|---:|---:|---:|
""" + "\n".join(
        f"| {name} | {_fmt(b)} | {_fmt(d)} | {_fmt(s)} | {_fmt(bm)} | {_fmt(sm)} |"
        for name, b, d, s, bm, sm in prim_rows
    ) + f"""
| **Mean episodes/trace** | {_fmt(total_episodes('Base', 'puzzles'))} | {_fmt(total_episodes('DSR', 'puzzles'))} | {_fmt(total_episodes('SFT', 'puzzles'))} | {_fmt(total_episodes('Base', 'math'))} | {_fmt(total_episodes('SFT', 'math'))} |
| n traces | {n_traces[('Base','puzzles')]} | {n_traces[('DSR','puzzles')]} | {n_traces[('SFT','puzzles')]} | {n_traces[('Base','math')]} | {n_traces[('SFT','math')]} |

Sanity check: DSR-puzzles, SFT-puzzles and SFT-math episode totals reproduce
`reports/neurips/dsr_vs_sft_primitives.md` Table 1 (32.4 / 49.4 / 19.6 episodes
per trace, see `dsr_vs_sft_primitives.md`).

## Motif categories — per-trace mean count (scheme = `{scheme}`)

Definitions:
- **Recovery** — sliding length-{scheme.replace('k3+5','3+5').replace('k', '')}
  windows containing both `HYPOTHESIZE` and `BACKTRACK`.
- **Exploitation** — exact contiguous matches of compute-anchored chains
  (`COMPUTE→CHECK→COMPUTE` for k=3, plus `COMPUTE→CHECK→COMPUTE→CHECK→COMPUTE`
  for k=5).
- **Verification** — exact contiguous matches of CHECK-anchored chains with
  non-`{{HYP, BTK, CHK}}` middle slots (`CHK→X→CHK` for k=3, `CHK→X→CHK→Y→CHK`
  for k=5).
  Allowed middle: `{{PLAN, SETUP, ENUMERATE, COMPUTE, SUMMARIZE, OTHER}}`.

| Category | Puzzles Base | Puzzles DSR | Puzzles SFT | Math Base | Math SFT |
|---|---:|---:|---:|---:|---:|
""" + "\n".join(
        f"| {name} | {_fmt(b)} | {_fmt(d)} | {_fmt(s)} | {_fmt(bm)} | {_fmt(sm)} |"
        for name, b, d, s, bm, sm in motif_rows
    ) + f"""

### Cross-scheme magnitude check

To make the choice of `k` explicit, here are the same motif means at all
three schemes for the puzzles-DSR / puzzles-SFT / math-SFT cells (these are
the cells that drive the §4.1 narrative).

| Scheme | Category | DSR-puzzles | SFT-puzzles | SFT-math |
|---|---|---:|---:|---:|
""" + "\n".join(
        f"| {sch} | {cat} | {_fmt(d)} | {_fmt(s)} | {_fmt(m)} |"
        for sch, cat, d, s, m in scheme_compare
    ) + f"""

The k=3 scheme is the canonical choice for the body figure: motif counts on
math are non-trivial enough to read (verification 0–1+, exploitation 0–1+),
and the puzzle-side magnitudes still differentiate DSR from SFT cleanly.
k=5 collapses the math cells toward zero; k=3+5 is just k=3 scaled up.

## Interpretation

### Primitive vocabulary is induced by SFT and inherited from DSR.

The Base column on puzzles shows the un-puzzle-SFT model's reasoning
vocabulary when attempting puzzles it cannot solve (0% pass@32). Traces
are short (mean {total_episodes('Base','puzzles'):.1f} episodes vs.
{total_episodes('DSR','puzzles'):.1f} for DSR and
{total_episodes('SFT','puzzles'):.1f} for SFT) and every primitive count
is well below DSR/SFT. Notably HYPOTHESIZE
({get('Base','puzzles','primitive','HYPOTHESIZE')[0]:.1f}) and BACKTRACK
({get('Base','puzzles','primitive','BACKTRACK')[0]:.1f}) are non-zero —
the chat-tuned base already has these markers in its repertoire — but
puzzle-SFT amplifies them 4-7× ({get('SFT','puzzles','primitive','HYPOTHESIZE')[0]:.1f} HYP,
{get('SFT','puzzles','primitive','BACKTRACK')[0]:.1f} BTK). The DSR
teacher sits between them.

DSR-vs-SFT puzzle alignment on the four "exploitation-like" primitives
(CHECK, COMPUTE, ENUMERATE, PLAN) is within ~1pp on episode-fraction
(per `dsr_vs_sft_primitives.md`); on raw counts the SFT student is
~1.4–1.7× the DSR teacher because SFT traces are longer (49 vs 32
episodes/trace). The two exploration primitives are amplified
2.4–3.0× — a known artefact of the 5%-truncated training corpus /
chat-base inheritance.

### Primitive vocabulary survives the puzzle→math transfer.

Comparing Math-Base to Math-SFT: every primitive present in the puzzle
SFT distribution is also present in the math SFT traces, with the
qualitative shape preserved. `PLAN` and `SUMMARIZE` are amplified on math
(reflecting the different shape of math problems — define, then sum up),
while `BACKTRACK` is suppressed (math SFT rarely admits a wrong turn).
Crucially, `CHECK` and `COMPUTE` — the two highest-frequency primitives
in DSR puzzle traces — remain dominant on math.

### Motifs (recovery / exploitation / verification) follow the same chain.

- **Recovery**: near-zero on Base/Math (model never enters a HYP↔BTK
  arc); rises on SFT-math from {get("Base","math","motif_category","Recovery", this_scheme=scheme)[0]:.2f} to
  {get("SFT","math","motif_category","Recovery", this_scheme=scheme)[0]:.2f}; large on SFT-puzzles
  ({get("SFT","puzzles","motif_category","Recovery", this_scheme=scheme)[0]:.1f}), substantially above DSR
  ({get("DSR","puzzles","motif_category","Recovery", this_scheme=scheme)[0]:.1f}) — the SFT student amplifies
  recovery motifs more than the teacher does.
- **Exploitation** (`CMP→CHK→CMP`): the canonical compute/check loop §6
  studies. Base-math is essentially zero
  ({get("Base","math","motif_category","Exploitation", this_scheme=scheme)[0]:.2f}); SFT-math reaches
  {get("SFT","math","motif_category","Exploitation", this_scheme=scheme)[0]:.2f}, SFT-puzzles
  {get("SFT","puzzles","motif_category","Exploitation", this_scheme=scheme)[0]:.2f}.
- **Verification** (`CHK→X→CHK`): same shape — rises from
  {get("Base","math","motif_category","Verification", this_scheme=scheme)[0]:.2f} (Base-math) to
  {get("SFT","math","motif_category","Verification", this_scheme=scheme)[0]:.2f} (SFT-math).

The §4.3 / §5.1 finding that RL training erodes recovery motifs is
load-bearing on the existence of those motifs in the SFT distribution.
This figure provides that baseline.

## Files

- Generator: `scripts/analysis/section4_1_primitive_motif_distributions.py`
- Numbers CSV (long-form, all schemes): `{OUT_CSV}`
- Figure (canonical, k=3): `writing/neurips_paper/figures/fig_primitive_motif_distributions.png`
- Per-scheme variants: `fig_primitive_motif_distributions_k3.png`,
  `_k5.png`, `_k3plus5.png`

## Caveats

- Base-puzzles eval (n={n_traces[('Base','puzzles')]} traces, OLMo3-Instruct-SFT × 100 problems × 32
  rollouts on `bridges_8x8de_pass32` and `undead_5x5de_pass32`) found 0%
  exact-match — the base model cannot solve these puzzles, so the Base
  bar measures the reasoning vocabulary it deploys when *attempting*
  them. Many trajectories are short ("I cannot solve this") or hit the
  22K-token max in unproductive loops; both are valid evidence of
  vocabulary absence.
- v90 classifier F1 on V2 puzzles ≈ 0.80, on in-domain math ≈ 0.74
  (per `reports/primitive_classifier_unified.md`). Numbers carry ~5pp
  absolute classifier noise per primitive.
- Mean episodes/trace differ across corpora (Base puzzles {total_episodes('Base','puzzles'):.1f},
  Base math {total_episodes('Base','math'):.1f}, DSR puzzles {total_episodes('DSR','puzzles'):.1f},
  SFT puzzles {total_episodes('SFT','puzzles'):.1f}, SFT math {total_episodes('SFT','math'):.1f}).
  Per-trace mean count is the right unit for vocabulary-presence claims;
  episode-fraction normalises trace-length differences (see
  `dsr_vs_sft_primitives.md` for the fraction-normalised view).
"""
    OUT_REPORT.parent.mkdir(parents=True, exist_ok=True)
    OUT_REPORT.write_text(body)
    return OUT_REPORT


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--motif_scheme", default="k3",
                    choices=("k3", "k5", "k3+5", "all"),
                    help="Which motif scheme to use for the canonical figure. "
                         "`all` renders one figure per scheme.")
    ap.add_argument("--layout", default="2x2", choices=("2x2", "1x4"))
    ap.add_argument("--use_cache", action="store_true",
                    help="Reuse existing CSV instead of recomputing.")
    ap.add_argument("--write_report", action="store_true",
                    help="Also emit reports/neurips/section4_1_distributions.md")
    args = ap.parse_args()

    nrows, ncols = (2, 2) if args.layout == "2x2" else (1, 4)

    # ------ Numbers
    if args.use_cache and OUT_CSV.exists():
        df = pd.read_csv(OUT_CSV)
        print(f"loaded cached numbers from {OUT_CSV} ({len(df)} rows)")
    else:
        print("computing distributions...")
        df = compute_distributions()
        OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
        df.to_csv(OUT_CSV, index=False)
        print(f"wrote {OUT_CSV} ({len(df)} rows)")

    # ------ Figures
    schemes = ["k3", "k5", "k3+5"] if args.motif_scheme == "all" else [args.motif_scheme]
    figs = []
    for scheme in schemes:
        out = render_figure(df, scheme, nrows=nrows, ncols=ncols)
        print(f"wrote {out}")
        figs.append(out)

    # Canonical filename: copy of the chosen scheme's figure (default k=3).
    canonical_scheme = "k3" if args.motif_scheme == "all" else args.motif_scheme
    canonical_idx = (["k3","k5","k3+5"].index(canonical_scheme)
                     if args.motif_scheme == "all" else 0)
    canonical = OUT_FIG_DIR / "fig_primitive_motif_distributions.png"
    shutil.copy2(figs[canonical_idx], canonical)
    print(f"wrote {canonical}  (canonical = {canonical_scheme})")

    # ------ Report
    if args.write_report:
        rp = write_report(df, canonical_scheme, canonical)
        print(f"wrote {rp}")


if __name__ == "__main__":
    main()
