"""Config-driven CLI entry point for the exploration analysis pipeline.

Usage:
    python -m analysis.exploration.pipeline --config analysis/exploration/configs/default.yaml
    python -m analysis.exploration.pipeline --config ... --stage segment
    python -m analysis.exploration.pipeline --config ... --stage correlate
"""
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import yaml

from . import data_loading, segmentation, primitive_classification, primitive_metrics
from . import passk, analysis, plotting, report
from . import clustering, diversity_metrics, novelty_metrics

_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent


def load_config(config_path: str) -> dict:
    with open(config_path) as f:
        return yaml.safe_load(f)


def resolve_path(p: str | None) -> Path | None:
    if p is None:
        return None
    path = Path(p)
    if not path.is_absolute():
        path = _PROJECT_ROOT / path
    return path


# ---------------------------------------------------------------------------
# Stage: load — Load JSONL samples and score responses
# ---------------------------------------------------------------------------

def stage_load(config: dict) -> dict:
    """Load all samples from configured results dirs."""
    print("\n=== Stage: LOAD ===")
    all_data = {}  # checkpoint_id -> {task_name -> list[PromptSamples]}

    # SFT baselines
    for base_key, base_cfg in config.get("sft_baselines", {}).items():
        label = base_cfg.get("label", base_key)
        for dir_key in ("puzzle_results", "math_results"):
            results_dir = resolve_path(base_cfg.get(dir_key))
            if results_dir and results_dir.exists():
                print(f"  Loading {label}/{dir_key}: {results_dir}")
                data = data_loading.load_results_dir(results_dir, checkpoint_id=label)
                if label not in all_data:
                    all_data[label] = {}
                all_data[label].update(data)

    # GSPO checkpoints
    for ckpt_key, ckpt_cfg in config.get("gspo_checkpoints", {}).items():
        for dir_key in ("puzzle_results", "math_results"):
            results_dir = resolve_path(ckpt_cfg.get(dir_key))
            if results_dir and results_dir.exists():
                print(f"  Loading {ckpt_key}/{dir_key}: {results_dir}")
                data = data_loading.load_results_dir(results_dir, checkpoint_id=ckpt_key)
                if ckpt_key not in all_data:
                    all_data[ckpt_key] = {}
                all_data[ckpt_key].update(data)

    # Optional task filter
    task_filter = config.get("task_filter")
    if task_filter:
        for ckpt in all_data:
            all_data[ckpt] = {
                t: v for t, v in all_data[ckpt].items()
                if any(f in t for f in task_filter)
            }
        print(f"  Filtered to tasks matching: {task_filter}")

    # Optional max docs per task
    max_docs = config.get("max_docs_per_task")
    if max_docs:
        for ckpt in all_data:
            for task in all_data[ckpt]:
                if len(all_data[ckpt][task]) > max_docs:
                    all_data[ckpt][task] = all_data[ckpt][task][:max_docs]
        print(f"  Capped at {max_docs} docs per task")

    print(f"  Loaded {len(all_data)} checkpoints")
    for ckpt, tasks in all_data.items():
        n_prompts = sum(len(v) for v in tasks.values())
        print(f"    {ckpt}: {len(tasks)} tasks, {n_prompts} prompts")

    return all_data


# ---------------------------------------------------------------------------
# Stage: segment + classify — Segment traces and classify primitives
# ---------------------------------------------------------------------------

def _segment_one(args):
    """Worker function for parallel segmentation (runs in subprocess)."""
    response, tokenizer_name, min_tokens, max_tokens = args
    from analysis.exploration import segmentation as _seg
    spans = _seg.segment_response(response, tokenizer_name, min_tokens, max_tokens)
    # Return plain dicts (picklable) — span_id, text, n_tokens
    return [(s.span_id, s.text, s.n_tokens) for s in spans]


def stage_segment_classify(
    all_data: dict,
    config: dict,
    classifier_type: str = "heuristic",
    n_workers: int = 8,
    max_traces: int | None = None,
) -> dict:
    """Segment traces and classify primitives. Returns per-trace metrics.

    Args:
        classifier_type: "heuristic" (regex, 10 labels) or "learned" (BERT ensemble, 9 labels).
        n_workers: Number of parallel workers for segmentation.
    """
    import time
    from concurrent.futures import ProcessPoolExecutor
    from . import sequence_diversity as sd

    print("\n=== Stage: SEGMENT + CLASSIFY ===")
    print(f"  Classifier: {classifier_type}, segmentation workers: {n_workers}")

    seg_cfg = config.get("segmentation", {})
    tokenizer_name = seg_cfg.get("tokenizer", "allenai/OLMo-3-7B-Instruct-SFT")
    min_tokens = seg_cfg.get("min_tokens", 80)
    max_tokens = seg_cfg.get("max_tokens", 250)
    conf_threshold = config.get("primitive_confidence_threshold", 0.01)

    # Load learned classifier if requested
    learned_clf = None
    if classifier_type == "learned":
        batch_size = config.get("primitive_batch_size", 256)
        use_fp16 = config.get("primitive_fp16", True)
        multi_gpu = config.get("primitive_multi_gpu", False)
        n_models = config.get("primitive_n_models", 3)
        print(f"  Loading BERT ensemble (n_models={n_models}, batch_size={batch_size}, fp16={use_fp16})...")
        t0 = time.time()
        learned_clf = primitive_classification.load_ensemble_classifier(
            batch_size=batch_size, use_fp16=use_fp16, multi_gpu=multi_gpu, n_models=n_models,
        )
        sd.use_learned_labels()
        print(f"  Loaded in {time.time()-t0:.1f}s — label set: {sd._PRIM_LIST}")

    # ---------------------------------------------------------------
    # Phase 1: Collect all trace metadata and segment in parallel
    # ---------------------------------------------------------------
    trace_meta = []  # (ckpt_id, task_name, doc_id, trace_id, correct, response_len, is_puzzle)
    seg_args = []    # args for parallel segmentation (puzzle traces only)

    if max_traces is not None:
        print(f"  [smoke test] Capping at {max_traces} traces per prompt")

    # When task_filter is set, segment ALL matching tasks (including math)
    segment_all = bool(config.get("task_filter"))

    for ckpt_id, tasks in all_data.items():
        for task_name, prompt_samples_list in tasks.items():
            is_puzzle = segment_all or any(p in task_name for p in
                           ["bridges", "pattern", "undead", "galaxies", "loopy"])
            for ps in prompt_samples_list:
                traces = ps.traces[:max_traces] if max_traces is not None else ps.traces
                for trace in traces:
                    trace_meta.append((
                        ckpt_id, task_name, ps.doc_id, trace.trace_id,
                        trace.correct, len(trace.response), is_puzzle,
                        ps.n_correct, ps.n_total,
                    ))
                    if is_puzzle:
                        seg_args.append((trace.response, tokenizer_name, min_tokens, max_tokens))
                    else:
                        seg_args.append(None)  # placeholder for math traces

    n_puzzle = sum(1 for a in seg_args if a is not None)
    print(f"  Segmenting {n_puzzle} puzzle traces with {n_workers} workers...")
    t0 = time.time()

    puzzle_indices = [i for i, a in enumerate(seg_args) if a is not None]
    puzzle_args = [seg_args[i] for i in puzzle_indices]

    # Parallel segmentation of puzzle traces
    span_results = [None] * len(seg_args)
    if puzzle_args:
        with ProcessPoolExecutor(max_workers=n_workers) as pool:
            for idx, result in zip(puzzle_indices, pool.map(_segment_one, puzzle_args, chunksize=20)):
                span_results[idx] = result  # list of (span_id, text, n_tokens)

    print(f"  Segmented in {time.time()-t0:.1f}s")

    # ---------------------------------------------------------------
    # Phase 2: Batch GPU classification (all spans at once)
    # ---------------------------------------------------------------
    if learned_clf is not None:
        # Build one flat row list for ALL spans across ALL puzzle traces
        flat_rows = []
        flat_trace_idx = []   # which trace each span belongs to
        flat_span_pos = []    # position within trace (for context)

        for i in puzzle_indices:
            spans_i = span_results[i] or []
            for j, (_, text, _) in enumerate(spans_i):
                prev_text = spans_i[j-1][1] if j > 0 else ""
                flat_rows.append({"span_text": text, "preceding_context": prev_text})
                flat_trace_idx.append(i)
                flat_span_pos.append(j)

        print(f"  Classifying {len(flat_rows)} spans in one GPU batch...")
        t0 = time.time()
        if flat_rows:
            flat_labels = learned_clf(flat_rows)
        else:
            flat_labels = []
        print(f"  Classified in {time.time()-t0:.1f}s")

        # Reassemble: per-trace label lists
        from collections import defaultdict
        trace_labels: dict[int, list] = defaultdict(list)
        for label, trace_i in zip(flat_labels, flat_trace_idx):
            trace_labels[trace_i].append(label)
    else:
        trace_labels = {}

    # ---------------------------------------------------------------
    # Phase 3: Build episode summaries and per-prompt aggregates
    # ---------------------------------------------------------------
    from analysis.exploration.segmentation import Span
    trace_records = []
    prompt_records = []

    # Group by (ckpt_id, task_name, doc_id) to rebuild prompt-level aggregates
    from itertools import groupby

    # We need to iterate in the same order as trace_meta
    # Group consecutive traces by prompt
    prompt_groups: dict[tuple, list] = {}
    for i, meta in enumerate(trace_meta):
        ckpt_id, task_name, doc_id, trace_id, correct, resp_len, is_puzzle, n_correct, n_total = meta
        key = (ckpt_id, task_name, doc_id)
        if key not in prompt_groups:
            prompt_groups[key] = {"n_correct": n_correct, "n_total": n_total, "traces": []}
        prompt_groups[key]["traces"].append((i, trace_id, correct, resp_len, is_puzzle))

    for (ckpt_id, task_name, doc_id), grp in prompt_groups.items():
        trace_summaries = []
        trace_successes = []

        for i, trace_id, correct, resp_len, is_puzzle in grp["traces"]:
            if is_puzzle:
                raw_spans = span_results[i] or []
                # Reconstruct Span objects (lightweight)
                spans = [Span(span_id=sid, text=txt, start_char=0, end_char=0, n_tokens=ntok)
                         for sid, txt, ntok in raw_spans]
                total_tokens = sum(s.n_tokens for s in spans)

                if learned_clf is not None:
                    labels = trace_labels.get(i, [])
                    labeled = [(span, lbl, 1.0) for span, lbl in zip(spans, labels)]
                else:
                    labeled = primitive_classification.classify_trace_spans(spans, conf_threshold)

                episodes = primitive_classification.merge_episodes(labeled)
                prim_seq = primitive_classification.extract_primitive_sequence(episodes)
                summary = primitive_metrics.trace_primitive_summary(episodes, total_tokens)
                summary["primitive_sequence"] = prim_seq
            else:
                summary = {"total_tokens": resp_len // 4}  # rough estimate
                prim_seq = []

            summary.update({
                "checkpoint_id": ckpt_id,
                "task_name": task_name,
                "doc_id": doc_id,
                "trace_id": trace_id,
                "correct": correct,
                "response_length_chars": resp_len,
            })
            trace_records.append(summary)
            trace_summaries.append(summary)
            trace_successes.append(correct)

        # Per-prompt aggregation
        if is_puzzle and trace_summaries:
            agg = primitive_metrics.aggregate_primitive_metrics(
                trace_summaries, success_mask=trace_successes
            )
            agg.update({
                "checkpoint_id": ckpt_id,
                "task_name": task_name,
                "doc_id": doc_id,
                "domain": "puzzle",
                "n_correct": grp["n_correct"],
                "n_total": grp["n_total"],
            })
            prompt_records.append(agg)

    print(f"  Processed {len(trace_records)} traces, {len(prompt_records)} prompts")
    return {
        "trace_records": trace_records,
        "prompt_records": prompt_records,
    }


# ---------------------------------------------------------------------------
# Stage: cluster — Cluster traces per prompt
# ---------------------------------------------------------------------------

def stage_cluster(
    all_data: dict,
    segment_results: dict,
    config: dict,
) -> dict:
    """Cluster traces per prompt and compute diversity metrics."""
    print("\n=== Stage: CLUSTER ===")

    clust_cfg = config.get("clustering", {})
    dist_threshold = clust_cfg.get("distance_threshold", 0.3)
    semantic_weight = clust_cfg.get("semantic_weight", 0.7)

    diversity_records = []

    for ckpt_id, tasks in all_data.items():
        for task_name, prompt_samples_list in tasks.items():
            is_puzzle = any(p in task_name for p in
                           ["bridges", "pattern", "undead", "galaxies", "loopy"])
            if not is_puzzle:
                continue

            for ps in prompt_samples_list:
                if len(ps.traces) < 2:
                    continue

                traces_text = [t.response for t in ps.traces]
                success_mask = np.array([t.correct for t in ps.traces])

                # Get primitive sequences for this prompt's traces
                prim_seqs = []
                for tr in segment_results.get("trace_records", []):
                    if (tr["checkpoint_id"] == ckpt_id
                            and tr["task_name"] == task_name
                            and tr["doc_id"] == ps.doc_id):
                        prim_seqs.append(tr.get("primitive_sequence", []))

                # Compute embeddings and cluster
                try:
                    embs = clustering.embed_traces(traces_text)
                    if prim_seqs and len(prim_seqs) == len(traces_text):
                        prim_vecs = clustering.primitive_ngram_vectors(prim_seqs)
                        dist_mat = clustering.combined_distance_matrix(
                            embs, prim_vecs, semantic_weight
                        )
                    else:
                        from scipy.spatial.distance import cdist
                        dist_mat = cdist(embs, embs, metric="cosine")
                        dist_mat = np.nan_to_num(dist_mat, nan=1.0)

                    labels = clustering.cluster_traces(dist_mat, dist_threshold)
                    div = diversity_metrics.prompt_diversity_metrics(labels, success_mask)
                except Exception as e:
                    print(f"  Warning: clustering failed for {ckpt_id}/{task_name}/doc{ps.doc_id}: {e}")
                    continue

                div.update({
                    "checkpoint_id": ckpt_id,
                    "task_name": task_name,
                    "doc_id": ps.doc_id,
                })
                diversity_records.append(div)

    print(f"  Clustered {len(diversity_records)} prompts")
    return {"diversity_records": diversity_records}


# ---------------------------------------------------------------------------
# Stage: novelty — Joint SFT+RL clustering
# ---------------------------------------------------------------------------

def stage_novelty(
    all_data: dict,
    config: dict,
) -> dict:
    """Compute cluster novelty by jointly clustering SFT + RL traces."""
    print("\n=== Stage: NOVELTY ===")

    clust_cfg = config.get("clustering", {})
    tau = clust_cfg.get("novel_tau", 0.1)
    dist_threshold = clust_cfg.get("distance_threshold", 0.3)

    novelty_records = []

    # For each GSPO checkpoint, find its SFT baseline and compute novelty
    for ckpt_key, ckpt_cfg in config.get("gspo_checkpoints", {}).items():
        sft_base_key = ckpt_cfg.get("sft_base")
        if not sft_base_key:
            continue

        sft_label = config["sft_baselines"].get(sft_base_key, {}).get("label", sft_base_key)

        if ckpt_key not in all_data or sft_label not in all_data:
            continue

        # Find matching puzzle tasks
        sft_tasks = all_data[sft_label]
        rl_tasks = all_data[ckpt_key]
        common_tasks = set(sft_tasks.keys()) & set(rl_tasks.keys())
        puzzle_tasks = [t for t in common_tasks
                        if any(p in t for p in ["bridges", "pattern", "undead", "galaxies", "loopy"])]

        for task_name in sorted(puzzle_tasks):
            sft_prompts = {ps.doc_id: ps for ps in sft_tasks[task_name]}
            rl_prompts = {ps.doc_id: ps for ps in rl_tasks[task_name]}
            common_docs = set(sft_prompts.keys()) & set(rl_prompts.keys())

            for doc_id in sorted(common_docs):
                sft_ps = sft_prompts[doc_id]
                rl_ps = rl_prompts[doc_id]

                sft_texts = [t.response for t in sft_ps.traces]
                rl_texts = [t.response for t in rl_ps.traces]
                rl_success = [t.correct for t in rl_ps.traces]

                try:
                    result = novelty_metrics.joint_cluster_novelty_with_success(
                        sft_texts, rl_texts, rl_success,
                        tau=tau, distance_threshold=dist_threshold,
                    )
                except Exception as e:
                    print(f"  Warning: novelty failed for {ckpt_key}/{task_name}/doc{doc_id}: {e}")
                    continue

                # Extract scalar metrics only
                record = {
                    k: v for k, v in result.items()
                    if not isinstance(v, np.ndarray)
                }
                record.update({
                    "checkpoint_id": ckpt_key,
                    "sft_base": sft_base_key,
                    "task_name": task_name,
                    "doc_id": doc_id,
                })
                novelty_records.append(record)

    print(f"  Computed novelty for {len(novelty_records)} prompts")
    return {"novelty_records": novelty_records}


# ---------------------------------------------------------------------------
# Stage: aggregate — Build checkpoint-level tables
# ---------------------------------------------------------------------------

def stage_aggregate(
    config: dict,
    segment_results: dict,
    cluster_results: dict,
    novelty_results: dict,
) -> dict:
    """Aggregate all metrics to checkpoint level."""
    print("\n=== Stage: AGGREGATE ===")

    checkpoint_metrics = {}

    # Group prompt records by checkpoint
    prompt_df = pd.DataFrame(segment_results.get("prompt_records", []))
    if not prompt_df.empty:
        for ckpt_id, group in prompt_df.groupby("checkpoint_id"):
            metrics = {}
            # Average numeric columns
            numeric_cols = group.select_dtypes(include=[np.number]).columns
            for col in numeric_cols:
                if col not in ("doc_id", "n_correct", "n_total"):
                    metrics[col] = float(group[col].mean())
            checkpoint_metrics[ckpt_id] = metrics

    # Add diversity metrics
    div_df = pd.DataFrame(cluster_results.get("diversity_records", []))
    if not div_df.empty:
        div_agg = diversity_metrics.aggregate_diversity(
            cluster_results["diversity_records"]
        )
        # Group by checkpoint
        for ckpt_id, group in div_df.groupby("checkpoint_id"):
            ckpt_divs = group.to_dict("records")
            agg = diversity_metrics.aggregate_diversity(ckpt_divs)
            if ckpt_id not in checkpoint_metrics:
                checkpoint_metrics[ckpt_id] = {}
            checkpoint_metrics[ckpt_id].update(agg)

    # Add novelty metrics
    nov_df = pd.DataFrame(novelty_results.get("novelty_records", []))
    if not nov_df.empty:
        for ckpt_id, group in nov_df.groupby("checkpoint_id"):
            ckpt_novs = group.to_dict("records")
            agg = novelty_metrics.aggregate_novelty(ckpt_novs)
            if ckpt_id not in checkpoint_metrics:
                checkpoint_metrics[ckpt_id] = {}
            checkpoint_metrics[ckpt_id].update(agg)

    # Build checkpoint table
    table = analysis.build_checkpoint_table(checkpoint_metrics)
    print(f"  Built checkpoint table: {table.shape}")

    return {
        "checkpoint_metrics": checkpoint_metrics,
        "checkpoint_table": table,
    }


# ---------------------------------------------------------------------------
# Stage: correlate — Run correlation and regression analysis
# ---------------------------------------------------------------------------

def stage_correlate(
    checkpoint_table: pd.DataFrame,
    config: dict,
) -> dict:
    """Run correlations and regressions."""
    print("\n=== Stage: CORRELATE ===")

    math_col = "math_pass32_gain"
    if math_col not in checkpoint_table.columns:
        print(f"  Warning: {math_col} not in checkpoint table, skipping correlations")
        return {"correlation_table": pd.DataFrame(), "regressions": []}

    # Find metric columns
    metric_cols = [c for c in checkpoint_table.columns
                   if c not in ("checkpoint_id", "sft_base", math_col)
                   and checkpoint_table[c].dtype in (np.float64, np.int64, float, int)]

    # Spearman correlations
    corr_table = analysis.spearman_correlation_table(
        checkpoint_table, metric_cols, math_col
    )
    corr_table = corr_table.sort_values("rho", key=abs, ascending=False)
    print(f"  Computed {len(corr_table)} correlations")

    # Simple regressions for top correlated metrics
    regressions = []
    top_metrics = corr_table.head(5)["metric"].tolist()
    for m in top_metrics:
        reg = analysis.ols_regression(checkpoint_table, m, math_col)
        regressions.append(reg)

    # Length confound check
    length_confounded = analysis.length_correlation_check(
        checkpoint_table, metric_cols
    )
    if length_confounded:
        print(f"  Warning: metrics highly correlated with trace length: {length_confounded}")

    return {
        "correlation_table": corr_table,
        "regressions": regressions,
        "length_confounded": length_confounded,
    }


# ---------------------------------------------------------------------------
# Stage: report — Generate plots and markdown report
# ---------------------------------------------------------------------------

def stage_report(
    config: dict,
    checkpoint_table: pd.DataFrame,
    checkpoint_metrics: dict,
    correlation_results: dict,
    output_dir: Path,
):
    """Generate all plots and the markdown report."""
    print("\n=== Stage: REPORT ===")

    # Generate plots
    plotting.generate_all_plots(
        checkpoint_table=checkpoint_table,
        checkpoint_metrics=checkpoint_metrics,
        output_dir=output_dir,
    )

    # Sanity checks
    sanity = {}
    if not checkpoint_table.empty:
        # Check primitive distribution
        other_col = "OTHER_per_1k_mean"
        if other_col in checkpoint_table.columns:
            other_frac = checkpoint_table[other_col].mean()
            total = sum(
                checkpoint_table[f"{p}_per_1k_mean"].mean()
                for p in primitive_classification.PRIMITIVES
                if f"{p}_per_1k_mean" in checkpoint_table.columns
            )
            if total > 0:
                other_pct = other_frac / total * 100
                sanity["primitive_distribution"] = {
                    "ok": other_pct < 80,
                    "message": f"OTHER is {other_pct:.1f}% of all primitives"
                    + (" (patterns may be too narrow)" if other_pct >= 80 else ""),
                }

    # Generate report
    report.generate_report(
        output_dir=output_dir,
        checkpoint_table=checkpoint_table,
        correlation_table=correlation_results.get("correlation_table"),
        regression_results=correlation_results.get("regressions"),
        sanity_checks=sanity,
        config_name=config.get("name", "exploration_analysis"),
    )


# ---------------------------------------------------------------------------
# Helper: save segment outputs immediately (called right after segment stage)
# ---------------------------------------------------------------------------

def _save_segment_outputs(output_dir: Path, segment_results: dict):
    """Save trace- and prompt-level parquets right after segment/classify."""
    trace_records = segment_results.get("trace_records", [])
    if trace_records:
        clean = []
        for r in trace_records:
            cr = {k: v for k, v in r.items() if not isinstance(v, np.ndarray)}
            if "primitive_sequence" in cr:
                cr["primitive_sequence"] = json.dumps(cr["primitive_sequence"])
            clean.append(cr)
        df = pd.DataFrame(clean)
        df.to_parquet(output_dir / "trace_level_metrics.parquet", index=False)
        df.to_csv(output_dir / "trace_level_metrics.csv", index=False)
        print(f"  Saved trace_level_metrics: {df.shape}")

    prompt_records = segment_results.get("prompt_records", [])
    if prompt_records:
        df = pd.DataFrame(prompt_records)
        df.to_parquet(output_dir / "prompt_level_metrics.parquet", index=False)
        df.to_csv(output_dir / "prompt_level_metrics.csv", index=False)
        print(f"  Saved prompt_level_metrics: {df.shape}")


# ---------------------------------------------------------------------------
# Stage: save — Save all intermediate outputs
# ---------------------------------------------------------------------------

def stage_save(
    output_dir: Path,
    segment_results: dict,
    cluster_results: dict,
    novelty_results: dict,
    agg_results: dict,
    correlation_results: dict,
):
    """Save all intermediate and final outputs."""
    print("\n=== Stage: SAVE ===")
    output_dir.mkdir(parents=True, exist_ok=True)

    # Trace-level metrics
    trace_records = segment_results.get("trace_records", [])
    if trace_records:
        # Remove non-serializable fields
        clean = []
        for r in trace_records:
            cr = {k: v for k, v in r.items()
                  if not isinstance(v, (np.ndarray,))}
            # Convert primitive_sequence list to string for parquet
            if "primitive_sequence" in cr:
                cr["primitive_sequence"] = json.dumps(cr["primitive_sequence"])
            clean.append(cr)
        df = pd.DataFrame(clean)
        df.to_parquet(output_dir / "trace_level_metrics.parquet", index=False)
        df.to_csv(output_dir / "trace_level_metrics.csv", index=False)
        print(f"  Saved trace_level_metrics: {df.shape}")

    # Prompt-level metrics
    prompt_records = segment_results.get("prompt_records", [])
    if prompt_records:
        df = pd.DataFrame(prompt_records)
        df.to_parquet(output_dir / "prompt_level_metrics.parquet", index=False)
        df.to_csv(output_dir / "prompt_level_metrics.csv", index=False)
        print(f"  Saved prompt_level_metrics: {df.shape}")

    # Diversity
    div_records = cluster_results.get("diversity_records", [])
    if div_records:
        df = pd.DataFrame(div_records)
        df.to_parquet(output_dir / "diversity_metrics.parquet", index=False)
        print(f"  Saved diversity_metrics: {df.shape}")

    # Novelty
    nov_records = novelty_results.get("novelty_records", [])
    if nov_records:
        df = pd.DataFrame(nov_records)
        df.to_parquet(output_dir / "novelty_metrics.parquet", index=False)
        print(f"  Saved novelty_metrics: {df.shape}")

    # Checkpoint table
    table = agg_results.get("checkpoint_table")
    if table is not None and not table.empty:
        table.to_parquet(output_dir / "checkpoint_metrics.parquet", index=False)
        table.to_csv(output_dir / "checkpoint_metrics.csv", index=False)
        print(f"  Saved checkpoint_metrics: {table.shape}")

    # Correlations
    corr = correlation_results.get("correlation_table")
    if corr is not None and not corr.empty:
        corr.to_csv(output_dir / "correlation_table.csv", index=False)
        print(f"  Saved correlation_table: {corr.shape}")


# ---------------------------------------------------------------------------
# Main pipeline
# ---------------------------------------------------------------------------

STAGES = ["load", "segment", "cluster", "novelty", "aggregate", "correlate", "report"]


def run_pipeline(config_path: str, stage: str | None = None, classifier_type: str = "heuristic", max_traces: int | None = None):
    """Run the full pipeline or a specific stage."""
    config = load_config(config_path)
    output_dir = resolve_path(config.get("output_dir", "results/exploration_analysis"))
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"Pipeline: {config.get('name', 'unnamed')}")
    print(f"Output: {output_dir}")
    print(f"Classifier: {classifier_type}")

    if stage and stage not in STAGES:
        print(f"Unknown stage: {stage}. Available: {STAGES}")
        sys.exit(1)

    # Run stages
    run_all = stage is None

    # Load
    all_data = {}
    if run_all or stage == "load":
        all_data = stage_load(config)
        if max_traces is not None:
            # Limit traces for smoke tests
            for ckpt in all_data:
                for task in all_data[ckpt]:
                    for ps in all_data[ckpt][task]:
                        ps.traces = ps.traces[:max_traces]
            print(f"  [smoke test] Capped at {max_traces} traces per prompt")

    # Segment + Classify (or load from cache)
    segment_results = {"trace_records": [], "prompt_records": []}
    cached_dirs = config.get("cached_segment_dirs")
    if cached_dirs:
        print(f"\n=== Loading cached segment results from {len(cached_dirs)} dirs ===")
        all_trace = []
        all_prompt = []
        for d in cached_dirs:
            p = resolve_path(d)
            tf = p / "trace_level_metrics.parquet"
            pf = p / "prompt_level_metrics.parquet"
            if tf.exists():
                df = pd.read_parquet(tf)
                recs = df.to_dict("records")
                # Deserialize primitive_sequence back to list
                for r in recs:
                    if "primitive_sequence" in r and isinstance(r["primitive_sequence"], str):
                        r["primitive_sequence"] = json.loads(r["primitive_sequence"])
                all_trace.extend(recs)
                print(f"  {d}: {len(recs)} traces")
            if pf.exists():
                all_prompt.extend(pd.read_parquet(pf).to_dict("records"))
        segment_results = {"trace_records": all_trace, "prompt_records": all_prompt}
        print(f"  Total: {len(all_trace)} traces, {len(all_prompt)} prompts")
    elif run_all or stage in ("segment", "classify"):
        if not all_data:
            all_data = stage_load(config)
        segment_results = stage_segment_classify(all_data, config, classifier_type, max_traces=max_traces)
        # Always save trace/prompt outputs immediately after segment
        output_dir.mkdir(parents=True, exist_ok=True)
        _save_segment_outputs(output_dir, segment_results)

    # Cluster
    cluster_results = {"diversity_records": []}
    if run_all or stage == "cluster":
        if not all_data:
            all_data = stage_load(config)
        if not segment_results["trace_records"]:
            if cached_dirs:
                print("  Warning: cached segment results empty, cannot cluster")
            else:
                segment_results = stage_segment_classify(all_data, config, classifier_type)
        cluster_results = stage_cluster(all_data, segment_results, config)

    # Novelty
    novelty_results = {"novelty_records": []}
    if run_all or stage == "novelty":
        if not all_data:
            all_data = stage_load(config)
        novelty_results = stage_novelty(all_data, config)

    # Aggregate
    agg_results = {"checkpoint_metrics": {}, "checkpoint_table": pd.DataFrame()}
    if run_all or stage == "aggregate":
        agg_results = stage_aggregate(config, segment_results, cluster_results, novelty_results)

    # Correlate
    correlation_results = {"correlation_table": pd.DataFrame(), "regressions": []}
    if run_all or stage == "correlate":
        table = agg_results.get("checkpoint_table", pd.DataFrame())
        if not table.empty:
            correlation_results = stage_correlate(table, config)

    # Save outputs
    if run_all or stage == "report":
        stage_save(output_dir, segment_results, cluster_results, novelty_results,
                   agg_results, correlation_results)
        stage_report(
            config=config,
            checkpoint_table=agg_results.get("checkpoint_table", pd.DataFrame()),
            checkpoint_metrics=agg_results.get("checkpoint_metrics", {}),
            correlation_results=correlation_results,
            output_dir=output_dir,
        )

    print(f"\n=== Pipeline complete. Output: {output_dir} ===")


def main():
    parser = argparse.ArgumentParser(
        description="Exploration analysis pipeline",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("--config", required=True, help="Path to YAML config file")
    parser.add_argument("--stage", default=None, choices=STAGES,
                        help="Run a specific stage only (default: all)")
    parser.add_argument("--classifier", default="heuristic",
                        choices=["heuristic", "learned"],
                        help="Primitive classifier: heuristic (regex) or learned (BERT ensemble)")
    parser.add_argument("--max-traces", type=int, default=None,
                        help="Cap traces per prompt (for smoke tests)")
    args = parser.parse_args()
    run_pipeline(args.config, args.stage, args.classifier, args.max_traces)


if __name__ == "__main__":
    main()
