#!/usr/bin/env python3
"""
Budget Curves and Bloat Histograms Analysis

Produces paper-ready plots showing:
1. Acc@(+Δ) curves for Δ in [0..100] for each model
2. Validity vs budgeted accuracy ("parsimony gap")
3. AST delta distributions for valid solutions (bloat histograms)

Usage:
    python -m concept_synth.analysis.budget_curves \
        --task fo \
        --records artifacts/analysis/v1/fo/eval_records.jsonl \
        --outdir artifacts/analysis/v1/fo/budget_curves
"""

import argparse
import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root
add_repo_root(__file__)

import numpy as np

try:
    import matplotlib

    matplotlib.use("Agg")  # Non-interactive backend
    import matplotlib.pyplot as plt

    HAS_MATPLOTLIB = True
except ImportError:
    HAS_MATPLOTLIB = False
    print("Warning: matplotlib not available, skipping plots")

try:
    import pandas as pd

    HAS_PANDAS = True
except ImportError:
    HAS_PANDAS = False
    print("Warning: pandas not available")

from concept_synth.analysis.schema import EvalRecord, load_records

# Model display names and colors
MODEL_DISPLAY_NAMES = {
    "grok4": "Grok4",
    "gpt-5.2": "GPT-5.2",
    "grok4.1fast": "Grok4.1f",
    "gemini-3-pro-preview": "Gemini 3",
    "deepseek-reasoner": "DSR",
    "claude-opus-4-5-20251101": "Opus 4.5",
    "hermes4": "Hermes4",
    "gpt-4o": "GPT-4o",
}

MODEL_COLORS = {
    "grok4": "#1f77b4",
    "gpt-5.2": "#ff7f0e",
    "grok4.1fast": "#2ca02c",
    "gemini-3-pro-preview": "#d62728",
    "deepseek-reasoner": "#9467bd",
    "claude-opus-4-5-20251101": "#8c564b",
    "hermes4": "#e377c2",
    "gpt-4o": "#7f7f7f",
}

# Model order for consistent plotting (by expected performance)
MODEL_ORDER = [
    "grok4",
    "gpt-5.2",
    "grok4.1fast",
    "gemini-3-pro-preview",
    "deepseek-reasoner",
    "claude-opus-4-5-20251101",
    "hermes4",
    "gpt-4o",
]


def compute_budget_curve(
    df: "pd.DataFrame", model: str, max_delta: int = 100, step: int = 1
) -> Tuple[List[int], List[float]]:
    """
    Compute Acc@(+Δ) curve for a model.

    Args:
        df: DataFrame with evaluation records
        model: Model name
        max_delta: Maximum delta to compute
        step: Step size for delta values

    Returns:
        (deltas, accuracies) tuple
    """
    model_df = df[df["model"] == model]
    total = len(model_df)

    if total == 0:
        return [], []

    deltas = list(range(0, max_delta + 1, step))
    accuracies = []

    for delta in deltas:
        # Count: valid AND ast_delta <= delta
        count = (
            (model_df["valid"] == True) & (model_df["ast_delta"].fillna(float("inf")) <= delta)
        ).sum()
        accuracies.append(count / total)

    return deltas, accuracies


def compute_parsimony_metrics(df: "pd.DataFrame") -> Dict[str, Dict[str, float]]:
    """
    Compute parsimony gap metrics for each model.

    Returns dict: model -> {validity, acc_25, gap, bloat_pct}
    """
    metrics = {}

    for model in df["model"].unique():
        model_df = df[df["model"] == model]
        total = len(model_df)

        if total == 0:
            continue

        valid_count = model_df["valid"].sum()
        validity = valid_count / total

        # Acc@+25
        acc_25_count = (
            (model_df["valid"] == True) & (model_df["ast_delta"].fillna(float("inf")) <= 25)
        ).sum()
        acc_25 = acc_25_count / total

        # Gap
        gap = validity - acc_25

        # Bloat rate: % of valid with ast_delta > 25
        if valid_count > 0:
            bloat_count = (
                (model_df["valid"] == True) & (model_df["ast_delta"].fillna(0) > 25)
            ).sum()
            bloat_pct = bloat_count / total  # As fraction of all problems
        else:
            bloat_pct = 0.0

        metrics[model] = {
            "validity": validity,
            "acc_25": acc_25,
            "gap": gap,
            "bloat_pct": bloat_pct,
        }

    return metrics


def plot_budget_curves(
    df: "pd.DataFrame",
    task: str,
    output_path: Path,
    max_delta: int = 100,
    figsize: Tuple[int, int] = (10, 6),
    y_max: Optional[int] = None,
) -> None:
    """
    Plot Acc@(+Δ) curves for all models.
    """
    if not HAS_MATPLOTLIB:
        return

    fig, ax = plt.subplots(figsize=figsize)

    models = [m for m in MODEL_ORDER if m in df["model"].unique()]

    for model in models:
        deltas, accs = compute_budget_curve(df, model, max_delta)
        if deltas:
            label = MODEL_DISPLAY_NAMES.get(model, model)
            color = MODEL_COLORS.get(model, None)
            ax.plot(deltas, [a * 100 for a in accs], label=label, color=color, linewidth=2)

    ax.set_xlabel("AST Budget (Δ = pred - gold)", fontsize=12)
    ax.set_ylabel("Accuracy (%)", fontsize=12)
    ax.set_title(f"{task.upper()} v1: Budgeted Accuracy Curves", fontsize=14)
    ax.legend(loc="lower right", fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, max_delta)

    # Task-specific Y-axis limits: FO tops at 60%, EC tops at 100%
    if y_max is None:
        y_max = 60 if task == "fo" else 100
    ax.set_ylim(0, y_max)

    # Add vertical lines at key thresholds
    label_y = y_max - 3  # Position labels near top
    for delta, label in [(0, "gold"), (10, "+10"), (25, "+25")]:
        ax.axvline(x=delta, color="gray", linestyle="--", alpha=0.5)
        ax.text(delta + 1, label_y, label, fontsize=9, color="gray")

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved budget curve plot to {output_path}")


def plot_ast_delta_histogram(
    df: "pd.DataFrame", task: str, output_path: Path, figsize: Tuple[int, int] = (12, 6)
) -> None:
    """
    Plot AST delta distribution for valid solutions.
    """
    if not HAS_MATPLOTLIB:
        return

    # Filter to valid solutions with ast_delta
    valid_df = df[(df["valid"] == True) & (df["ast_delta"].notna())]

    if len(valid_df) == 0:
        print(f"No valid solutions with ast_delta for {task}")
        return

    models = [m for m in MODEL_ORDER if m in valid_df["model"].unique()]
    n_models = len(models)

    if n_models == 0:
        return

    # Create subplots
    fig, axes = plt.subplots(2, (n_models + 1) // 2, figsize=figsize)
    axes = axes.flatten() if n_models > 1 else [axes]

    for i, model in enumerate(models):
        ax = axes[i]
        model_df = valid_df[valid_df["model"] == model]

        if len(model_df) == 0:
            continue

        deltas = model_df["ast_delta"].values

        # Histogram
        bins = np.arange(min(-20, deltas.min()), max(60, deltas.max()) + 5, 5)
        ax.hist(
            deltas, bins=bins, color=MODEL_COLORS.get(model, "blue"), alpha=0.7, edgecolor="black"
        )

        # Add vertical lines
        ax.axvline(x=0, color="green", linestyle="--", linewidth=2, label="gold")
        ax.axvline(x=25, color="red", linestyle="--", linewidth=2, label="+25")

        # Bloat annotation
        bloat_count = (deltas > 25).sum()
        bloat_pct = bloat_count / len(deltas) * 100
        ax.text(
            0.95,
            0.95,
            f"Bloat: {bloat_pct:.1f}%",
            transform=ax.transAxes,
            ha="right",
            va="top",
            fontsize=10,
            color="red",
        )

        label = MODEL_DISPLAY_NAMES.get(model, model)
        ax.set_title(label, fontsize=11)
        ax.set_xlabel("AST Δ")
        ax.set_ylabel("Count")
        ax.grid(True, alpha=0.3)

    # Hide unused subplots
    for i in range(n_models, len(axes)):
        axes[i].set_visible(False)

    fig.suptitle(f"{task.upper()} v1: AST Delta Distribution (Valid Solutions)", fontsize=14)
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved AST delta histogram to {output_path}")


def generate_parsimony_table(df: "pd.DataFrame", task: str, output_path: Path) -> str:
    """
    Generate LaTeX table for parsimony gap metrics.
    """
    metrics = compute_parsimony_metrics(df)

    # Sort by acc_25 descending
    sorted_models = sorted(metrics.keys(), key=lambda m: metrics[m]["acc_25"], reverse=True)

    # Find best values for each metric (higher is better for validity/acc_25, lower for gap/bloat)
    best_validity = max(metrics[m]["validity"] for m in sorted_models)
    best_acc_25 = max(metrics[m]["acc_25"] for m in sorted_models)
    # For gap and bloat, lower is better (but only among models with non-zero validity)
    valid_models = [m for m in sorted_models if metrics[m]["validity"] > 0]
    best_gap = min(metrics[m]["gap"] for m in valid_models) if valid_models else 0
    best_bloat = min(metrics[m]["bloat_pct"] for m in valid_models) if valid_models else 0

    def fmt_val(val: float, best: float, higher_better: bool = True) -> str:
        """Format value with bold if it's the best."""
        pct = f"{val*100:.1f}\\%"
        if higher_better:
            is_best = abs(val - best) < 0.001
        else:
            is_best = abs(val - best) < 0.001
        return f"\\textbf{{{pct}}}" if is_best else pct

    lines = []
    lines.append(f"% {task.upper()} v1 Parsimony Gap Metrics")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append(f"\\caption{{{task.upper()} v1 parsimony gap: validity vs budgeted accuracy.}}")
    lines.append(f"\\label{{tab:{task}_parsimony}}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & Validity & Acc@+25 & Gap & Bloat\\% \\\\")
    lines.append("\\midrule")

    for model in sorted_models:
        m = metrics[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)
        validity_str = fmt_val(m["validity"], best_validity, higher_better=True)
        acc_25_str = fmt_val(m["acc_25"], best_acc_25, higher_better=True)
        gap_str = fmt_val(m["gap"], best_gap, higher_better=False)
        bloat_str = fmt_val(m["bloat_pct"], best_bloat, higher_better=False)
        lines.append(f"{display} & {validity_str} & {acc_25_str} & {gap_str} & {bloat_str} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    table_content = "\n".join(lines)

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write(table_content)

    print(f"Saved parsimony table to {output_path}")
    return table_content


def run_budget_analysis(
    task: str,
    records_path: Path,
    outdir: Path,
    figures_dir: Optional[Path] = None,
    tables_dir: Optional[Path] = None,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Run full budget curve analysis.

    Args:
        task: 'fo' or 'ec'
        records_path: Path to eval_records.jsonl
        outdir: Output directory for artifacts
        figures_dir: Directory for PDF figures
        tables_dir: Directory for LaTeX tables
        verbose: Print progress

    Returns:
        Summary dict with key metrics
    """
    if not HAS_PANDAS:
        print("Error: pandas required for budget analysis")
        return {}

    # Load records
    if verbose:
        print(f"Loading records from {records_path}...")

    df = load_records([records_path])

    if len(df) == 0:
        print(f"No records found in {records_path}")
        return {}

    if verbose:
        print(f"Loaded {len(df)} records")
        print(f"Models: {df['model'].unique().tolist()}")

    # Create output directories
    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    if figures_dir is None:
        figures_dir = Path("figures/induction")
    figures_dir = Path(figures_dir)
    figures_dir.mkdir(parents=True, exist_ok=True)

    if tables_dir is None:
        tables_dir = Path("tables/induction")
    tables_dir = Path(tables_dir)
    tables_dir.mkdir(parents=True, exist_ok=True)

    # Generate plots
    if HAS_MATPLOTLIB:
        plot_budget_curves(df, task, figures_dir / f"{task}_budget_curve.pdf")

        plot_ast_delta_histogram(df, task, figures_dir / f"{task}_ast_delta_hist.pdf")

    # Generate table
    generate_parsimony_table(df, task, tables_dir / f"{task}_parsimony_gap.tex")

    # Compute summary metrics
    metrics = compute_parsimony_metrics(df)

    summary = {
        "task": task,
        "total_records": len(df),
        "models": list(metrics.keys()),
        "metrics": metrics,
    }

    # Save summary JSON
    summary_path = outdir / "budget_analysis_summary.json"
    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=2)

    if verbose:
        print(f"\nSummary saved to {summary_path}")
        print("\nParsimony Gap Summary:")
        print("-" * 60)
        for model in sorted(metrics.keys(), key=lambda m: metrics[m]["acc_25"], reverse=True):
            m = metrics[model]
            display = MODEL_DISPLAY_NAMES.get(model, model)
            print(
                f"  {display:20s}: Validity={m['validity']*100:5.1f}%, "
                f"Acc@+25={m['acc_25']*100:5.1f}%, Gap={m['gap']*100:5.1f}%"
            )

    return summary


def main():
    parser = argparse.ArgumentParser(
        description="Generate budget curves and bloat histograms",
    )
    parser.add_argument(
        "--task", "-t", required=True, choices=["fo", "ec"], help="Task type: fo or ec"
    )
    parser.add_argument("--records", "-r", required=True, help="Path to eval_records.jsonl")
    parser.add_argument("--outdir", "-o", required=True, help="Output directory for artifacts")
    parser.add_argument("--figures-dir", help="Directory for PDF figures")
    parser.add_argument("--tables-dir", help="Directory for LaTeX tables")
    parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output")

    args = parser.parse_args()

    run_budget_analysis(
        task=args.task,
        records_path=Path(args.records),
        outdir=Path(args.outdir),
        figures_dir=Path(args.figures_dir) if args.figures_dir else None,
        tables_dir=Path(args.tables_dir) if args.tables_dir else None,
        verbose=not args.quiet,
    )


if __name__ == "__main__":
    main()
