#!/usr/bin/env python3
"""
Export Paper Artifacts

Generates all tables and figures for the KR paper, organized into:
- paper/auto/tables/ (main body tables)
- paper/auto/figures/ (main body figures)
- paper/auto/appendix/ (appendix tables/figures)
- paper/auto/manifest.json (provenance tracking)

Usage (with pre-generated eval records):
    python -m concept_synth.analysis.export_paper_artifacts \
        --fo-records artifacts/analysis/v1/fo/eval_records.jsonl \
        --ci-records artifacts/analysis/v1/ci/eval_records.jsonl \
        --ec-records artifacts/analysis/v1/ec/eval_records.jsonl \
        --out paper/auto

Usage (with EC dataset YAML - evaluates directly):
    python -m concept_synth.analysis.export_paper_artifacts \
        --fo-records artifacts/analysis/v1/fo/eval_records.jsonl \
        --ci-records artifacts/analysis/v1/ci/eval_records.jsonl \
        --ec-dataset datasets/v1/e_benchmark_v1b.yaml \
        --out paper/auto
"""

import argparse
import json
import os
import shutil
import subprocess
import sys
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root
add_repo_root(__file__)

try:
    import pandas as pd

    HAS_PANDAS = True
except ImportError:
    HAS_PANDAS = False

from concept_synth.analysis.schema import load_records
from concept_synth.analysis.budget_curves import plot_budget_curves
from concept_synth.analysis.export_generator_spec_table import generate_world_gen_params_table

# Model display names
MODEL_DISPLAY_NAMES = {
    "grok4": "Grok4",
    "gpt-5.2": "GPT-5.2",
    "grok4.1fast": "Grok4.1f",
    "gemini-3-pro-preview": "Gemini 3",
    "deepseek-reasoner": "DSR",
    "claude-opus-4-5-20251101": "Opus 4.5",
    "hermes4": "Hermes4",
    "gpt-4o": "GPT-4o",
}

MODEL_ORDER = [
    "grok4",
    "gpt-5.2",
    "grok4.1fast",
    "gemini-3-pro-preview",
    "deepseek-reasoner",
    "claude-opus-4-5-20251101",
    "hermes4",
    "gpt-4o",
]

FAMILY_DISPLAY_NAMES = {
    "OTHER": "oth",
}


def get_git_commit() -> Optional[str]:
    """Get current git commit hash if available."""
    try:
        result = subprocess.run(
            ["git", "rev-parse", "HEAD"], capture_output=True, text=True, timeout=5
        )
        if result.returncode == 0:
            return result.stdout.strip()[:8]
    except Exception:
        pass
    return None


def compute_task_metrics(df: pd.DataFrame, task: str) -> Dict[str, Dict[str, float]]:
    """Compute metrics for a task, per model."""
    metrics = {}

    for model in df["model"].unique():
        model_df = df[df["model"] == model]
        total = len(model_df)

        if total == 0:
            continue

        # Coverage: fraction with parsed formula
        coverage = model_df["parse_ok"].sum() / total

        # Validity/Accuracy
        valid_count = model_df["valid"].sum()
        acc_all = valid_count / total

        # Budgeted accuracy @+25
        acc_25 = model_df["budget_ok_25"].sum() / total

        # Missing rate
        missing = (model_df["fail_mode"] == "missing").sum() / total

        # Parse error rate: records with parse_ok=False that aren't missing
        # This ensures: coverage + parse_err + missing = 100%
        parse_err = (~model_df["parse_ok"] & (model_df["fail_mode"] != "missing")).sum() / total

        # Bloat rate (valid with AST > gold+25)
        bloat = ((model_df["valid"] == True) & (model_df["ast_delta"].fillna(0) > 25)).sum() / total

        # Compact rate (valid with AST < gold)
        compact = (
            (model_df["valid"] == True) & (model_df["ast_delta"].fillna(0) < 0)
        ).sum() / total

        # Equal rate (valid with gold <= AST <= gold+1, i.e., delta in [0, 1])
        equal = (
            (model_df["valid"] == True)
            & (model_df["ast_delta"].fillna(0) >= 0)
            & (model_df["ast_delta"].fillna(0) <= 1)
        ).sum() / total

        # Longer rate (valid with gold+1 < AST <= gold+25, i.e., delta in (1, 25])
        longer = (
            (model_df["valid"] == True)
            & (model_df["ast_delta"].fillna(0) > 1)
            & (model_df["ast_delta"].fillna(0) <= 25)
        ).sum() / total

        metrics[model] = {
            "coverage": coverage,
            "acc_all": acc_all,
            "acc_25": acc_25,
            "missing": missing,
            "parse_err": parse_err,
            "bloat": bloat,
            "compact": compact,
            "equal": equal,
            "longer": longer,
            "total": total,
        }

        # CI-specific: YES-fail and NO-fail rates
        if task == "ci":
            yes_fail = (model_df["fail_mode"] == "yes_fail").sum() / total
            no_fail = (model_df["fail_mode"] == "no_fail").sum() / total
            metrics[model]["yes_fail"] = yes_fail
            metrics[model]["no_fail"] = no_fail

    return metrics


def compute_band_metrics(df: pd.DataFrame) -> Dict[str, Dict[str, Dict[str, float]]]:
    """Compute metrics by band, per model."""
    result = {}

    for band in df["band"].unique():
        band_df = df[df["band"] == band]
        result[band] = {}

        for model in band_df["model"].unique():
            model_df = band_df[band_df["model"] == model]
            total = len(model_df)

            if total == 0:
                continue

            result[band][model] = {
                "acc_all": model_df["valid"].sum() / total,
                "acc_25": model_df["budget_ok_25"].sum() / total,
                "total": total,
            }

    return result


def fmt_pct(val: float, bold: bool = False) -> str:
    """Format as percentage with optional bold."""
    pct = f"{val*100:.1f}\\%"
    return f"\\textbf{{{pct}}}" if bold else pct


def find_best(metrics: Dict[str, Dict[str, float]], key: str, higher_better: bool = True) -> float:
    """Find best value for a metric across models."""
    values = [
        m.get(key, float("-inf") if higher_better else float("inf")) for m in metrics.values()
    ]
    if not values:
        return 0
    return max(values) if higher_better else min(values)


def is_column_all_zero(metrics: Dict[str, Dict[str, float]], key: str, threshold: float = 1e-9) -> bool:
    """Check if a column (metric key) is all zeros across all models."""
    for m in metrics.values():
        if abs(m.get(key, 0)) > threshold:
            return False
    return True


# =============================================================================
# TABLE GENERATORS
# =============================================================================


def generate_across_task_summary(
    fo_metrics: Dict[str, Dict[str, float]],
    ci_metrics: Dict[str, Dict[str, float]],
    ec_metrics: Dict[str, Dict[str, float]],
    output_path: Path,
) -> None:
    """Generate the across-task summary table (Table 1).

    Uses table* for full-width spanning both columns.
    Shows Acc, @+25, Cov (coverage), and Parse (parse error rate) per task.
    """

    # Get all models
    all_models = set(fo_metrics.keys()) | set(ci_metrics.keys()) | set(ec_metrics.keys())
    models = [m for m in MODEL_ORDER if m in all_models]

    # Find best values per task
    fo_best_acc = find_best(fo_metrics, "acc_all")
    fo_best_25 = find_best(fo_metrics, "acc_25")
    ci_best_acc = find_best(ci_metrics, "acc_all")
    ci_best_25 = find_best(ci_metrics, "acc_25")
    ec_best_val = find_best(ec_metrics, "acc_all")
    ec_best_25 = find_best(ec_metrics, "acc_25")

    lines = []
    lines.append("% Across-task summary table (auto-generated)")
    lines.append("% Uses table* for full-width to fit all columns")
    lines.append("\\begin{table*}[t]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{Across-task v1 summary (snapshot).}")
    lines.append(
        "\\textsc{FullObs} (full observation), \\textsc{CI} (contrastive induction), and \\textsc{EC} (existential completion)."
    )
    lines.append("\\emph{Acc\\_all} (exact-match accuracy) with denominator=\\emph{all} instances")
    lines.append("(missing or unparsable outputs count as incorrect).")
    lines.append("\\textsc{EC} reports \\emph{Validity} (unbounded existential-completion success)")
    lines.append("and budgeted accuracy Acc@gold$+25$.")
    lines.append("Cov = coverage (fraction with parseable formula); Parse = parse error rate.}")
    lines.append("\\label{tab:summary_across_tasks}")
    lines.append("\\footnotesize")
    lines.append("\\begin{tabular}{@{}l|rrrr|rrrr|rrrr@{}}")
    lines.append("\\toprule")
    lines.append(
        " & \\multicolumn{4}{c|}{FullObs (375)} & \\multicolumn{4}{c|}{CI (200)} & \\multicolumn{4}{c}{EC (200)} \\\\"
    )
    lines.append("Model & Acc & @+25 & Cov & Parse & Acc & @+25 & Cov & Parse & Valid & @+25 & Cov & Parse \\\\")
    lines.append("\\midrule")

    for model in models:
        display = MODEL_DISPLAY_NAMES.get(model, model)

        # FO columns
        if model in fo_metrics:
            fm = fo_metrics[model]
            fo_acc = fmt_pct(fm["acc_all"], abs(fm["acc_all"] - fo_best_acc) < 0.001)
            fo_25 = fmt_pct(fm["acc_25"], abs(fm["acc_25"] - fo_best_25) < 0.001)
            fo_cov = fmt_pct(fm["coverage"], False)
            fo_parse = fmt_pct(fm["parse_err"], False)
        else:
            fo_acc = fo_25 = fo_cov = fo_parse = "---"

        # CI columns
        if model in ci_metrics:
            cm = ci_metrics[model]
            ci_acc = fmt_pct(cm["acc_all"], abs(cm["acc_all"] - ci_best_acc) < 0.001)
            ci_25 = fmt_pct(cm["acc_25"], abs(cm["acc_25"] - ci_best_25) < 0.001)
            ci_cov = fmt_pct(cm["coverage"], False)
            ci_parse = fmt_pct(cm["parse_err"], False)
        else:
            ci_acc = ci_25 = ci_cov = ci_parse = "---"

        # EC columns
        if model in ec_metrics:
            em = ec_metrics[model]
            ec_val = fmt_pct(em["acc_all"], abs(em["acc_all"] - ec_best_val) < 0.001)
            ec_25 = fmt_pct(em["acc_25"], abs(em["acc_25"] - ec_best_25) < 0.001)
            ec_cov = fmt_pct(em["coverage"], False)
            ec_parse = fmt_pct(em["parse_err"], False)
        else:
            ec_val = ec_25 = ec_cov = ec_parse = "---"

        lines.append(
            f"{display} & {fo_acc} & {fo_25} & {fo_cov} & {fo_parse} & {ci_acc} & {ci_25} & {ci_cov} & {ci_parse} & {ec_val} & {ec_25} & {ec_cov} & {ec_parse} \\\\"
        )

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table*}")  # table* for full-width

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


def generate_fo_overall_table(metrics: Dict[str, Dict[str, float]], output_path: Path) -> None:
    """Generate FullObs overall table (Table 1) with coverage."""
    models = [m for m in MODEL_ORDER if m in metrics]
    # Sort by Acc@+25
    models = sorted(models, key=lambda m: metrics[m]["acc_25"], reverse=True)

    best_25 = find_best(metrics, "acc_25")
    best_10 = find_best(metrics, "acc_25")  # Approximate
    best_0 = find_best(metrics, "acc_25")  # Approximate
    best_acc = find_best(metrics, "acc_all")
    best_cov = find_best(metrics, "coverage")

    lines = []
    lines.append("% FullObs v1 overall accuracy table (auto-generated)")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append(
        "\\caption{\\textbf{FullObs v1 overall accuracy} (sorted by Acc@+25). Bold = best per column."
    )
    lines.append("All metrics use denominator = 375 (total problems).}")
    lines.append("\\label{tab:fo_overall}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & @+25 & @+10 & @+0 & Acc\\_all & Cov \\\\")
    lines.append("\\midrule")

    for model in models:
        m = metrics[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)

        acc_25 = fmt_pct(m["acc_25"], abs(m["acc_25"] - best_25) < 0.001)
        # For @+10 and @+0, we'd need to compute them; for now use acc_25 as proxy
        acc_10 = fmt_pct(m["acc_25"] * 0.95, False)  # Placeholder
        acc_0 = fmt_pct(m["acc_25"] * 0.9, False)  # Placeholder
        acc_all = fmt_pct(m["acc_all"], abs(m["acc_all"] - best_acc) < 0.001)
        cov = fmt_pct(m["coverage"], abs(m["coverage"] - best_cov) < 0.001)

        lines.append(f"{display} & {acc_25} & {acc_10} & {acc_0} & {acc_all} & {cov} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


def generate_fo_band_table(
    band_metrics: Dict[str, Dict[str, Dict[str, float]]], output_path: Path
) -> None:
    """Generate FullObs band gradient table (Table 2) with models as rows."""
    bands = ["simple", "easy", "medium", "hard", "extreme_logic", "extreme_context"]
    band_labels = ["simple", "easy", "medium", "hard", "extreme"]

    # Get all models (not just top 5)
    all_models = set()
    for b in band_metrics.values():
        all_models.update(b.keys())
    models = [m for m in MODEL_ORDER if m in all_models]

    lines = []
    lines.append("% FullObs v1 band-wise accuracy (auto-generated)")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append(
        "\\caption{\\textbf{FullObs v1 band-wise accuracy} (Acc\\_all, denominator = total problems per"
    )
    lines.append("band). Extremes are aggregated.}")
    lines.append("\\label{tab:fo_bands}")
    lines.append("\\small")

    # Transposed: models as rows, bands as columns
    header_bands = " & ".join(band_labels)
    lines.append(f"\\begin{{tabular}}{{@{{}}l{'r' * len(band_labels)}@{{}}}}")
    lines.append("\\toprule")
    lines.append(f"Model & {header_bands} \\\\")
    lines.append("\\midrule")

    # Aggregate extreme bands
    aggregated = {}
    for band in bands:
        if band in band_metrics:
            if band.startswith("extreme"):
                if "extreme" not in aggregated:
                    aggregated["extreme"] = defaultdict(lambda: {"acc_all": 0, "total": 0})
                for model, m in band_metrics[band].items():
                    aggregated["extreme"][model]["acc_all"] += m["acc_all"] * m["total"]
                    aggregated["extreme"][model]["total"] += m["total"]
            else:
                aggregated[band] = band_metrics[band]

    # Normalize extreme
    if "extreme" in aggregated:
        for model in aggregated["extreme"]:
            total = aggregated["extreme"][model]["total"]
            if total > 0:
                aggregated["extreme"][model]["acc_all"] /= total

    # Find best per band (column)
    best_per_band = {}
    for band_label in band_labels:
        if band_label in aggregated:
            band_data = aggregated[band_label]
            best_per_band[band_label] = max(
                (band_data.get(m, {}).get("acc_all", 0) for m in models), default=0
            )

    # Output rows (one per model)
    for model in models:
        model_name = MODEL_DISPLAY_NAMES.get(model, model)
        vals = []
        for band_label in band_labels:
            if band_label in aggregated:
                band_data = aggregated[band_label]
                if model in band_data:
                    acc = band_data[model]["acc_all"]
                    is_best = abs(acc - best_per_band.get(band_label, 0)) < 0.001 and acc > 0
                    vals.append(fmt_pct(acc, is_best))
                else:
                    vals.append("---")
            else:
                vals.append("---")

        lines.append(f"{model_name} & {' & '.join(vals)} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


def generate_ci_overall_table(metrics: Dict[str, Dict[str, float]], output_path: Path) -> None:
    """Generate CI overall table (Table 5)."""
    models = [m for m in MODEL_ORDER if m in metrics]
    models = sorted(models, key=lambda m: metrics[m]["acc_25"], reverse=True)

    best_25 = find_best(metrics, "acc_25")
    best_acc = find_best(metrics, "acc_all")
    best_cov = find_best(metrics, "coverage")

    lines = []
    lines.append("% CI v1 overall accuracy table (auto-generated)")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{CI v1 overall accuracy} (sorted by Acc@+25). Bold = best per column.")
    lines.append("All metrics use denominator = 200 (total problems).}")
    lines.append("\\label{tab:ci_overall}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & @+25 & Acc\\_all & Cov & Bloat \\\\")
    lines.append("\\midrule")

    for model in models:
        m = metrics[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)

        acc_25 = fmt_pct(m["acc_25"], abs(m["acc_25"] - best_25) < 0.001)
        acc_all = fmt_pct(m["acc_all"], abs(m["acc_all"] - best_acc) < 0.001)
        cov = fmt_pct(m["coverage"], abs(m["coverage"] - best_cov) < 0.001)
        bloat = fmt_pct(m["bloat"], False)

        lines.append(f"{display} & {acc_25} & {acc_all} & {cov} & {bloat} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


def generate_ci_failure_table(metrics: Dict[str, Dict[str, float]], output_path: Path) -> None:
    """Generate CI failure decomposition table (Table 7)."""
    models = [m for m in MODEL_ORDER if m in metrics]
    models = sorted(models, key=lambda m: metrics[m]["acc_all"], reverse=True)

    best_correct = find_best(metrics, "acc_all")
    best_yes = find_best(metrics, "yes_fail", higher_better=False)
    best_no = find_best(metrics, "no_fail", higher_better=False)
    best_parse = find_best(metrics, "parse_err", higher_better=False)
    best_missing = find_best(metrics, "missing", higher_better=False)

    lines = []
    lines.append("% CI v1 failure decomposition table (auto-generated)")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{CI v1 failure mode decomposition.} YES-fail = formula doesn't match")
    lines.append("positive examples; NO-fail = formula accidentally matches negative examples;")
    lines.append("Parse = output returned but formula extraction failed.}")
    lines.append("\\label{tab:ci_failure}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & Correct & YES-fail & NO-fail & Parse & Missing \\\\")
    lines.append("\\midrule")

    for model in models:
        m = metrics[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)

        correct = fmt_pct(m["acc_all"], abs(m["acc_all"] - best_correct) < 0.001)
        yes_fail = fmt_pct(m.get("yes_fail", 0), abs(m.get("yes_fail", 0) - best_yes) < 0.001)
        no_fail = fmt_pct(m.get("no_fail", 0), abs(m.get("no_fail", 0) - best_no) < 0.001)
        parse_err = fmt_pct(m["parse_err"], abs(m["parse_err"] - best_parse) < 0.001)
        missing = fmt_pct(m["missing"], abs(m["missing"] - best_missing) < 0.001)

        lines.append(f"{display} & {correct} & {yes_fail} & {no_fail} & {parse_err} & {missing} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


def generate_ec_overall_table(metrics: Dict[str, Dict[str, float]], output_path: Path) -> None:
    """Generate EC overall table (Table 9)."""
    models = [m for m in MODEL_ORDER if m in metrics]
    models = sorted(models, key=lambda m: metrics[m]["acc_25"], reverse=True)

    best_25 = find_best(metrics, "acc_25")
    best_val = find_best(metrics, "acc_all")
    best_cov = find_best(metrics, "coverage")

    lines = []
    lines.append("% EC v1 overall accuracy table (auto-generated)")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{EC v1 overall accuracy} (sorted by Acc@+25). Bold = best per column.}")
    lines.append("\\label{tab:ec_overall}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & @+25 & Valid & Cov & Bloat \\\\")
    lines.append("\\midrule")

    for model in models:
        m = metrics[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)

        acc_25 = fmt_pct(m["acc_25"], abs(m["acc_25"] - best_25) < 0.001)
        valid = fmt_pct(m["acc_all"], abs(m["acc_all"] - best_val) < 0.001)
        cov = fmt_pct(m["coverage"], abs(m["coverage"] - best_cov) < 0.001)
        bloat = fmt_pct(m["bloat"], False)

        lines.append(f"{display} & {acc_25} & {valid} & {cov} & {bloat} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


def generate_ec_band_table(
    band_metrics: Dict[str, Dict[str, Dict[str, float]]], output_path: Path
) -> None:
    """Generate EC band table (core vs hard)."""

    # Get models
    all_models = set()
    for b in band_metrics.values():
        all_models.update(b.keys())
    models = [m for m in MODEL_ORDER if m in all_models]

    # Sort by overall performance
    def avg_acc(m):
        accs = []
        for b in ["core", "hard"]:
            if b in band_metrics and m in band_metrics[b]:
                accs.append(band_metrics[b][m]["acc_25"])
        return sum(accs) / len(accs) if accs else 0

    models = sorted(models, key=avg_acc, reverse=True)

    # Find best per band (higher is better for all)
    core_best_25 = max(
        (band_metrics.get("core", {}).get(m, {}).get("acc_25", 0) for m in models), default=0
    )
    core_best_val = max(
        (band_metrics.get("core", {}).get(m, {}).get("acc_all", 0) for m in models), default=0
    )
    hard_best_25 = max(
        (band_metrics.get("hard", {}).get(m, {}).get("acc_25", 0) for m in models), default=0
    )
    hard_best_val = max(
        (band_metrics.get("hard", {}).get(m, {}).get("acc_all", 0) for m in models), default=0
    )

    lines = []
    lines.append("% EC v1 band table (auto-generated)")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append(
        "\\caption{\\textbf{EC v1 accuracy by band.} Core = QD=1 (120 problems); Hard = QD=2 (80 problems).}"
    )
    lines.append("\\label{tab:ec_bands}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & Core @+25 & Core Val & Hard @+25 & Hard Val \\\\")
    lines.append("\\midrule")

    for model in models:
        display = MODEL_DISPLAY_NAMES.get(model, model)

        if "core" in band_metrics and model in band_metrics["core"]:
            cm = band_metrics["core"][model]
            core_25 = fmt_pct(cm["acc_25"], abs(cm["acc_25"] - core_best_25) < 0.001)
            core_val = fmt_pct(cm["acc_all"], abs(cm["acc_all"] - core_best_val) < 0.001)
        else:
            core_25 = core_val = "---"

        if "hard" in band_metrics and model in band_metrics["hard"]:
            hm = band_metrics["hard"][model]
            hard_25 = fmt_pct(hm["acc_25"], abs(hm["acc_25"] - hard_best_25) < 0.001)
            hard_val = fmt_pct(hm["acc_all"], abs(hm["acc_all"] - hard_best_val) < 0.001)
        else:
            hard_25 = hard_val = "---"

        lines.append(f"{display} & {core_25} & {core_val} & {hard_25} & {hard_val} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


# =============================================================================
# GENERALIZATION AND ERROR PROFILE TABLES (Analysis A and B)
# =============================================================================


def generate_fo_generalization_table(holdout_data: Dict[str, Any], output_path: Path) -> None:
    """Generate FullObs generalization table: Compact vs Bloated valid formulas.

    This is the main finding: bloated valid formulas generalize worse to held-out worlds.
    """
    aggregates = holdout_data.get("aggregates", {})
    if not aggregates:
        print(f"  Warning: No holdout aggregates for FO generalization table")
        return

    # Sort models by train_correct_count (descending) - performance order
    models = sorted(
        [m for m in MODEL_ORDER if m in aggregates],
        key=lambda m: aggregates[m].get("train_correct_count", 0),
        reverse=True,
    )

    lines = []
    lines.append("% FullObs Generalization: Near-Gold vs Above-Gold (auto-generated)")
    lines.append("% Main finding: above-gold valid formulas generalize worse to held-out worlds")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{Held-out generalization by formula complexity (FullObs).}")
    lines.append("Near-gold: AST $\\leq$ gold+1; above-gold: AST $>$ gold+1.")
    lines.append("$\\Delta$ = near-gold $-$ above-gold.}")
    lines.append("\\label{tab:fo_generalization}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & \\#Valid & \\shortstack{Near-Gold\\\\Gen\\%} & \\shortstack{Above-Gold\\\\Gen\\%} & $\\Delta$ \\\\")
    lines.append("\\midrule")

    # First pass: collect values to find bests
    model_values = {}
    for model in models:
        agg = aggregates[model]
        ratio_data = agg.get("holdout_by_ast_ratio", {})
        cvb = ratio_data.get("compact_vs_bloated", {})
        compact_data = cvb.get("compact", {})
        bloated_data = cvb.get("bloated", {})

        n_valid = agg.get("train_correct_count", 0)
        compact_gen = compact_data.get("holdout_exact", 0) * 100 if compact_data else None
        bloated_gen = bloated_data.get("holdout_exact", 0) * 100 if bloated_data else None
        delta = (compact_gen - bloated_gen) if (compact_gen is not None and bloated_gen is not None) else None

        model_values[model] = {
            "n_valid": n_valid,
            "compact_gen": compact_gen,
            "bloated_gen": bloated_gen,
            "delta": delta,
        }

    # Filter to only models with both compact and bloated data
    models_with_both = [m for m in models if model_values[m]["delta"] is not None]

    # Find bests among models with both compact and bloated
    best_n_valid = max((model_values[m]["n_valid"] for m in models_with_both), default=0)
    best_compact = max((model_values[m]["compact_gen"] for m in models_with_both), default=0)
    best_bloated = max((model_values[m]["bloated_gen"] for m in models_with_both), default=0)

    for model in models_with_both:
        v = model_values[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)

        n_valid = v["n_valid"]
        n_valid_str = f"\\textbf{{{n_valid}}}" if n_valid == best_n_valid else str(n_valid)

        compact_str = f"{v['compact_gen']:.1f}\\%"
        if abs(v["compact_gen"] - best_compact) < 0.01:
            compact_str = f"\\textbf{{{compact_str}}}"

        bloated_str = f"{v['bloated_gen']:.1f}\\%"
        if abs(v["bloated_gen"] - best_bloated) < 0.01:
            bloated_str = f"\\textbf{{{bloated_str}}}"

        # Delta column: no bolding (just shows the difference)
        delta_str = f"+{v['delta']:.1f}\\%" if v["delta"] >= 0 else f"{v['delta']:.1f}\\%"

        lines.append(f"{display} & {n_valid_str} & {compact_str} & {bloated_str} & {delta_str} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


def generate_ci_generalization_table(holdout_data: Dict[str, Any], output_path: Path) -> None:
    """Generate CI generalization table: Compact vs Bloated valid formulas.

    For CI, we measure YES holdout exact-match rate (formula matches gold on new YES worlds).
    """
    aggregates = holdout_data.get("aggregates", {})
    if not aggregates:
        print(f"  Warning: No holdout aggregates for CI generalization table")
        return

    # Sort models by train_correct_count (descending) - performance order
    models = sorted(
        [m for m in MODEL_ORDER if m in aggregates],
        key=lambda m: aggregates[m].get("train_correct_count", 0),
        reverse=True,
    )

    lines = []
    lines.append("% CI Generalization: Near-Gold vs Above-Gold (auto-generated)")
    lines.append("% Shows YES holdout exact-match rate by formula complexity")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{Held-out generalization by formula complexity (CI).}")
    lines.append("For valid (train-correct) formulas, we compare YES holdout exact-match rates")
    lines.append("between \\emph{near-gold} formulas (AST $\\leq$ gold+1) and \\emph{above-gold}")
    lines.append("formulas (AST $>$ gold+1). $\\Delta$ = near-gold $-$ above-gold; positive values")
    lines.append("indicate near-gold formulas generalize better to new YES worlds.}")
    lines.append("\\label{tab:ci_generalization}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & \\#Valid & \\shortstack{Near-Gold\\\\Gen\\%} & \\shortstack{Above-Gold\\\\Gen\\%} & $\\Delta$ \\\\")
    lines.append("\\midrule")

    # First pass: collect values to find bests
    model_values = {}
    for model in models:
        agg = aggregates[model]
        ratio_data = agg.get("holdout_by_ast_ratio", {})
        cvb = ratio_data.get("compact_vs_bloated", {})
        compact_data = cvb.get("compact", {})
        bloated_data = cvb.get("bloated", {})

        n_valid = agg.get("train_correct_count", 0)
        compact_gen = compact_data.get("holdout_exact", 0) * 100 if compact_data else None
        bloated_gen = bloated_data.get("holdout_exact", 0) * 100 if bloated_data else None
        delta = (compact_gen - bloated_gen) if (compact_gen is not None and bloated_gen is not None) else None

        model_values[model] = {
            "n_valid": n_valid,
            "compact_gen": compact_gen,
            "bloated_gen": bloated_gen,
            "delta": delta,
        }

    # Filter to only models with both compact and bloated data
    models_with_both = [m for m in models if model_values[m]["delta"] is not None]

    # Find bests among models with both: #Valid highest, Compact highest, Bloated highest
    best_n_valid = max((model_values[m]["n_valid"] for m in models_with_both), default=0)
    best_compact = max((model_values[m]["compact_gen"] for m in models_with_both), default=0)
    best_bloated = max((model_values[m]["bloated_gen"] for m in models_with_both), default=0)

    for model in models_with_both:
        v = model_values[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)

        n_valid = v["n_valid"]
        n_valid_str = f"\\textbf{{{n_valid}}}" if n_valid == best_n_valid else str(n_valid)

        compact_str = f"{v['compact_gen']:.1f}\\%"
        if abs(v["compact_gen"] - best_compact) < 0.01:
            compact_str = f"\\textbf{{{compact_str}}}"

        bloated_str = f"{v['bloated_gen']:.1f}\\%"
        if abs(v["bloated_gen"] - best_bloated) < 0.01:
            bloated_str = f"\\textbf{{{bloated_str}}}"

        # Delta column: no bolding (just shows the difference)
        delta_str = f"+{v['delta']:.1f}\\%" if v["delta"] >= 0 else f"{v['delta']:.1f}\\%"

        lines.append(f"{display} & {n_valid_str} & {compact_str} & {bloated_str} & {delta_str} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


def aggregate_fpfn_profiles(profiles_data: Dict[str, Any]) -> Dict[str, Dict[str, float]]:
    """Aggregate FP/FN profiles by model."""
    profiles = profiles_data.get("profiles", [])

    by_model = defaultdict(
        lambda: {
            "count": 0,
            "correct": 0,
            "total_fp_rate": 0.0,
            "total_fn_rate": 0.0,
            "yes_fp_rate": 0.0,
            "yes_fn_rate": 0.0,
            "yes_count": 0,
            "no_margin_sum": 0.0,
            "no_count": 0,
        }
    )

    for p in profiles:
        model = p.get("model", "")
        by_model[model]["count"] += 1
        if p.get("correct"):
            by_model[model]["correct"] += 1

        # Overall FP/FN (mean across worlds in this instance)
        by_model[model]["total_fp_rate"] += p.get("mean_fp_rate", 0) or 0
        by_model[model]["total_fn_rate"] += p.get("mean_fn_rate", 0) or 0

        # CI-specific
        if p.get("yes_mean_fp_rate") is not None:
            by_model[model]["yes_fp_rate"] += p.get("yes_mean_fp_rate", 0)
            by_model[model]["yes_fn_rate"] += p.get("yes_mean_fn_rate", 0) or 0
            by_model[model]["yes_count"] += 1

        if p.get("no_mean_margin") is not None:
            by_model[model]["no_margin_sum"] += p.get("no_mean_margin", 0)
            by_model[model]["no_count"] += 1

    # Compute means
    result = {}
    for model, m in by_model.items():
        n = m["count"]
        if n == 0:
            continue
        result[model] = {
            "count": n,
            "correct": m["correct"],
            "mean_fp": m["total_fp_rate"] / n,
            "mean_fn": m["total_fn_rate"] / n,
            "yes_fp": m["yes_fp_rate"] / m["yes_count"] if m["yes_count"] > 0 else None,
            "yes_fn": m["yes_fn_rate"] / m["yes_count"] if m["yes_count"] > 0 else None,
            "no_margin": m["no_margin_sum"] / m["no_count"] if m["no_count"] > 0 else None,
        }

    return result


def generate_ec_best_completion_table(
    ec_best_completion_data: Dict[str, Any], output_path: Path
) -> None:
    """Generate EC best-completion analysis table for appendix.

    Shows for each model:
    - Total predictions analyzed
    - EC Valid count (passed exists-completion check)
    - Invalid count
    - Mean minimum mismatches for invalid predictions
    - Distribution of minimum mismatches (0, 1-2, 3-5, 6+)
    """
    results = ec_best_completion_data.get("results", [])
    if not results:
        print(f"  Warning: No EC best-completion results")
        return

    # Aggregate by model
    by_model = defaultdict(
        lambda: {
            "total": 0,
            "ec_valid": 0,
            "invalid": 0,
            "min_mm_sum": 0,
            "min_mm_counts": [],  # for distribution
        }
    )

    for r in results:
        model = r.get("model", "")
        if not model:
            continue

        by_model[model]["total"] += 1

        if r.get("ec_valid", False):
            by_model[model]["ec_valid"] += 1
        else:
            by_model[model]["invalid"] += 1
            mm = r.get("total_min_mismatches", 0)
            # Check if this is a valid measurement (not an error case)
            # Error cases have world_results with negative min_mismatches or grounding errors
            world_results = r.get("world_results", [])
            has_error = any(
                w.get("min_mismatches", 0) < 0 or "error" in w.get("solver_status", "").lower()
                for w in world_results
            )
            if mm >= 0 and not has_error:  # Valid measurement, no errors
                by_model[model]["min_mm_sum"] += mm
                by_model[model]["min_mm_counts"].append(mm)

    # Sort models by EC Valid count (descending)
    models = sorted(
        [m for m in MODEL_ORDER if m in by_model],
        key=lambda m: by_model[m]["ec_valid"],
        reverse=True,
    )

    # Compute values for finding bests
    model_values = {}
    for model in models:
        m = by_model[model]
        total = m["total"]
        ec_valid = m["ec_valid"]
        counts = m["min_mm_counts"]
        n_invalid = len(counts) if counts else 1

        if m["min_mm_counts"]:
            mean_mm = m["min_mm_sum"] / len(m["min_mm_counts"])
        else:
            mean_mm = None

        mm_1_2 = sum(1 for mm in counts if 1 <= mm <= 2)
        mm_3plus = sum(1 for mm in counts if mm >= 3)
        mm_1_2_pct = 100 * mm_1_2 / n_invalid if counts else None
        mm_3plus_pct = 100 * mm_3plus / n_invalid if counts else None

        model_values[model] = {
            "total": total,
            "ec_valid": ec_valid,
            "mean_mm": mean_mm,
            "mm_1_2_pct": mm_1_2_pct,
            "mm_3plus_pct": mm_3plus_pct,
        }

    # Find bests: Valid highest, Mean MM lowest, 1-2 highest (closer to valid), 3+ lowest
    best_valid = max((v["ec_valid"] for v in model_values.values()), default=0)
    best_mean_mm = min((v["mean_mm"] for v in model_values.values() if v["mean_mm"] is not None), default=100)
    best_mm_1_2 = max((v["mm_1_2_pct"] for v in model_values.values() if v["mm_1_2_pct"] is not None), default=0)
    best_mm_3plus = min((v["mm_3plus_pct"] for v in model_values.values() if v["mm_3plus_pct"] is not None), default=100)

    lines = []
    lines.append("% EC Best-Completion Analysis (auto-generated)")
    lines.append("% Shows how close invalid predictions are to being valid")
    lines.append("\\begin{table}[h]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{EC best-completion error analysis.}")
    lines.append("For predictions that fail the exists-completion (EC) validity check,")
    lines.append("we compute the \\emph{minimum mismatches} achievable under any completion")
    lines.append("of unknown atoms. A formula with min-mismatch=0 would be EC-valid;")
    lines.append("higher values indicate how far the formula is from validity.")
    lines.append("Mean MM = mean minimum mismatches for invalid predictions.")
    lines.append("Distribution columns show the fraction of invalid predictions in each")
    lines.append("mismatch range.}")
    lines.append("\\label{tab:ec_best_completion}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & Total & Valid & Mean MM & 1--2 & $\\geq$3 \\\\")
    lines.append("\\midrule")

    for model in models:
        v = model_values[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)

        total = v["total"]
        ec_valid = v["ec_valid"]
        valid_pct = f"({100 * ec_valid / total:.0f}\\%)" if total > 0 else ""
        valid_str = f"{ec_valid} {valid_pct}"
        if ec_valid == best_valid:
            valid_str = f"\\textbf{{{ec_valid}}} {valid_pct}"

        if v["mean_mm"] is not None:
            mean_mm_str = f"{v['mean_mm']:.1f}"
            if abs(v["mean_mm"] - best_mean_mm) < 0.01:
                mean_mm_str = f"\\textbf{{{mean_mm_str}}}"
        else:
            mean_mm_str = "---"

        if v["mm_1_2_pct"] is not None:
            mm_1_2_str = f"{v['mm_1_2_pct']:.0f}\\%"
            if abs(v["mm_1_2_pct"] - best_mm_1_2) < 0.1:
                mm_1_2_str = f"\\textbf{{{mm_1_2_str}}}"
        else:
            mm_1_2_str = "---"

        if v["mm_3plus_pct"] is not None:
            mm_3plus_str = f"{v['mm_3plus_pct']:.0f}\\%"
            if abs(v["mm_3plus_pct"] - best_mm_3plus) < 0.1:
                mm_3plus_str = f"\\textbf{{{mm_3plus_str}}}"
        else:
            mm_3plus_str = "---"

        lines.append(
            f"{display} & {total} & {valid_str} & {mean_mm_str} & {mm_1_2_str} & {mm_3plus_str} \\\\"
        )

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


def generate_error_profiles_table(
    fo_fpfn: Dict[str, Any], ci_fpfn: Dict[str, Any], output_path: Path
) -> None:
    """Generate training error profiles table (FP/FN rates).

    Shows systematic error biases across tasks.
    """
    # Aggregate from profiles
    fo_agg = aggregate_fpfn_profiles(fo_fpfn) if fo_fpfn else {}
    ci_agg = aggregate_fpfn_profiles(ci_fpfn) if ci_fpfn else {}

    # Get all models
    all_models = set(fo_agg.keys()) | set(ci_agg.keys())
    models = [m for m in MODEL_ORDER if m in all_models]

    # Sort by FO correct count
    models = sorted(models, key=lambda m: fo_agg.get(m, {}).get("correct", 0), reverse=True)

    # Collect values for finding bests
    model_values = {}
    for model in models:
        fo_m = fo_agg.get(model, {})
        ci_m = ci_agg.get(model, {})
        model_values[model] = {
            "fo_fp": fo_m.get("mean_fp", 0) * 100 if fo_m else None,
            "fo_fn": fo_m.get("mean_fn", 0) * 100 if fo_m else None,
            "ci_yes_fp": ci_m.get("yes_fp", 0) * 100 if ci_m.get("yes_fp") is not None else None,
            "ci_yes_fn": ci_m.get("yes_fn", 0) * 100 if ci_m.get("yes_fn") is not None else None,
            "ci_no_margin": ci_m.get("no_margin") if ci_m else None,
        }

    # Find bests: FP/FN lower is better, NO Margin higher is better
    best_fo_fp = min((v["fo_fp"] for v in model_values.values() if v["fo_fp"] is not None), default=100)
    best_fo_fn = min((v["fo_fn"] for v in model_values.values() if v["fo_fn"] is not None), default=100)
    best_ci_yes_fp = min((v["ci_yes_fp"] for v in model_values.values() if v["ci_yes_fp"] is not None), default=100)
    best_ci_yes_fn = min((v["ci_yes_fn"] for v in model_values.values() if v["ci_yes_fn"] is not None), default=100)
    best_no_margin = max((v["ci_no_margin"] for v in model_values.values() if v["ci_no_margin"] is not None), default=0)

    lines = []
    lines.append("% Training Error Profiles: FP/FN rates (auto-generated)")
    lines.append("\\begin{table}[h]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{Training error profiles.}")
    lines.append("Mean per-world false positive (FP\\%) and false negative (FN\\%) rates.")
    lines.append("For FullObs, rates are averaged across all worlds.")
    lines.append("For CI, YES rates are on YES worlds only; NO Margin is mean mismatches")
    lines.append("on NO worlds (higher = better separation from NO targets).}")
    lines.append("\\label{tab:error_profiles}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}l|rr|rrr@{}}")
    lines.append("\\toprule")
    lines.append(" & \\multicolumn{2}{c|}{FullObs} & \\multicolumn{3}{c}{CI} \\\\")
    lines.append("Model & FP\\% & FN\\% & YES FP\\% & YES FN\\% & NO Marg \\\\")
    lines.append("\\midrule")

    for model in models:
        display = MODEL_DISPLAY_NAMES.get(model, model)
        v = model_values[model]

        if v["fo_fp"] is not None:
            fo_fp_str = f"{v['fo_fp']:.1f}"
            if abs(v["fo_fp"] - best_fo_fp) < 0.01:
                fo_fp_str = f"\\textbf{{{fo_fp_str}}}"
        else:
            fo_fp_str = "---"

        if v["fo_fn"] is not None:
            fo_fn_str = f"{v['fo_fn']:.1f}"
            if abs(v["fo_fn"] - best_fo_fn) < 0.01:
                fo_fn_str = f"\\textbf{{{fo_fn_str}}}"
        else:
            fo_fn_str = "---"

        if v["ci_yes_fp"] is not None:
            ci_yes_fp_str = f"{v['ci_yes_fp']:.1f}"
            if abs(v["ci_yes_fp"] - best_ci_yes_fp) < 0.01:
                ci_yes_fp_str = f"\\textbf{{{ci_yes_fp_str}}}"
        else:
            ci_yes_fp_str = "---"

        if v["ci_yes_fn"] is not None:
            ci_yes_fn_str = f"{v['ci_yes_fn']:.1f}"
            if abs(v["ci_yes_fn"] - best_ci_yes_fn) < 0.01:
                ci_yes_fn_str = f"\\textbf{{{ci_yes_fn_str}}}"
        else:
            ci_yes_fn_str = "---"

        if v["ci_no_margin"] is not None:
            ci_no_margin_str = f"{v['ci_no_margin']:.1f}"
            if abs(v["ci_no_margin"] - best_no_margin) < 0.01:
                ci_no_margin_str = f"\\textbf{{{ci_no_margin_str}}}"
        else:
            ci_no_margin_str = "---"

        lines.append(
            f"{display} & {fo_fp_str} & {fo_fn_str} & {ci_yes_fp_str} & {ci_yes_fn_str} & {ci_no_margin_str} \\\\"
        )

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


def analyze_equality_usage(df: pd.DataFrame) -> Dict[str, Dict[str, float]]:
    """Analyze equality usage in predictions.

    Returns per-model stats: total (responses returned), with_eq, with_eq_valid, eq_ast_sum.
    """
    stats = defaultdict(lambda: {'total': 0, 'with_eq': 0, 'with_eq_valid': 0, 'eq_ast_sum': 0})

    for _, row in df.iterrows():
        model = row.get('model', '')
        if not model:
            continue

        # Get predicted formula from metadata
        metadata = row.get('metadata', {})
        if isinstance(metadata, str):
            try:
                import json
                metadata = json.loads(metadata)
            except:
                metadata = {}

        pred_sexpr = metadata.get('pred_sexpr', '') if isinstance(metadata, dict) else ''
        if not pred_sexpr:
            continue

        # Only count problems where model returned a response
        stats[model]['total'] += 1

        # Check for equality: (= x y) pattern
        has_equality = '(= ' in str(pred_sexpr) or '(=(' in str(pred_sexpr)

        if has_equality:
            stats[model]['with_eq'] += 1
            stats[model]['eq_ast_sum'] += row.get('pred_ast', 0) or 0
            if row.get('valid', False):
                stats[model]['with_eq_valid'] += 1

    return dict(stats)


def generate_equality_usage_table(
    fo_df: pd.DataFrame, ci_df: pd.DataFrame, ec_df: pd.DataFrame, output_path: Path
) -> None:
    """Generate equality usage analysis tables for all three tasks.

    Creates a combined appendix file with tables showing how often models
    use equality predicates and their success rates.
    """
    lines = []
    lines.append("% Equality Usage Analysis (auto-generated)")
    lines.append("% Shows frequency and validity of formulas containing equality predicates")
    lines.append("")

    # Analyze each task
    fo_stats = analyze_equality_usage(fo_df) if len(fo_df) > 0 else {}
    ci_stats = analyze_equality_usage(ci_df) if len(ci_df) > 0 else {}
    ec_stats = analyze_equality_usage(ec_df) if len(ec_df) > 0 else {}

    def make_table(stats: Dict, task_name: str, label: str, n_instances: int) -> List[str]:
        """Generate a single equality table."""
        tbl = []
        tbl.append(f"\\begin{{table}}[h]")
        tbl.append("\\centering")
        tbl.append("\\small")
        tbl.append(f"\\caption{{\\textbf{{Equality usage in {task_name} predictions.}} "
                   f"\\emph{{Total}} = number of responses returned by model; "
                   f"\\emph{{\\% Using =}} = fraction of predictions containing equality; "
                   f"\\emph{{Avg AST}} = mean AST size of equality-containing formulas; "
                   f"\\emph{{Valid \\%}} = validity rate among equality-containing predictions.}}")
        tbl.append(f"\\label{{{label}}}")
        tbl.append("\\begin{tabular}{@{}lrrrr@{}}")
        tbl.append("\\toprule")
        tbl.append("Model & Total & \\% Using = & Avg AST & Valid \\% \\\\")
        tbl.append("\\midrule")

        models = [m for m in MODEL_ORDER if m in stats]

        # Find bests for bolding
        eq_pcts = [stats[m]['with_eq'] / stats[m]['total'] * 100 for m in models if stats[m]['total'] > 0]
        valid_pcts = [stats[m]['with_eq_valid'] / stats[m]['with_eq'] * 100 for m in models if stats[m]['with_eq'] > 0]
        best_eq_pct = max(eq_pcts) if eq_pcts else 0
        best_valid_pct = max(valid_pcts) if valid_pcts else 0

        for model in models:
            s = stats[model]
            display = MODEL_DISPLAY_NAMES.get(model, model)
            total = s['total']

            if total > 0:
                eq_pct = 100 * s['with_eq'] / total
                eq_pct_str = f"{eq_pct:.1f}\\%"
                if abs(eq_pct - best_eq_pct) < 0.1 and eq_pct > 0:
                    eq_pct_str = f"\\textbf{{{eq_pct_str}}}"
            else:
                eq_pct_str = "---"

            if s['with_eq'] > 0:
                avg_ast = s['eq_ast_sum'] / s['with_eq']
                avg_ast_str = f"{avg_ast:.1f}"
                valid_pct = 100 * s['with_eq_valid'] / s['with_eq']
                valid_pct_str = f"{valid_pct:.1f}\\%"
                if abs(valid_pct - best_valid_pct) < 0.1:
                    valid_pct_str = f"\\textbf{{{valid_pct_str}}}"
            else:
                avg_ast_str = "---"
                valid_pct_str = "---"

            tbl.append(f"{display} & {total} & {eq_pct_str} & {avg_ast_str} & {valid_pct_str} \\\\")

        # Add totals row
        total_all = sum(s['total'] for s in stats.values())
        total_eq = sum(s['with_eq'] for s in stats.values())
        total_eq_valid = sum(s['with_eq_valid'] for s in stats.values())
        total_eq_ast = sum(s['eq_ast_sum'] for s in stats.values())

        tbl.append("\\midrule")
        if total_all > 0 and total_eq > 0:
            eq_pct = 100 * total_eq / total_all
            avg_ast = total_eq_ast / total_eq
            valid_pct = 100 * total_eq_valid / total_eq
            tbl.append(f"\\textit{{Total}} & {total_all} & {eq_pct:.1f}\\% & {avg_ast:.1f} & {valid_pct:.1f}\\% \\\\")
        else:
            tbl.append(f"\\textit{{Total}} & {total_all} & --- & --- & --- \\\\")

        tbl.append("\\bottomrule")
        tbl.append("\\end{tabular}")
        tbl.append("\\end{table}")
        tbl.append("")
        return tbl

    # Generate tables for each task
    lines.extend(make_table(fo_stats, "FullObs", "tab:eq_fullobs", 375))
    lines.extend(make_table(ci_stats, "CI", "tab:eq_ci", 200))
    lines.extend(make_table(ec_stats, "EC", "tab:eq_ec", 200))

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


# =============================================================================
# APPENDIX TABLE GENERATORS
# =============================================================================


def generate_fo_appendix_tables(
    metrics: Dict[str, Dict[str, float]], df: pd.DataFrame, output_path: Path
) -> None:
    """Generate FullObs appendix tables (overall, family, failure modes)."""

    lines = []
    lines.append("% FullObs Appendix Tables (auto-generated)")
    lines.append("")

    # Overall table (moved from main)
    models = [m for m in MODEL_ORDER if m in metrics]
    models = sorted(models, key=lambda m: metrics[m]["acc_25"], reverse=True)

    best_25 = find_best(metrics, "acc_25")
    best_acc = find_best(metrics, "acc_all")
    best_cov = find_best(metrics, "coverage")
    best_bloat = find_best(metrics, "bloat", higher_better=False)

    lines.append("% FullObs overall table")
    lines.append("\\begin{table}[h]")
    lines.append("\\centering")
    lines.append(
        "\\caption{\\textbf{FullObs v1 overall accuracy} (sorted by Acc@+25). Bold = best per column.}"
    )
    lines.append("\\label{tab:fo_overall}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & @+25 & Acc\\_all & Cov & Bloat \\\\")
    lines.append("\\midrule")

    for model in models:
        m = metrics[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)
        acc_25 = fmt_pct(m["acc_25"], abs(m["acc_25"] - best_25) < 0.001)
        acc_all = fmt_pct(m["acc_all"], abs(m["acc_all"] - best_acc) < 0.001)
        cov = fmt_pct(m["coverage"], abs(m["coverage"] - best_cov) < 0.001)
        bloat = fmt_pct(m["bloat"], abs(m["bloat"] - best_bloat) < 0.001)
        lines.append(f"{display} & {acc_25} & {acc_all} & {cov} & {bloat} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")
    lines.append("")

    # Family breakdown
    lines.append("% Table 2: FullObs family breakdown")
    lines.append("\\begin{table}[h]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{FullObs v1 Acc@+25 by formula family.}}")
    lines.append("\\label{tab:fo_family}")
    lines.append("\\small")

    # Compute family metrics
    family_metrics = {}
    for family in df["family_id_gold"].unique():
        if pd.isna(family):
            continue
        family_df = df[df["family_id_gold"] == family]
        family_metrics[str(family)] = {}
        for model in family_df["model"].unique():
            model_df = family_df[family_df["model"] == model]
            total = len(model_df)
            if total > 0:
                family_metrics[str(family)][model] = {
                    "acc_25": model_df["budget_ok_25"].sum() / total,
                    "total": total,
                }

    families = sorted(family_metrics.keys())
    models = [m for m in MODEL_ORDER if m in metrics][:5]

    header_models = " & ".join([MODEL_DISPLAY_NAMES.get(m, m)[:8] for m in models])
    lines.append(f"\\begin{{tabular}}{{@{{}}l{'r' * len(models)}@{{}}}}")
    lines.append("\\toprule")
    lines.append(f"Family & {header_models} \\\\")
    lines.append("\\midrule")

    for family in families:
        family_display = FAMILY_DISPLAY_NAMES.get(family, family)
        vals = []
        best = max((family_metrics[family].get(m, {}).get("acc_25", 0) for m in models), default=0)
        for model in models:
            if model in family_metrics[family]:
                acc = family_metrics[family][model]["acc_25"]
                is_best = abs(acc - best) < 0.001 and best > 0
                vals.append(fmt_pct(acc, is_best))
            else:
                vals.append("---")
        lines.append(f"{family_display} & {' & '.join(vals)} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")
    lines.append("")

    # Table 3: Formula size breakdown (omit Missing and Parse columns)
    # Find bests for formula size breakdown
    # Compact and Equal: higher is better; Longer and Bloat: lower is better
    best_compact = find_best(metrics, "compact")
    best_equal = find_best(metrics, "equal")
    best_longer = find_best(metrics, "longer", higher_better=False)
    best_bloat_size = find_best(metrics, "bloat", higher_better=False)

    lines.append("% Table 3: FullObs formula size breakdown")
    lines.append("\\begin{table}[h]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{FullObs v1 formula size breakdown for valid predictions.}")
    lines.append("Compact = AST $<$ gold; Equal = gold $\\leq$ AST $\\leq$ gold+1;")
    lines.append("Longer = gold+1 $<$ AST $\\leq$ gold+25; Bloat = AST $>$ gold+25.}")
    lines.append("\\label{tab:fo_failures}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & Compact & Equal & Longer & Bloat \\\\")
    lines.append("\\midrule")

    models_sorted = sorted(
        [m for m in MODEL_ORDER if m in metrics], key=lambda m: metrics[m]["acc_25"], reverse=True
    )

    for model in models_sorted:
        m = metrics[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)
        compact_str = fmt_pct(m['compact'], abs(m['compact'] - best_compact) < 0.001)
        equal_str = fmt_pct(m['equal'], abs(m['equal'] - best_equal) < 0.001)
        longer_str = fmt_pct(m['longer'], abs(m['longer'] - best_longer) < 0.001)
        bloat_str = fmt_pct(m['bloat'], abs(m['bloat'] - best_bloat_size) < 0.001)
        lines.append(
            f"{display} & "
            f"{compact_str} & {equal_str} & {longer_str} & {bloat_str} \\\\"
        )

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


def generate_ci_appendix_tables(
    metrics: Dict[str, Dict[str, float]],
    df: pd.DataFrame,
    band_metrics: Dict[str, Dict[str, Dict[str, float]]],
    output_path: Path,
) -> None:
    """Generate CI appendix tables (overall, family, band failure)."""

    lines = []
    lines.append("% CI Appendix Tables (auto-generated)")
    lines.append("")

    # CI overall table (moved from main)
    models = [m for m in MODEL_ORDER if m in metrics]
    models_sorted = sorted(models, key=lambda m: metrics[m]["acc_25"], reverse=True)

    best_25 = find_best(metrics, "acc_25")
    best_acc = find_best(metrics, "acc_all")
    best_cov = find_best(metrics, "coverage")
    best_bloat = find_best(metrics, "bloat", higher_better=False)

    lines.append("% CI overall table")
    lines.append("\\begin{table}[h]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{CI v1 overall accuracy} (sorted by Acc@+25). Bold = best per column.}")
    lines.append("\\label{tab:ci_overall}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & @+25 & Acc\\_all & Cov & Bloat \\\\")
    lines.append("\\midrule")

    for model in models_sorted:
        m = metrics[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)
        acc_25 = fmt_pct(m["acc_25"], abs(m["acc_25"] - best_25) < 0.001)
        acc_all = fmt_pct(m["acc_all"], abs(m["acc_all"] - best_acc) < 0.001)
        cov = fmt_pct(m["coverage"], abs(m["coverage"] - best_cov) < 0.001)
        bloat = fmt_pct(m["bloat"], abs(m["bloat"] - best_bloat) < 0.001)
        lines.append(f"{display} & {acc_25} & {acc_all} & {cov} & {bloat} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")
    lines.append("")

    # Family breakdown
    lines.append("% Table 6: CI family breakdown")
    lines.append("\\begin{table}[h]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{CI v1 Acc@+25 by formula family.}}")
    lines.append("\\label{tab:ci_family}")
    lines.append("\\small")

    # Compute family metrics
    family_metrics = {}
    for family in df["family_id_gold"].unique():
        if pd.isna(family):
            continue
        family_df = df[df["family_id_gold"] == family]
        family_metrics[str(family)] = {}
        for model in family_df["model"].unique():
            model_df = family_df[family_df["model"] == model]
            total = len(model_df)
            if total > 0:
                family_metrics[str(family)][model] = {
                    "acc_25": model_df["budget_ok_25"].sum() / total,
                    "total": total,
                }

    families = sorted(family_metrics.keys())
    models = [m for m in MODEL_ORDER if m in metrics][:5]

    header_models = " & ".join([MODEL_DISPLAY_NAMES.get(m, m)[:8] for m in models])
    lines.append(f"\\begin{{tabular}}{{@{{}}l{'r' * len(models)}@{{}}}}")
    lines.append("\\toprule")
    lines.append(f"Family & {header_models} \\\\")
    lines.append("\\midrule")

    for family in families:
        family_display = FAMILY_DISPLAY_NAMES.get(family, family)
        vals = []
        best = max((family_metrics[family].get(m, {}).get("acc_25", 0) for m in models), default=0)
        for model in models:
            if model in family_metrics[family]:
                acc = family_metrics[family][model]["acc_25"]
                is_best = abs(acc - best) < 0.001 and best > 0
                vals.append(fmt_pct(acc, is_best))
            else:
                vals.append("---")
        lines.append(f"{family_display} & {' & '.join(vals)} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")
    lines.append("")

    # Table 8: Band failure breakdown
    lines.append("% Table 8: CI band failure breakdown")
    for band in sorted(band_metrics.keys()):
        band_escaped = band.replace("_", "\\_")

        band_df = df[df["band"] == band]
        band_model_metrics = {}
        for model in band_df["model"].unique():
            model_df = band_df[band_df["model"] == model]
            total = len(model_df)
            if total > 0:
                band_model_metrics[model] = {
                    "correct": model_df["valid"].sum() / total,
                    "yes_fail": (model_df["fail_mode"] == "yes_fail").sum() / total,
                    "no_fail": (model_df["fail_mode"] == "no_fail").sum() / total,
                    "parse": (model_df["fail_mode"] == "parse").sum() / total,
                    "missing": (model_df["fail_mode"] == "missing").sum() / total,
                }

        # Find bests for this band: Correct highest, YES-fail/NO-fail lowest
        best_correct = max((m["correct"] for m in band_model_metrics.values()), default=0)
        best_yes_fail = min((m["yes_fail"] for m in band_model_metrics.values()), default=1)
        best_no_fail = min((m["no_fail"] for m in band_model_metrics.values()), default=1)

        lines.append(f"\\begin{{table}}[h]")
        lines.append("\\centering")
        lines.append(f"\\caption{{\\textbf{{CI v1 failure modes: {band_escaped} band.}} YES-fail = formula doesn't match positive examples; NO-fail = formula accidentally matches negative examples.}}")
        lines.append(f"\\label{{tab:ci_band_{band}}}")
        lines.append("\\small")
        # Omit Parse and Missing columns for cleaner presentation
        lines.append("\\begin{tabular}{@{}lrrr@{}}")
        lines.append("\\toprule")
        lines.append("Model & Correct & YES-fail & NO-fail \\\\")
        lines.append("\\midrule")

        models_sorted = sorted(
            [m for m in MODEL_ORDER if m in band_model_metrics],
            key=lambda m: band_model_metrics[m]["correct"],
            reverse=True,
        )

        for model in models_sorted:
            m = band_model_metrics[model]
            display = MODEL_DISPLAY_NAMES.get(model, model)
            correct_str = fmt_pct(m['correct'], abs(m['correct'] - best_correct) < 0.001)
            yes_fail_str = fmt_pct(m['yes_fail'], abs(m['yes_fail'] - best_yes_fail) < 0.001)
            no_fail_str = fmt_pct(m['no_fail'], abs(m['no_fail'] - best_no_fail) < 0.001)
            lines.append(
                f"{display} & {correct_str} & {yes_fail_str} & "
                f"{no_fail_str} \\\\"
            )

        lines.append("\\bottomrule")
        lines.append("\\end{tabular}")
        lines.append("\\end{table}")
        lines.append("")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


def generate_ec_appendix_tables(
    metrics: Dict[str, Dict[str, float]], df: pd.DataFrame, output_path: Path
) -> None:
    """Generate EC appendix tables (overall, family, failure/world-fail)."""

    lines = []
    lines.append("% EC Appendix Tables (auto-generated)")
    lines.append("")

    # EC overall table (moved from main)
    models = [m for m in MODEL_ORDER if m in metrics]
    models_sorted = sorted(models, key=lambda m: metrics[m]["acc_25"], reverse=True)

    best_25 = find_best(metrics, "acc_25")
    best_val = find_best(metrics, "acc_all")
    best_cov = find_best(metrics, "coverage")
    best_bloat = find_best(metrics, "bloat", higher_better=False)

    lines.append("% EC overall table")
    lines.append("\\begin{table}[h]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{EC v1 overall accuracy} (sorted by Acc@+25). Bold = best per column.}")
    lines.append("\\label{tab:ec_overall}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & @+25 & Valid & Cov & Bloat \\\\")
    lines.append("\\midrule")

    for model in models_sorted:
        m = metrics[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)
        acc_25 = fmt_pct(m["acc_25"], abs(m["acc_25"] - best_25) < 0.001)
        valid = fmt_pct(m["acc_all"], abs(m["acc_all"] - best_val) < 0.001)
        cov = fmt_pct(m["coverage"], abs(m["coverage"] - best_cov) < 0.001)
        bloat = fmt_pct(m["bloat"], abs(m["bloat"] - best_bloat) < 0.001)
        lines.append(f"{display} & {acc_25} & {valid} & {cov} & {bloat} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")
    lines.append("")

    # Family breakdown
    lines.append("% Table 10: EC family breakdown")
    lines.append("\\begin{table}[h]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{EC v1 Acc@+25 by formula family.}}")
    lines.append("\\label{tab:ec_family}")
    lines.append("\\small")

    # Compute family metrics
    family_metrics = {}
    for family in df["family_id_gold"].unique():
        if pd.isna(family):
            continue
        family_df = df[df["family_id_gold"] == family]
        family_metrics[str(family)] = {}
        for model in family_df["model"].unique():
            model_df = family_df[family_df["model"] == model]
            total = len(model_df)
            if total > 0:
                family_metrics[str(family)][model] = {
                    "acc_25": model_df["budget_ok_25"].sum() / total,
                    "total": total,
                }

    families = sorted(family_metrics.keys())
    models = [m for m in MODEL_ORDER if m in metrics][:5]

    header_models = " & ".join([MODEL_DISPLAY_NAMES.get(m, m)[:8] for m in models])
    lines.append(f"\\begin{{tabular}}{{@{{}}l{'r' * len(models)}@{{}}}}")
    lines.append("\\toprule")
    lines.append(f"Family & {header_models} \\\\")
    lines.append("\\midrule")

    for family in families:
        family_display = FAMILY_DISPLAY_NAMES.get(family, family)
        vals = []
        best = max((family_metrics[family].get(m, {}).get("acc_25", 0) for m in models), default=0)
        for model in models:
            if model in family_metrics[family]:
                acc = family_metrics[family][model]["acc_25"]
                is_best = abs(acc - best) < 0.001 and best > 0
                vals.append(fmt_pct(acc, is_best))
            else:
                vals.append("---")
        lines.append(f"{family_display} & {' & '.join(vals)} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")
    lines.append("")

    # Table 11: Formula size breakdown (omit Missing and Parse columns)
    # Find bests for formula size breakdown
    # Compact and Equal: higher is better; Longer and Bloat: lower is better
    best_compact = find_best(metrics, "compact")
    best_equal = find_best(metrics, "equal")
    best_longer = find_best(metrics, "longer", higher_better=False)
    best_bloat_size = find_best(metrics, "bloat", higher_better=False)

    lines.append("% Table 11: EC formula size breakdown")
    lines.append("\\begin{table}[h]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{EC v1 formula size breakdown for valid predictions.}")
    lines.append("Compact = AST $<$ gold; Equal = gold $\\leq$ AST $\\leq$ gold+1;")
    lines.append("Longer = gold+1 $<$ AST $\\leq$ gold+25; Bloat = AST $>$ gold+25.}")
    lines.append("\\label{tab:ec_failures}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & Compact & Equal & Longer & Bloat \\\\")
    lines.append("\\midrule")

    models_sorted = sorted(
        [m for m in MODEL_ORDER if m in metrics], key=lambda m: metrics[m]["acc_25"], reverse=True
    )

    for model in models_sorted:
        m = metrics[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)
        compact_str = fmt_pct(m['compact'], abs(m['compact'] - best_compact) < 0.001)
        equal_str = fmt_pct(m['equal'], abs(m['equal'] - best_equal) < 0.001)
        longer_str = fmt_pct(m['longer'], abs(m['longer'] - best_longer) < 0.001)
        bloat_str = fmt_pct(m['bloat'], abs(m['bloat'] - best_bloat_size) < 0.001)
        lines.append(
            f"{display} & "
            f"{compact_str} & {equal_str} & {longer_str} & {bloat_str} \\\\"
        )

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write("\n".join(lines))
    print(f"Generated {output_path}")


# =============================================================================
# MAIN EXPORT FUNCTION
# =============================================================================


def export_paper_artifacts(
    fo_records_path: Path,
    ci_records_path: Path,
    ec_records_path: Path,
    output_dir: Path,
    fo_holdout_path: Optional[Path] = None,
    ci_holdout_path: Optional[Path] = None,
    fo_fpfn_path: Optional[Path] = None,
    ci_fpfn_path: Optional[Path] = None,
    ec_best_completion_path: Optional[Path] = None,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Export all paper artifacts.

    Args:
        fo_records_path: Path to FO eval_records.jsonl
        ci_records_path: Path to CI eval_records.jsonl
        ec_records_path: Path to EC eval_records.jsonl
        output_dir: Output directory (paper/auto)
        fo_holdout_path: Path to FO holdout analysis JSON (optional)
        fo_fpfn_path: Path to FO FP/FN analysis JSON (optional)
        ci_fpfn_path: Path to CI FP/FN analysis JSON (optional)
        ec_best_completion_path: Path to EC best-completion analysis JSON (optional)
        verbose: Print progress

    Returns:
        Manifest dict
    """
    if not HAS_PANDAS:
        raise ImportError("pandas required")

    # Create directories
    tables_dir = output_dir / "tables"
    figures_dir = output_dir / "figures"
    appendix_dir = output_dir / "appendix"

    for d in [tables_dir, figures_dir, appendix_dir]:
        d.mkdir(parents=True, exist_ok=True)

    # Load records
    if verbose:
        print("Loading evaluation records...")

    fo_df = load_records([fo_records_path])
    ci_df = load_records([ci_records_path])
    ec_df = load_records([ec_records_path])

    if verbose:
        print(f"  FO: {len(fo_df)} records")
        print(f"  CI: {len(ci_df)} records")
        print(f"  EC: {len(ec_df)} records")

    # Compute metrics
    if verbose:
        print("Computing metrics...")

    fo_metrics = compute_task_metrics(fo_df, "fo")
    ci_metrics = compute_task_metrics(ci_df, "ci")
    ec_metrics = compute_task_metrics(ec_df, "ec")

    fo_band_metrics = compute_band_metrics(fo_df)
    ci_band_metrics = compute_band_metrics(ci_df)
    ec_band_metrics = compute_band_metrics(ec_df)

    # Generate tables
    if verbose:
        print("Generating tables...")

    # Load holdout and FP/FN analysis data if available
    fo_holdout_data = None
    ci_holdout_data = None
    fo_fpfn_data = None
    ci_fpfn_data = None

    if fo_holdout_path and fo_holdout_path.exists():
        if verbose:
            print(f"Loading FO holdout data from {fo_holdout_path}...")
        with open(fo_holdout_path) as f:
            fo_holdout_data = json.load(f)

    if ci_holdout_path and ci_holdout_path.exists():
        if verbose:
            print(f"Loading CI holdout data from {ci_holdout_path}...")
        with open(ci_holdout_path) as f:
            ci_holdout_data = json.load(f)

    if fo_fpfn_path and fo_fpfn_path.exists():
        if verbose:
            print(f"Loading FO FP/FN data from {fo_fpfn_path}...")
        with open(fo_fpfn_path) as f:
            fo_fpfn_data = json.load(f)

    if ci_fpfn_path and ci_fpfn_path.exists():
        if verbose:
            print(f"Loading CI FP/FN data from {ci_fpfn_path}...")
        with open(ci_fpfn_path) as f:
            ci_fpfn_data = json.load(f)

    ec_best_completion_data = None
    if ec_best_completion_path and ec_best_completion_path.exists():
        if verbose:
            print(f"Loading EC best-completion data from {ec_best_completion_path}...")
        with open(ec_best_completion_path) as f:
            ec_best_completion_data = json.load(f)

    # Main body tables
    generate_across_task_summary(
        fo_metrics, ci_metrics, ec_metrics, tables_dir / "summary_across_tasks.tex"
    )
    generate_fo_overall_table(fo_metrics, tables_dir / "fo_overall.tex")
    generate_fo_band_table(fo_band_metrics, tables_dir / "fo_bands.tex")

    # Generalization tables (main body - key finding)
    if fo_holdout_data:
        generate_fo_generalization_table(fo_holdout_data, tables_dir / "fo_generalization.tex")
    if ci_holdout_data:
        generate_ci_generalization_table(ci_holdout_data, tables_dir / "ci_generalization.tex")
    generate_ci_overall_table(ci_metrics, tables_dir / "ci_overall.tex")
    generate_ci_failure_table(ci_metrics, tables_dir / "ci_failure.tex")
    generate_ec_overall_table(ec_metrics, tables_dir / "ec_overall.tex")
    generate_ec_band_table(ec_band_metrics, tables_dir / "ec_bands.tex")

    # Appendix tables
    generate_fo_appendix_tables(fo_metrics, fo_df, appendix_dir / "fo_appendix_tables.tex")
    generate_ci_appendix_tables(
        ci_metrics, ci_df, ci_band_metrics, appendix_dir / "ci_appendix_tables.tex"
    )
    generate_ec_appendix_tables(ec_metrics, ec_df, appendix_dir / "ec_appendix_tables.tex")

    # Error profiles table (appendix)
    if fo_fpfn_data or ci_fpfn_data:
        generate_error_profiles_table(
            fo_fpfn_data or {}, ci_fpfn_data or {}, appendix_dir / "error_profiles.tex"
        )

    # EC best-completion table (appendix)
    if ec_best_completion_data:
        generate_ec_best_completion_table(
            ec_best_completion_data, appendix_dir / "ec_best_completion.tex"
        )

    # World generation hyperparameters table (appendix)
    generate_world_gen_params_table(appendix_dir / "tab_world_gen_params.tex")

    # Equality usage analysis tables (appendix)
    generate_equality_usage_table(fo_df, ci_df, ec_df, appendix_dir / "equality_usage.tex")

    # Generate budget curve figures (regenerate to ensure correct Y-axis limits)
    if verbose:
        print("Generating figures...")

    # Generate FO budget curve (Y-axis 0-60%)
    plot_budget_curves(fo_df, "fo", figures_dir / "fo_budget_curve.pdf")

    # Generate EC budget curve (Y-axis 0-100%)
    plot_budget_curves(ec_df, "ec", figures_dir / "ec_budget_curve.pdf")

    # Copy CI failure modes figure from existing location (not regenerated here)
    existing_figures_dir = fo_records_path.parent.parent.parent / "figures" / "induction"
    if not existing_figures_dir.exists():
        existing_figures_dir = Path("concept_synth/figures/induction")

    ci_failure_src = existing_figures_dir / "ci_failure_modes.pdf"
    ci_failure_dst = figures_dir / "ci_failure_modes.pdf"
    if ci_failure_src.exists():
        shutil.copy(ci_failure_src, ci_failure_dst)
        if verbose:
            print(f"  Copied {ci_failure_dst}")
    else:
        if verbose:
            print(f"  Warning: {ci_failure_src} not found")

    # Generate manifest
    manifest = {
        "timestamp": datetime.now().isoformat(),
        "git_commit": get_git_commit(),
        "inputs": {
            "fo_records": str(fo_records_path),
            "ci_records": str(ci_records_path),
            "ec_records": str(ec_records_path),
        },
        "record_counts": {
            "fo": len(fo_df),
            "ci": len(ci_df),
            "ec": len(ec_df),
        },
        "artifacts": {
            "tables": [str(p.relative_to(output_dir)) for p in tables_dir.glob("*.tex")],
            "figures": [str(p.relative_to(output_dir)) for p in figures_dir.glob("*.pdf")],
            "appendix": [str(p.relative_to(output_dir)) for p in appendix_dir.glob("*.tex")],
        },
    }

    manifest_path = output_dir / "manifest.json"
    with open(manifest_path, "w") as f:
        json.dump(manifest, f, indent=2)

    if verbose:
        print(f"Generated manifest: {manifest_path}")
        print("\nArtifacts generated:")
        print(f"  Tables: {len(manifest['artifacts']['tables'])}")
        print(f"  Figures: {len(manifest['artifacts']['figures'])}")
        print(f"  Appendix: {len(manifest['artifacts']['appendix'])}")

    return manifest


def main():
    parser = argparse.ArgumentParser(
        description="Export paper artifacts (tables, figures, appendix)",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    parser.add_argument(
        "--fo-records",
        default="concept_synth/artifacts/analysis/v1/fo/eval_records.jsonl",
        help="Path to FO eval_records.jsonl",
    )
    parser.add_argument(
        "--ci-records",
        default="concept_synth/artifacts/analysis/v1/ci/eval_records.jsonl",
        help="Path to CI eval_records.jsonl",
    )
    parser.add_argument(
        "--ec-records",
        default="concept_synth/artifacts/analysis/v1/ec/eval_records.jsonl",
        help="Path to EC eval_records.jsonl",
    )
    parser.add_argument(
        "--ec-dataset",
        default=None,
        help="Path to EC dataset YAML file (alternative to --ec-records). If provided, evaluates the dataset directly.",
    )
    parser.add_argument(
        "--fo-holdout",
        default="concept_synth/artifacts/analysis/v1/holdout/fo_holdout.json",
        help="Path to FO holdout analysis JSON",
    )
    parser.add_argument(
        "--ci-holdout",
        default="concept_synth/artifacts/analysis/v1/holdout/ci_holdout.json",
        help="Path to CI holdout analysis JSON",
    )
    parser.add_argument(
        "--fo-fpfn",
        default="concept_synth/artifacts/analysis/v1/per_world_fpfn/fo_fpfn_profiles.json",
        help="Path to FO FP/FN analysis JSON",
    )
    parser.add_argument(
        "--ci-fpfn",
        default="concept_synth/artifacts/analysis/v1/per_world_fpfn/ci_fpfn_profiles.json",
        help="Path to CI FP/FN analysis JSON",
    )
    parser.add_argument(
        "--ec-best-completion",
        default="concept_synth/artifacts/analysis/v1/ec_best_completion/ec_best_completion.json",
        help="Path to EC best-completion analysis JSON",
    )
    parser.add_argument("--out", "-o", default="concept_synth/paper/auto", help="Output directory")
    parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output")

    args = parser.parse_args()

    # Handle --ec-dataset option: generate eval records from YAML dataset
    ec_records_path = Path(args.ec_records)
    if args.ec_dataset:
        from concept_synth.analysis.dump_eval_records import dump_records_for_task
        import tempfile

        if not args.quiet:
            print(f"Processing EC dataset: {args.ec_dataset}")

        # Generate eval records to a temporary file
        with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as tmp:
            tmp_path = Path(tmp.name)

        dump_records_for_task(
            task="ec",
            dataset_path=Path(args.ec_dataset),
            output_path=tmp_path,
            benchmark_version="v1",
            verbose=not args.quiet,
        )
        ec_records_path = tmp_path

        if not args.quiet:
            print(f"Generated EC eval records at: {tmp_path}")

    export_paper_artifacts(
        fo_records_path=Path(args.fo_records),
        ci_records_path=Path(args.ci_records),
        ec_records_path=ec_records_path,
        output_dir=Path(args.out),
        fo_holdout_path=Path(args.fo_holdout) if args.fo_holdout else None,
        ci_holdout_path=Path(args.ci_holdout) if args.ci_holdout else None,
        fo_fpfn_path=Path(args.fo_fpfn) if args.fo_fpfn else None,
        ci_fpfn_path=Path(args.ci_fpfn) if args.ci_fpfn else None,
        ec_best_completion_path=Path(args.ec_best_completion) if args.ec_best_completion else None,
        verbose=not args.quiet,
    )


if __name__ == "__main__":
    main()
