"""Motif (k-gram) enrichment analysis for reasoning primitive chains.

Analyses:
1. Global motif enrichment — which k-grams discriminate solved vs unsolved?
2. Per-problem motif consistency — controlling for problem difficulty
3. Cross-checkpoint comparison — what motifs does RL gain over SFT?
4. Predictive power — logistic regression on k-gram features

Usage:
    python -m analysis.exploration.motif_analysis \
        --checkpoints v90_gspo_puzzles v90_sft_puzzles \
        --k-range 2 8 --sft-baseline v90_sft_puzzles --verbose
"""
from __future__ import annotations

import argparse
import json
import logging
import sys
from collections import Counter
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
from scipy.spatial.distance import jensenshannon
from scipy.special import expit
from scipy.stats import chi2_contingency, false_discovery_control, fisher_exact

from .sequence_diversity import kmer_counts

log = logging.getLogger(__name__)

RESULTS_BASE = Path("results/exploration_analysis")

# Sentinel token used to break k-grams at filtered-out primitive positions.
# Any k-gram containing this token is discarded during extraction.
_BREAK = "__BREAK__"

# ---------------------------------------------------------------------------
# Section 0: Data loading & utilities
# ---------------------------------------------------------------------------


def _filter_sequence(seq: list[str], exclude: set[str]) -> list[str]:
    """Replace excluded primitives with boundary tokens to avoid false adjacencies.

    E.g. with exclude={"OTHER"}:
        [SETUP, CHECK, OTHER, OTHER, CHECK, SETUP]
     -> [SETUP, CHECK, __BREAK__, CHECK, SETUP]

    Consecutive excluded spans collapse into a single __BREAK__.
    """
    if not exclude:
        return seq
    result = []
    in_excluded = False
    for p in seq:
        if p in exclude:
            if not in_excluded:
                result.append(_BREAK)
                in_excluded = True
        else:
            result.append(p)
            in_excluded = False
    return result


def load_checkpoint_data(
    checkpoint_name: str,
    exclude_primitives: Optional[set[str]] = None,
) -> pd.DataFrame:
    """Load trace-level metrics for one checkpoint."""
    path = RESULTS_BASE / checkpoint_name / "trace_level_metrics.parquet"
    if not path.exists():
        raise FileNotFoundError(f"No parquet at {path}")
    df = pd.read_parquet(path)
    df["primitive_sequence"] = df["primitive_sequence"].apply(json.loads)
    if exclude_primitives:
        df["primitive_sequence"] = df["primitive_sequence"].apply(
            lambda seq: _filter_sequence(seq, exclude_primitives)
        )
    df["seq_len"] = df["primitive_sequence"].apply(len)
    df["dataset"] = checkpoint_name
    return df


def load_all_checkpoints(
    names: list[str],
    exclude_primitives: Optional[set[str]] = None,
) -> pd.DataFrame:
    """Load and concatenate multiple checkpoints."""
    dfs = []
    for name in names:
        try:
            dfs.append(load_checkpoint_data(name, exclude_primitives))
            log.info("Loaded %s: %d traces", name, len(dfs[-1]))
        except FileNotFoundError:
            log.warning("Skipping %s — parquet not found", name)
    if not dfs:
        raise RuntimeError("No checkpoint data loaded")
    return pd.concat(dfs, ignore_index=True)


def kgram_to_str(kg: tuple) -> str:
    """Pretty-print a k-gram tuple, e.g. ('COMPUTE','CHECK') -> 'COMPUTE→CHECK'."""
    return "→".join(kg)


def extract_all_kgrams(df: pd.DataFrame, k: int) -> tuple[pd.DataFrame, int]:
    """Extract k-grams from all traces, returning per-trace counts.

    K-grams containing the _BREAK sentinel are silently discarded so that
    filtered-out primitives do not create false adjacencies.

    Returns (kgram_df, n_skipped) where kgram_df has columns:
        trace_idx, kgram (tuple), count, n_positions, correct, dataset, doc_id
    """
    rows = []
    n_skipped = 0
    for idx, row in df.iterrows():
        seq = row["primitive_sequence"]
        if len(seq) < k:
            n_skipped += 1
            continue
        counts = kmer_counts(seq, k)
        # Filter out k-grams that span a break boundary
        counts = {kg: c for kg, c in counts.items() if _BREAK not in kg}
        if not counts:
            n_skipped += 1
            continue
        n_pos = sum(counts.values())
        for kg, cnt in counts.items():
            rows.append((idx, kg, cnt, n_pos, row["correct"],
                         row["dataset"], row.get("doc_id", -1)))
    cols = ["trace_idx", "kgram", "count", "n_positions", "correct",
            "dataset", "doc_id"]
    return pd.DataFrame(rows, columns=cols), n_skipped


def build_observed_vocab(kgram_df: pd.DataFrame, min_count: int) -> list[tuple]:
    """Return k-grams that appear in at least min_count total occurrences."""
    totals = kgram_df.groupby("kgram")["count"].sum()
    return sorted([kg for kg, c in totals.items() if c >= min_count])


# ---------------------------------------------------------------------------
# Section 1: Global motif enrichment
# ---------------------------------------------------------------------------


def compute_kgram_enrichment(
    df: pd.DataFrame,
    k: int,
    min_count: int = 5,
    pseudocount: float = 0.5,
    checkpoint_label: str = "all",
) -> pd.DataFrame:
    """Compute log-odds enrichment for each k-gram, solved vs unsolved.

    Returns DataFrame with columns:
        kgram_str, k, checkpoint, count_solved, count_unsolved,
        freq_solved, freq_unsolved, log_odds, p_value, p_adj, significant
    """
    kgram_df, n_skipped = extract_all_kgrams(df, k)
    n_usable = len(df) - n_skipped
    log.info("  k=%d: %d usable traces (%d skipped, seq<k), %d kgram occurrences",
             k, n_usable, n_skipped, len(kgram_df))

    if kgram_df.empty:
        return pd.DataFrame()

    vocab = build_observed_vocab(kgram_df, min_count)
    log.info("  k=%d: %d k-grams pass min_count=%d", k, len(vocab), min_count)
    if not vocab:
        return pd.DataFrame()

    # Aggregate counts per k-gram, split by correct
    solved = kgram_df[kgram_df["correct"]]
    unsolved = kgram_df[~kgram_df["correct"]]

    s_counts = solved.groupby("kgram")["count"].sum()
    u_counts = unsolved.groupby("kgram")["count"].sum()

    # Total k-gram positions (denominator for frequency)
    # Sum n_positions per unique trace, not per kgram row
    s_traces = kgram_df[kgram_df["correct"]].drop_duplicates("trace_idx")
    u_traces = kgram_df[~kgram_df["correct"]].drop_duplicates("trace_idx")
    total_s = s_traces["n_positions"].sum() if len(s_traces) > 0 else 0
    total_u = u_traces["n_positions"].sum() if len(u_traces) > 0 else 0

    if total_s == 0 or total_u == 0:
        log.warning("  k=%d: no solved (%d) or unsolved (%d) positions — skipping",
                    k, total_s, total_u)
        return pd.DataFrame()

    V = len(vocab)
    rows = []
    p_values = []

    for kg in vocab:
        cs = int(s_counts.get(kg, 0))
        cu = int(u_counts.get(kg, 0))
        fs = (cs + pseudocount) / (total_s + pseudocount * V)
        fu = (cu + pseudocount) / (total_u + pseudocount * V)
        lo = np.log2(fs / fu) if fu > 0 else 0.0

        # Fisher's exact or chi-squared
        a, b = cs, total_s - cs
        c, d = cu, total_u - cu
        # Use chi-squared for large counts (faster)
        if min(a, b, c, d) > 5:
            try:
                _, pv, _, _ = chi2_contingency([[a, b], [c, d]],
                                               correction=True)
            except ValueError:
                pv = 1.0
        else:
            try:
                _, pv = fisher_exact([[a, b], [c, d]])
            except ValueError:
                pv = 1.0

        p_values.append(pv)
        rows.append({
            "kgram_str": kgram_to_str(kg),
            "kgram": kg,
            "k": k,
            "checkpoint": checkpoint_label,
            "count_solved": cs,
            "count_unsolved": cu,
            "freq_solved": cs / total_s,
            "freq_unsolved": cu / total_u,
            "log_odds": lo,
            "p_value": pv,
        })

    result = pd.DataFrame(rows)
    if len(result) == 0:
        return result

    # BH FDR correction
    p_arr = np.array(p_values)
    try:
        p_adj = false_discovery_control(p_arr, method="bh")
    except Exception:
        p_adj = p_arr  # fallback
    result["p_adj"] = p_adj
    result["significant"] = result["p_adj"] < 0.05

    n_sig = result["significant"].sum()
    log.info("  k=%d: %d/%d significant (FDR<0.05), solved_positions=%d, unsolved_positions=%d",
             k, n_sig, len(result), total_s, total_u)

    return result.sort_values("log_odds", ascending=False).reset_index(drop=True)


def global_enrichment_all_k(
    df: pd.DataFrame,
    k_range: tuple[int, int],
    min_count: int,
    min_count_high_k: int,
    pseudocount: float,
    checkpoint_label: str = "all",
) -> pd.DataFrame:
    """Run enrichment for all k values, return combined DataFrame."""
    results = []
    for k in range(k_range[0], k_range[1] + 1):
        mc = min_count if k <= 5 else min_count_high_k
        res = compute_kgram_enrichment(df, k, mc, pseudocount, checkpoint_label)
        if len(res) > 0:
            results.append(res)
    return pd.concat(results, ignore_index=True) if results else pd.DataFrame()


# ---------------------------------------------------------------------------
# Section 2: Per-problem motif analysis
# ---------------------------------------------------------------------------


def per_problem_enrichment(
    df: pd.DataFrame,
    k: int,
    min_solved: int = 3,
    min_count: int = 2,
    pseudocount: float = 0.5,
) -> pd.DataFrame:
    """Compute within-problem k-gram enrichment, then aggregate across problems.

    Returns DataFrame with:
        kgram_str, k, n_problems_tested, n_enriched_solved, frac_enriched,
        mean_log_odds, median_log_odds
    """
    kgram_df, _ = extract_all_kgrams(df, k)
    if kgram_df.empty:
        return pd.DataFrame()

    # Find problems with enough solved traces
    problem_stats = df[df["seq_len"] >= k].groupby("doc_id")["correct"].agg(
        n_solved="sum", n_total="count"
    )
    eligible = problem_stats[problem_stats["n_solved"] >= min_solved].index.tolist()
    log.info("  k=%d per-problem: %d/%d problems with >=%d solved traces",
             k, len(eligible), len(problem_stats), min_solved)
    if not eligible:
        return pd.DataFrame()

    # For each eligible problem, compute enrichment
    kgram_log_odds: dict[tuple, list[float]] = {}  # kgram -> list of per-problem log-odds

    for doc_id in eligible:
        prob_kg = kgram_df[kgram_df["doc_id"] == doc_id]
        if prob_kg.empty:
            continue
        vocab = build_observed_vocab(prob_kg, min_count)
        if not vocab:
            continue

        solved_kg = prob_kg[prob_kg["correct"]]
        unsolved_kg = prob_kg[~prob_kg["correct"]]
        s_counts = solved_kg.groupby("kgram")["count"].sum()
        u_counts = unsolved_kg.groupby("kgram")["count"].sum()

        s_traces = prob_kg[prob_kg["correct"]].drop_duplicates("trace_idx")
        u_traces = prob_kg[~prob_kg["correct"]].drop_duplicates("trace_idx")
        total_s = s_traces["n_positions"].sum() if len(s_traces) > 0 else 0
        total_u = u_traces["n_positions"].sum() if len(u_traces) > 0 else 0
        if total_s == 0 or total_u == 0:
            continue

        V = len(vocab)
        for kg in vocab:
            cs = int(s_counts.get(kg, 0))
            cu = int(u_counts.get(kg, 0))
            fs = (cs + pseudocount) / (total_s + pseudocount * V)
            fu = (cu + pseudocount) / (total_u + pseudocount * V)
            lo = np.log2(fs / fu) if fu > 0 else 0.0
            kgram_log_odds.setdefault(kg, []).append(lo)

    # Aggregate
    rows = []
    for kg, odds_list in kgram_log_odds.items():
        n_tested = len(odds_list)
        n_enriched = sum(1 for x in odds_list if x > 0)
        rows.append({
            "kgram_str": kgram_to_str(kg),
            "kgram": kg,
            "k": k,
            "n_problems_tested": n_tested,
            "n_enriched_solved": n_enriched,
            "frac_enriched": n_enriched / n_tested if n_tested > 0 else 0.0,
            "mean_log_odds": np.mean(odds_list),
            "median_log_odds": np.median(odds_list),
        })

    result = pd.DataFrame(rows)
    if len(result) > 0:
        result = result.sort_values("frac_enriched", ascending=False).reset_index(drop=True)
        top = result.head(3)
        log.info("  k=%d per-problem: %d motifs tracked, top enriched: %s",
                 k, len(result),
                 ", ".join(f"{r['kgram_str']}({r['frac_enriched']:.0%})"
                           for _, r in top.iterrows()))
    return result


def per_problem_all_k(
    df: pd.DataFrame,
    k_range: tuple[int, int],
    min_solved: int,
    pseudocount: float,
) -> pd.DataFrame:
    """Per-problem enrichment for all k values."""
    results = []
    for k in range(k_range[0], k_range[1] + 1):
        mc = 2 if k <= 5 else 3
        res = per_problem_enrichment(df, k, min_solved, mc, pseudocount)
        if len(res) > 0:
            results.append(res)
    return pd.concat(results, ignore_index=True) if results else pd.DataFrame()


# ---------------------------------------------------------------------------
# Section 3: Cross-checkpoint comparison
# ---------------------------------------------------------------------------


def checkpoint_kgram_profile(
    df: pd.DataFrame,
    k: int,
    min_count: int = 5,
) -> dict[tuple, float]:
    """Compute normalized k-gram frequency profile for a checkpoint."""
    kgram_df, _ = extract_all_kgrams(df, k)
    if kgram_df.empty:
        return {}
    total_positions = kgram_df.drop_duplicates("trace_idx")["n_positions"].sum()
    if total_positions == 0:
        return {}
    counts = kgram_df.groupby("kgram")["count"].sum()
    return {kg: c / total_positions for kg, c in counts.items() if c >= min_count}


def cross_checkpoint_comparison(
    sft_df: pd.DataFrame,
    rl_df: pd.DataFrame,
    k: int,
    enrichment_df: pd.DataFrame,
    min_count: int = 5,
    ratio_threshold: float = 2.0,
    sft_label: str = "sft",
    rl_label: str = "rl",
) -> pd.DataFrame:
    """Compare k-gram profiles between SFT and RL checkpoints.

    Identifies RL-gained and RL-lost motifs. Cross-references with enrichment
    data to flag motifs that are both RL-gained AND enriched in solved traces.
    """
    sft_profile = checkpoint_kgram_profile(sft_df, k, min_count)
    rl_profile = checkpoint_kgram_profile(rl_df, k, min_count)

    all_kgrams = set(sft_profile) | set(rl_profile)
    if not all_kgrams:
        return pd.DataFrame()

    # Build enrichment lookup for the RL checkpoint
    enriched_set = set()
    if len(enrichment_df) > 0:
        rl_enrich = enrichment_df[
            (enrichment_df["k"] == k) &
            (enrichment_df["checkpoint"] == rl_label)
        ]
        enriched_set = set(
            rl_enrich[rl_enrich["log_odds"] > 0]["kgram"].tolist()
        )

    # Pseudocount for ratio to avoid division by zero
    pseudo = 1e-6
    rows = []
    for kg in all_kgrams:
        fs = sft_profile.get(kg, 0.0)
        fr = rl_profile.get(kg, 0.0)
        ratio = (fr + pseudo) / (fs + pseudo)

        if ratio >= ratio_threshold:
            category = "RL-gained"
        elif ratio <= 1.0 / ratio_threshold:
            category = "RL-lost"
        else:
            category = "stable"

        rows.append({
            "kgram_str": kgram_to_str(kg),
            "kgram": kg,
            "k": k,
            "freq_sft": fs,
            "freq_rl": fr,
            "ratio_rl_sft": ratio,
            "category": category,
            "enriched_in_solved": kg in enriched_set,
            "sft_checkpoint": sft_label,
            "rl_checkpoint": rl_label,
        })

    result = pd.DataFrame(rows)

    # JSD between profiles
    union_kg = sorted(all_kgrams)
    p = np.array([sft_profile.get(kg, 0.0) for kg in union_kg])
    q = np.array([rl_profile.get(kg, 0.0) for kg in union_kg])
    p_sum, q_sum = p.sum(), q.sum()
    if p_sum > 0 and q_sum > 0:
        p /= p_sum
        q /= q_sum
        jsd = jensenshannon(p, q)
    else:
        jsd = float("nan")
    result.attrs["jsd"] = jsd

    gained = result[result["category"] == "RL-gained"]
    gained_and_enriched = gained[gained["enriched_in_solved"]]
    log.info("  k=%d cross-ckpt: %d total, %d gained, %d gained+enriched, JSD=%.4f",
             k, len(result), len(gained), len(gained_and_enriched), jsd)

    return result.sort_values("ratio_rl_sft", ascending=False).reset_index(drop=True)


def cross_checkpoint_all_k(
    sft_df: pd.DataFrame,
    rl_df: pd.DataFrame,
    enrichment_df: pd.DataFrame,
    k_range: tuple[int, int],
    min_count: int,
    min_count_high_k: int,
    sft_label: str,
    rl_label: str,
) -> pd.DataFrame:
    """Cross-checkpoint comparison for all k values."""
    results = []
    jsd_by_k = {}
    for k in range(k_range[0], k_range[1] + 1):
        mc = min_count if k <= 5 else min_count_high_k
        res = cross_checkpoint_comparison(
            sft_df, rl_df, k, enrichment_df, mc,
            sft_label=sft_label, rl_label=rl_label)
        if len(res) > 0:
            jsd_by_k[k] = res.attrs.get("jsd", float("nan"))
            results.append(res)
    combined = pd.concat(results, ignore_index=True) if results else pd.DataFrame()
    combined.attrs["jsd_by_k"] = jsd_by_k
    return combined


# ---------------------------------------------------------------------------
# Section 4: Predictive power (logistic regression)
# ---------------------------------------------------------------------------


def build_feature_matrix(
    df: pd.DataFrame,
    k: int,
    min_count: int = 5,
) -> tuple[np.ndarray, np.ndarray, list[tuple]]:
    """Build (X, y, vocab) feature matrix from k-gram frequencies.

    X[i, j] = normalized frequency of vocab[j] in trace i.
    y[i] = 1 if solved, 0 otherwise.
    """
    kgram_df, n_skipped = extract_all_kgrams(df, k)
    vocab = build_observed_vocab(kgram_df, min_count)
    if not vocab:
        return np.empty((0, 0)), np.empty(0), []

    vocab_idx = {kg: j for j, kg in enumerate(vocab)}

    # Build per-trace feature vectors
    usable = df[df["seq_len"] >= k].copy()
    usable_indices = usable.index.tolist()
    n_traces = len(usable_indices)
    n_features = len(vocab)

    X = np.zeros((n_traces, n_features), dtype=np.float64)
    y = np.zeros(n_traces, dtype=np.float64)

    trace_idx_to_row = {idx: i for i, idx in enumerate(usable_indices)}

    for _, row in kgram_df.iterrows():
        tidx = row["trace_idx"]
        if tidx not in trace_idx_to_row:
            continue
        i = trace_idx_to_row[tidx]
        kg = row["kgram"]
        if kg in vocab_idx:
            j = vocab_idx[kg]
            X[i, j] = row["count"] / row["n_positions"]

    for i, idx in enumerate(usable_indices):
        y[i] = float(usable.loc[idx, "correct"])

    return X, y, vocab


def _stratified_kfold(y: np.ndarray, n_folds: int = 5, seed: int = 42) -> list[tuple]:
    """Manual stratified k-fold split. Returns list of (train_idx, test_idx)."""
    rng = np.random.RandomState(seed)
    pos_idx = np.where(y == 1)[0]
    neg_idx = np.where(y == 0)[0]
    rng.shuffle(pos_idx)
    rng.shuffle(neg_idx)

    folds_pos = np.array_split(pos_idx, n_folds)
    folds_neg = np.array_split(neg_idx, n_folds)

    splits = []
    for i in range(n_folds):
        test = np.concatenate([folds_pos[i], folds_neg[i]])
        train = np.concatenate([
            np.concatenate([folds_pos[j] for j in range(n_folds) if j != i]),
            np.concatenate([folds_neg[j] for j in range(n_folds) if j != i]),
        ])
        splits.append((train, test))
    return splits


def _auc_score(y_true: np.ndarray, y_score: np.ndarray) -> float:
    """Compute AUC using the rank-based formula."""
    n = len(y_true)
    if n == 0:
        return 0.5
    n_pos = y_true.sum()
    n_neg = n - n_pos
    if n_pos == 0 or n_neg == 0:
        return 0.5
    # Rank-based AUC: (sum_of_positive_ranks - n_pos*(n_pos+1)/2) / (n_pos*n_neg)
    order = np.argsort(y_score)
    ranks = np.empty(n)
    ranks[order] = np.arange(1, n + 1).astype(float)
    # Handle ties: average ranks
    sorted_scores = y_score[order]
    i = 0
    while i < n:
        j = i + 1
        while j < n and sorted_scores[j] == sorted_scores[i]:
            j += 1
        if j > i + 1:
            avg_rank = (i + 1 + j) / 2.0
            for idx in range(i, j):
                ranks[order[idx]] = avg_rank
        i = j

    sum_pos_ranks = ranks[y_true == 1].sum()
    auc = (sum_pos_ranks - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
    return auc


def logistic_regression_cv(
    X: np.ndarray,
    y: np.ndarray,
    n_folds: int = 5,
    C: float = 1.0,
    max_iter: int = 500,
) -> dict:
    """Logistic regression with stratified CV. Returns AUC stats and coefficients.

    Uses class weights inversely proportional to frequency.
    """
    from scipy.optimize import minimize as sp_minimize

    n_pos = y.sum()
    n_neg = len(y) - n_pos
    if n_pos < 5 or n_neg < 5:
        return {"auc_mean": 0.5, "auc_std": 0.0, "coefficients": None,
                "n_pos": int(n_pos), "n_neg": int(n_neg), "note": "too few positives"}

    # Class weights
    w_pos = len(y) / (2 * n_pos)
    w_neg = len(y) / (2 * n_neg)
    sample_weights = np.where(y == 1, w_pos, w_neg)

    def objective(w, X_t, y_t, sw):
        z = X_t @ w[:-1] + w[-1]
        # Clip for numerical stability
        z = np.clip(z, -30, 30)
        ll = sw * (y_t * z - np.logaddexp(0, z))
        reg = 0.5 / C * np.dot(w[:-1], w[:-1])
        return -ll.sum() + reg

    def gradient(w, X_t, y_t, sw):
        z = X_t @ w[:-1] + w[-1]
        z = np.clip(z, -30, 30)
        p = expit(z)
        residual = sw * (p - y_t)
        grad_w = X_t.T @ residual + w[:-1] / C
        grad_b = residual.sum()
        return np.append(grad_w, grad_b)

    splits = _stratified_kfold(y, n_folds)
    aucs = []
    all_coefs = []

    for train_idx, test_idx in splits:
        X_tr, y_tr = X[train_idx], y[train_idx]
        X_te, y_te = X[test_idx], y[test_idx]
        sw_tr = sample_weights[train_idx]

        n_feat = X_tr.shape[1]
        w0 = np.zeros(n_feat + 1)

        result = sp_minimize(
            objective, w0, args=(X_tr, y_tr, sw_tr),
            jac=gradient, method="L-BFGS-B",
            options={"maxiter": max_iter, "ftol": 1e-8},
        )
        w = result.x
        coefs = w[:-1]
        intercept = w[-1]

        # Predict on test
        z_te = X_te @ coefs + intercept
        y_score = expit(z_te)
        auc = _auc_score(y_te, y_score)
        aucs.append(auc)
        all_coefs.append(coefs)

    avg_coefs = np.mean(all_coefs, axis=0)

    return {
        "auc_mean": np.mean(aucs),
        "auc_std": np.std(aucs),
        "auc_folds": aucs,
        "coefficients": avg_coefs,
        "n_pos": int(n_pos),
        "n_neg": int(n_neg),
    }


def predictive_analysis(
    df: pd.DataFrame,
    k_range: tuple[int, int],
    min_count: int,
    min_count_high_k: int,
    n_folds: int = 5,
    checkpoint_label: str = "all",
) -> list[dict]:
    """Run predictive analysis for each k."""
    results = []
    for k in range(k_range[0], k_range[1] + 1):
        mc = min_count if k <= 5 else min_count_high_k
        X, y, vocab = build_feature_matrix(df, k, mc)
        if X.shape[0] == 0 or X.shape[1] == 0:
            log.info("  k=%d predictive: no features — skipping", k)
            continue
        log.info("  k=%d predictive: %d traces, %d features, %.1f%% positive",
                 k, X.shape[0], X.shape[1], 100 * y.mean())
        lr = logistic_regression_cv(X, y, n_folds)
        lr["k"] = k
        lr["n_features"] = X.shape[1]
        lr["checkpoint"] = checkpoint_label

        # Top features
        if lr["coefficients"] is not None:
            top_idx = np.argsort(np.abs(lr["coefficients"]))[::-1][:10]
            lr["top_features"] = [
                {"kgram": kgram_to_str(vocab[i]),
                 "coef": float(lr["coefficients"][i])}
                for i in top_idx
            ]
        else:
            lr["top_features"] = []

        log.info("  k=%d predictive: AUC=%.3f±%.3f", k, lr["auc_mean"], lr["auc_std"])
        results.append(lr)
    return results


# ---------------------------------------------------------------------------
# Section 5: Report generation
# ---------------------------------------------------------------------------


def _fmt_table(df: pd.DataFrame, columns: list[str], max_rows: int = 30) -> str:
    """Format a DataFrame as a markdown table."""
    if len(df) == 0:
        return "*No data*\n"
    sub = df[columns].head(max_rows)
    lines = []
    # Header
    lines.append("| " + " | ".join(columns) + " |")
    lines.append("| " + " | ".join("---" for _ in columns) + " |")
    for _, row in sub.iterrows():
        cells = []
        for c in columns:
            v = row[c]
            if isinstance(v, float):
                if abs(v) < 0.001 and v != 0:
                    cells.append(f"{v:.2e}")
                else:
                    cells.append(f"{v:.4f}")
            elif isinstance(v, (np.floating,)):
                cells.append(f"{float(v):.4f}")
            elif isinstance(v, bool):
                cells.append("yes" if v else "")
            else:
                cells.append(str(v))
        lines.append("| " + " | ".join(cells) + " |")
    if len(df) > max_rows:
        lines.append(f"\n*... {len(df) - max_rows} more rows (see CSV)*\n")
    return "\n".join(lines) + "\n"


def generate_report(
    enrichment_df: pd.DataFrame,
    per_problem_df: pd.DataFrame,
    cross_ckpt_df: pd.DataFrame,
    predictive_results: list[dict],
    args: argparse.Namespace,
) -> str:
    """Generate markdown report."""
    lines = ["# Motif (k-gram) Enrichment Analysis\n"]

    # Summary
    lines.append("## Configuration\n")
    lines.append(f"- **Checkpoints**: {', '.join(args.checkpoints)}")
    lines.append(f"- **k range**: {args.k_range[0]}–{args.k_range[1]}")
    lines.append(f"- **Min count**: {args.min_count} (k≤5), {args.min_count_high_k} (k>5)")
    if args.exclude_primitives:
        lines.append(f"- **Excluded primitives**: {', '.join(args.exclude_primitives)} "
                      "(replaced with boundary tokens)")
    if args.sft_baseline:
        lines.append(f"- **SFT baseline**: {args.sft_baseline}")
    lines.append("")

    # Section 1: Global enrichment
    lines.append("## 1. Global Motif Enrichment\n")
    if len(enrichment_df) > 0:
        for ckpt in enrichment_df["checkpoint"].unique():
            ckpt_df = enrichment_df[enrichment_df["checkpoint"] == ckpt]
            lines.append(f"### Checkpoint: {ckpt}\n")
            for k in sorted(ckpt_df["k"].unique()):
                k_df = ckpt_df[ckpt_df["k"] == k]
                n_sig = k_df["significant"].sum()
                lines.append(f"#### k={k} ({len(k_df)} motifs, {n_sig} significant)\n")

                # Top enriched in solved
                top_solved = k_df[k_df["log_odds"] > 0].head(15)
                if len(top_solved) > 0:
                    lines.append("**Top enriched in SOLVED:**\n")
                    lines.append(_fmt_table(
                        top_solved,
                        ["kgram_str", "count_solved", "count_unsolved",
                         "freq_solved", "freq_unsolved", "log_odds", "p_adj", "significant"],
                        max_rows=15,
                    ))

                # Top enriched in unsolved
                top_unsolved = k_df[k_df["log_odds"] < 0].sort_values("log_odds").head(10)
                if len(top_unsolved) > 0:
                    lines.append("**Top enriched in UNSOLVED:**\n")
                    lines.append(_fmt_table(
                        top_unsolved,
                        ["kgram_str", "count_solved", "count_unsolved",
                         "freq_solved", "freq_unsolved", "log_odds", "p_adj", "significant"],
                        max_rows=10,
                    ))
    else:
        lines.append("*No enrichment data computed.*\n")

    # Section 2: Per-problem consistency
    lines.append("## 2. Per-Problem Motif Consistency\n")
    if len(per_problem_df) > 0:
        lines.append("Motifs consistently enriched in solved traces across multiple problems:\n")
        for k in sorted(per_problem_df["k"].unique()):
            k_df = per_problem_df[per_problem_df["k"] == k]
            # Show motifs tested in ≥3 problems with frac_enriched > 0.5
            top = k_df[(k_df["n_problems_tested"] >= 3) &
                       (k_df["frac_enriched"] > 0.5)].head(15)
            if len(top) > 0:
                lines.append(f"#### k={k}\n")
                lines.append(_fmt_table(
                    top,
                    ["kgram_str", "n_problems_tested", "n_enriched_solved",
                     "frac_enriched", "mean_log_odds", "median_log_odds"],
                    max_rows=15,
                ))
            # Also show motifs consistently enriched in unsolved
            bottom = k_df[(k_df["n_problems_tested"] >= 3) &
                          (k_df["frac_enriched"] < 0.3)].sort_values(
                              "frac_enriched").head(10)
            if len(bottom) > 0:
                lines.append(f"**Consistently enriched in UNSOLVED (k={k}):**\n")
                lines.append(_fmt_table(
                    bottom,
                    ["kgram_str", "n_problems_tested", "n_enriched_solved",
                     "frac_enriched", "mean_log_odds", "median_log_odds"],
                    max_rows=10,
                ))
    else:
        lines.append("*No per-problem data computed.*\n")

    # Section 3: Cross-checkpoint comparison
    lines.append("## 3. Cross-Checkpoint Comparison\n")
    if len(cross_ckpt_df) > 0:
        jsd_by_k = cross_ckpt_df.attrs.get("jsd_by_k", {})
        if jsd_by_k:
            lines.append("### Profile Divergence (JSD)\n")
            lines.append("| k | JSD |")
            lines.append("| --- | --- |")
            for k_val in sorted(jsd_by_k):
                lines.append(f"| {k_val} | {jsd_by_k[k_val]:.4f} |")
            lines.append("")

        # RL-gained motifs that are also enriched in solved
        gained = cross_ckpt_df[
            (cross_ckpt_df["category"] == "RL-gained") &
            (cross_ckpt_df["enriched_in_solved"])
        ]
        if len(gained) > 0:
            lines.append("### RL-Gained Motifs (enriched in solved)\n")
            lines.append("These motifs are ≥2x more frequent after RL AND enriched in solved traces:\n")
            for k in sorted(gained["k"].unique()):
                k_df = gained[gained["k"] == k].sort_values(
                    "ratio_rl_sft", ascending=False)
                lines.append(f"#### k={k}\n")
                lines.append(_fmt_table(
                    k_df,
                    ["kgram_str", "freq_sft", "freq_rl", "ratio_rl_sft"],
                    max_rows=20,
                ))

        # RL-lost motifs
        lost = cross_ckpt_df[cross_ckpt_df["category"] == "RL-lost"]
        if len(lost) > 0:
            lines.append("### RL-Lost Motifs\n")
            for k in sorted(lost["k"].unique()):
                k_df = lost[lost["k"] == k].sort_values("ratio_rl_sft")
                lines.append(f"#### k={k}\n")
                lines.append(_fmt_table(
                    k_df,
                    ["kgram_str", "freq_sft", "freq_rl", "ratio_rl_sft",
                     "enriched_in_solved"],
                    max_rows=10,
                ))
    else:
        lines.append("*No cross-checkpoint data computed.*\n")

    # Section 4: Predictive power
    lines.append("## 4. Motif Predictive Power\n")
    if predictive_results:
        lines.append("Logistic regression (5-fold stratified CV) predicting solved from k-gram features:\n")
        lines.append("| k | AUC (mean±std) | Features | Positives | Negatives | Top-3 features |")
        lines.append("| --- | --- | --- | --- | --- | --- |")
        for r in predictive_results:
            top3 = ", ".join(f['kgram'] for f in r.get("top_features", [])[:3])
            lines.append(
                f"| {r['k']} | {r['auc_mean']:.3f}±{r['auc_std']:.3f} "
                f"| {r['n_features']} | {r['n_pos']} | {r['n_neg']} | {top3} |"
            )
        lines.append("")

        # Detail top features for best k
        best = max(predictive_results, key=lambda r: r["auc_mean"])
        if best.get("top_features"):
            lines.append(f"### Best k={best['k']} (AUC={best['auc_mean']:.3f}) — Top Features\n")
            lines.append("| Motif | Coefficient |")
            lines.append("| --- | --- |")
            for f in best["top_features"]:
                lines.append(f"| {f['kgram']} | {f['coef']:+.4f} |")
            lines.append("")
    else:
        lines.append("*No predictive analysis computed.*\n")

    return "\n".join(lines)


# ---------------------------------------------------------------------------
# Section 6: CLI
# ---------------------------------------------------------------------------


def main():
    parser = argparse.ArgumentParser(
        description="Motif (k-gram) enrichment analysis for reasoning primitive chains",
    )
    parser.add_argument("--checkpoints", nargs="+", required=True,
                        help="Checkpoint directory names under results/exploration_analysis/")
    parser.add_argument("--k-range", nargs=2, type=int, default=[2, 10],
                        metavar=("MIN_K", "MAX_K"),
                        help="Range of k values (default: 2 10)")
    parser.add_argument("--min-count", type=int, default=5,
                        help="Min k-gram occurrences for k<=5 (default: 5)")
    parser.add_argument("--min-count-high-k", type=int, default=10,
                        help="Min k-gram occurrences for k>5 (default: 10)")
    parser.add_argument("--min-solved-per-problem", type=int, default=3,
                        help="Min solved traces per problem for per-problem analysis (default: 3)")
    parser.add_argument("--sft-baseline", type=str, default=None,
                        help="SFT checkpoint name for cross-checkpoint comparison")
    parser.add_argument("--analyses", nargs="+",
                        default=["all"],
                        choices=["enrichment", "per_problem", "cross_checkpoint",
                                 "predictive", "all"],
                        help="Which analyses to run (default: all)")
    parser.add_argument("--output-dir", type=str,
                        default="results/exploration_analysis/motif_analysis",
                        help="Output directory")
    parser.add_argument("--fdr-threshold", type=float, default=0.05)
    parser.add_argument("--pseudocount", type=float, default=0.5)
    parser.add_argument("--cv-folds", type=int, default=5)
    parser.add_argument("--exclude-primitives", nargs="+", default=None,
                        help="Primitives to exclude (e.g. OTHER). "
                             "Replaced with boundary tokens to avoid false adjacencies.")
    parser.add_argument("--verbose", action="store_true")
    args = parser.parse_args()

    logging.basicConfig(
        level=logging.INFO if args.verbose else logging.WARNING,
        format="%(asctime)s %(levelname)s %(message)s",
        datefmt="%H:%M:%S",
    )

    analyses = set(args.analyses)
    if "all" in analyses:
        analyses = {"enrichment", "per_problem", "cross_checkpoint", "predictive"}

    outdir = Path(args.output_dir)
    outdir.mkdir(parents=True, exist_ok=True)
    k_range = tuple(args.k_range)

    # Load data
    exclude = set(args.exclude_primitives) if args.exclude_primitives else None
    log.info("Loading checkpoints: %s", args.checkpoints)
    if exclude:
        log.info("Excluding primitives: %s (replaced with boundary tokens)", exclude)
    df = load_all_checkpoints(args.checkpoints, exclude_primitives=exclude)
    log.info("Total: %d traces, %d solved (%.1f%%), %d problems",
             len(df), df["correct"].sum(), 100 * df["correct"].mean(),
             df["doc_id"].nunique())

    enrichment_df = pd.DataFrame()
    per_problem_df = pd.DataFrame()
    cross_ckpt_df = pd.DataFrame()
    predictive_results = []

    # 1. Global enrichment — per checkpoint
    if "enrichment" in analyses:
        log.info("=== Global Motif Enrichment ===")
        parts = []
        for ckpt in args.checkpoints:
            ckpt_df = df[df["dataset"] == ckpt]
            n_solved = ckpt_df["correct"].sum()
            log.info("Checkpoint %s: %d traces, %d solved (%.1f%%)",
                     ckpt, len(ckpt_df), n_solved, 100 * ckpt_df["correct"].mean())
            if n_solved < 2:
                log.warning("  Skipping %s — too few solved traces", ckpt)
                continue
            res = global_enrichment_all_k(
                ckpt_df, k_range, args.min_count, args.min_count_high_k,
                args.pseudocount, checkpoint_label=ckpt)
            if len(res) > 0:
                parts.append(res)
        if parts:
            enrichment_df = pd.concat(parts, ignore_index=True)
            # Save — drop the tuple column for CSV
            save_df = enrichment_df.drop(columns=["kgram"], errors="ignore")
            save_df.to_csv(outdir / "global_enrichment.csv", index=False)
            log.info("Saved global_enrichment.csv (%d rows)", len(save_df))

    # 2. Per-problem analysis — pool all checkpoints
    if "per_problem" in analyses:
        log.info("=== Per-Problem Motif Analysis ===")
        per_problem_df = per_problem_all_k(
            df, k_range, args.min_solved_per_problem, args.pseudocount)
        if len(per_problem_df) > 0:
            save_df = per_problem_df.drop(columns=["kgram"], errors="ignore")
            save_df.to_csv(outdir / "per_problem_consistency.csv", index=False)
            log.info("Saved per_problem_consistency.csv (%d rows)", len(save_df))

    # 3. Cross-checkpoint comparison
    if "cross_checkpoint" in analyses and args.sft_baseline:
        log.info("=== Cross-Checkpoint Comparison ===")
        sft_df = df[df["dataset"] == args.sft_baseline]
        if len(sft_df) == 0:
            log.warning("SFT baseline %s not found in loaded data", args.sft_baseline)
        else:
            rl_checkpoints = [c for c in args.checkpoints if c != args.sft_baseline]
            parts = []
            for rl_ckpt in rl_checkpoints:
                rl_df = df[df["dataset"] == rl_ckpt]
                log.info("Comparing %s -> %s", args.sft_baseline, rl_ckpt)
                res = cross_checkpoint_all_k(
                    sft_df, rl_df, enrichment_df, k_range,
                    args.min_count, args.min_count_high_k,
                    sft_label=args.sft_baseline, rl_label=rl_ckpt)
                if len(res) > 0:
                    parts.append(res)
            if parts:
                cross_ckpt_df = pd.concat(parts, ignore_index=True)
                # Merge JSD attrs
                jsd_by_k = {}
                for p in parts:
                    for k_val, jsd_val in p.attrs.get("jsd_by_k", {}).items():
                        jsd_by_k[k_val] = jsd_val
                cross_ckpt_df.attrs["jsd_by_k"] = jsd_by_k
                save_df = cross_ckpt_df.drop(columns=["kgram"], errors="ignore")
                save_df.to_csv(outdir / "cross_checkpoint.csv", index=False)
                log.info("Saved cross_checkpoint.csv (%d rows)", len(save_df))

    # 4. Predictive power — per checkpoint
    if "predictive" in analyses:
        log.info("=== Motif Predictive Power ===")
        for ckpt in args.checkpoints:
            ckpt_df = df[df["dataset"] == ckpt]
            n_solved = ckpt_df["correct"].sum()
            if n_solved < 5:
                log.warning("  Skipping %s — too few solved for CV", ckpt)
                continue
            log.info("Predictive model for %s:", ckpt)
            results = predictive_analysis(
                ckpt_df, k_range, args.min_count, args.min_count_high_k,
                args.cv_folds, checkpoint_label=ckpt)
            predictive_results.extend(results)
        if predictive_results:
            # Save — convert numpy arrays in coefficients to lists
            save_results = []
            for r in predictive_results:
                r2 = dict(r)
                if r2.get("coefficients") is not None:
                    r2.pop("coefficients")  # large array, skip in JSON
                save_results.append(r2)
            with open(outdir / "predictive_results.json", "w") as f:
                json.dump(save_results, f, indent=2, default=str)
            log.info("Saved predictive_results.json")

    # 5. Generate report
    log.info("=== Generating Report ===")
    report = generate_report(
        enrichment_df, per_problem_df, cross_ckpt_df, predictive_results, args)
    report_path = outdir / "motif_report.md"
    report_path.write_text(report)
    log.info("Saved %s", report_path)
    print(f"\nReport written to {report_path}")
    print(f"CSVs and JSON in {outdir}/")


if __name__ == "__main__":
    main()
