"""§6 motif counts at the 4 main checkpoints.

Three motifs:
  k=3  : CHECK -> COMPUTE -> CHECK              (basic verify loop)
  k=5  : CHECK -> COMPUTE -> CHECK -> COMPUTE -> CHECK   (longer alternation)
  k=7  : CHECK -> COMPUTE -> CHECK -> COMPUTE -> CHECK -> ENUMERATE -> COMPUTE
         (deep motif candidate from motif_kgram_analysis.md §3)

For each motif and each checkpoint we report mean motif count per trace with a
95% bootstrap CI. Data: primitive_sequence column from the trace-level parquets.

Outputs:
  results/within_problem_paired/motif_examples_counts.csv  (table for the report)
  writing/neurips_paper/figures/fig_motif_examples.png      (3-panel figure)
"""

import json
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

PATHS = {
    "Base": "/tmp/v90_csvs/v90_base_traces.parquet",
    "SFT": "/tmp/v90_csvs/v90_sft_traces.parquet",
    "Vanilla GSPO": "/tmp/v90_csvs/v90_gspo_traces.parquet",
    "Novelty bonus": "/tmp/v90_csvs/v90_prod_s15_traces.parquet",
}

# Pastel palette — keep in sync with CLAUDE.md "Figure colour scheme".
PALETTE = {
    "Base": "#cfd8dc",
    "SFT": "#c8e6c9",
    "Vanilla GSPO": "#bbdefb",
    "Novelty bonus": "#f8bbd0",
}
EDGE = {
    "Base": "#607d8b",
    "SFT": "#2e7d32",
    "Vanilla GSPO": "#1565c0",
    "Novelty bonus": "#c2185b",
}

MOTIFS = [
    ("k=3: CH→CO→CH",
     ("CHECK", "COMPUTE", "CHECK")),
    ("k=5: CH→CO→CH→CO→CH",
     ("CHECK", "COMPUTE", "CHECK", "COMPUTE", "CHECK")),
    ("k=7: CH→CO→CH→CO→CH→EN→CO",
     ("CHECK", "COMPUTE", "CHECK", "COMPUTE", "CHECK", "ENUMERATE", "COMPUTE")),
]

_FIG_DIR = Path(os.environ.get("FIG_OUT_DIR", "writing/neurips_paper/figures"))
OUT_FIG = _FIG_DIR / "fig_motif_examples.png"
OUT_CSV = Path("results/within_problem_paired/motif_examples_counts.csv")
OUT_FIG.parent.mkdir(parents=True, exist_ok=True)
OUT_CSV.parent.mkdir(parents=True, exist_ok=True)


def parse_seq(s):
    if s is None:
        return []
    try:
        return json.loads(s)
    except (json.JSONDecodeError, TypeError):
        return []


def count_motif(seq, motif):
    """Count non-overlapping occurrences of ``motif`` (a tuple) in ``seq``."""
    n, m = len(seq), len(motif)
    if m == 0 or n < m:
        return 0
    c = i = 0
    while i <= n - m:
        if tuple(seq[i:i + m]) == motif:
            c += 1
            i += m  # non-overlapping; switch to ``i += 1`` for overlapping count
        else:
            i += 1
    return c


def boot_mean_ci(x, n_boot=10_000, seed=42):
    x = np.asarray(x, dtype=float)
    if len(x) == 0:
        return float("nan"), float("nan"), float("nan")
    rng = np.random.default_rng(seed)
    boot = rng.choice(x, size=(n_boot, len(x)), replace=True).mean(axis=1)
    lo, hi = np.percentile(boot, [2.5, 97.5])
    return float(x.mean()), float(lo), float(hi)


def main():
    rows = []
    cache = {}  # checkpoint -> list[parsed_seq]
    for ckpt in PATHS:
        df = pd.read_parquet(PATHS[ckpt])
        cache[ckpt] = df["primitive_sequence"].apply(parse_seq).tolist()

    for motif_label, motif in MOTIFS:
        for ckpt in PATHS:
            counts = [count_motif(s, motif) for s in cache[ckpt]]
            mean, lo, hi = boot_mean_ci(counts)
            rows.append({
                "motif": motif_label,
                "checkpoint": ckpt,
                "mean_count_per_trace": mean,
                "ci_low": lo,
                "ci_high": hi,
                "n_traces": len(counts),
                "n_traces_with_motif": int(sum(c > 0 for c in counts)),
            })
            print(f"{motif_label:60s}  {ckpt:18s}  mean={mean:6.3f}  "
                  f"95%CI=[{lo:.3f},{hi:.3f}]  n_with_motif={sum(c>0 for c in counts):4d}/"
                  f"{len(counts)}")

    out = pd.DataFrame(rows)
    out.to_csv(OUT_CSV, index=False)
    print(f"\nwrote {OUT_CSV}")

    # ---------- Figure ----------
    ckpt_order = list(PATHS.keys())
    colors = [PALETTE[c] for c in ckpt_order]
    edges = [EDGE[c] for c in ckpt_order]
    fig, axes = plt.subplots(1, 3, figsize=(14, 4.2))
    for ax, (motif_label, motif) in zip(axes, MOTIFS):
        sub = out[out["motif"] == motif_label].set_index("checkpoint")
        means = [sub.loc[c, "mean_count_per_trace"] for c in ckpt_order]
        lows = [sub.loc[c, "ci_low"] for c in ckpt_order]
        highs = [sub.loc[c, "ci_high"] for c in ckpt_order]
        errs = [[m - lo for m, lo in zip(means, lows)],
                [hi - m for m, hi in zip(means, highs)]]
        xs = np.arange(len(ckpt_order))
        ax.bar(xs, means, yerr=errs, color=colors, edgecolor=edges, linewidth=1.0,
               capsize=4)
        for x, m in zip(xs, means):
            ax.text(x, m + 0.02 * max(highs + [0.1]), f"{m:.2f}",
                    ha="center", va="bottom", fontsize=9)
        ax.set_xticks(xs)
        ax.set_xticklabels(ckpt_order, rotation=20, ha="right", fontsize=9)
        ax.set_ylabel("Mean count per trace")
        ax.set_title(motif_label, fontsize=10)
        ax.grid(True, axis="y", alpha=0.3)
    fig.suptitle("Specific motif counts confirm the depth-of-exploitation pattern",
                 fontsize=11)
    fig.tight_layout()
    fig.savefig(OUT_FIG, dpi=160, bbox_inches="tight")
    print(f"wrote {OUT_FIG}")


if __name__ == "__main__":
    main()
