#!/usr/bin/env python3
"""
Generalization vs Bloat Bins Analysis

Extends the compact vs bloated analysis to multiple AST-delta bins:
- Bin1: AST_delta <= +1 (compact)
- Bin2: +2..+10 (moderate)
- Bin3: +11..+25 (bloated)
- Bin4: > +25 (extreme bloat)

Uses existing holdout world generation to measure generalization.

Outputs:
- CSV: generalization_vs_bloat_bins_{fullobs,ci}.csv
- LaTeX: tab_fullobs_generalization_vs_bloat_bins.tex, tab_ci_generalization_vs_bloat_bins.tex
- Figures: fig_fullobs_generalization_vs_bloat_bins.pdf, fig_ci_generalization_vs_bloat_bins.pdf
"""

import argparse
import csv
import json
import os
import random
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root
add_repo_root(__file__)

from concept_synth.io_utils import load_from_yaml

try:
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    HAS_MATPLOTLIB = True
except ImportError:
    HAS_MATPLOTLIB = False

# Model display names and order
MODEL_DISPLAY_NAMES = {
    "grok4": "Grok4",
    "gpt-5.2": "GPT-5.2",
    "grok4.1fast": "Grok4.1f",
    "gemini-3-pro-preview": "Gemini3",
    "deepseek-reasoner": "DSR",
    "claude-opus-4-5-20251101": "Opus4.5",
    "hermes4": "Hermes4",
    "gpt-4o": "GPT-4o",
}

MODEL_ORDER = [
    "grok4",
    "gpt-5.2",
    "grok4.1fast",
    "gemini-3-pro-preview",
    "deepseek-reasoner",
    "claude-opus-4-5-20251101",
    "hermes4",
    "gpt-4o",
]

MODEL_COLORS = {
    "grok4": "#1f77b4",
    "gpt-5.2": "#ff7f0e",
    "grok4.1fast": "#2ca02c",
    "gemini-3-pro-preview": "#d62728",
    "deepseek-reasoner": "#9467bd",
    "claude-opus-4-5-20251101": "#8c564b",
    "hermes4": "#e377c2",
    "gpt-4o": "#7f7f7f",
}

# AST delta bins (3 bins: compact, moderate, extreme)
AST_BINS = [
    ("<=+1", -1000, 1),
    ("+2..+25", 2, 25),
    (">+25", 26, 10000),
]


@dataclass
class BinMetrics:
    """Metrics for an AST delta bin."""

    bin_name: str
    n_train_correct: int
    mean_holdout_exact: float
    mean_holdout_fp: Optional[float] = None
    mean_holdout_fn: Optional[float] = None


@dataclass
class ModelBinResults:
    """Per-model results across bins."""

    model: str
    task: str
    bins: Dict[str, BinMetrics]


def get_bin_name(ast_delta: int) -> str:
    """Get bin name for an AST delta value."""
    for name, low, high in AST_BINS:
        if low <= ast_delta <= high:
            return name
    return ">+25"


def load_holdout_data(holdout_path: Path) -> Dict[str, Any]:
    """Load existing holdout analysis data."""
    if not holdout_path.exists():
        return {}

    with open(holdout_path) as f:
        return json.load(f)


def compute_bin_metrics_from_holdout(
    holdout_data: Dict[str, Any], task: str
) -> Dict[str, ModelBinResults]:
    """
    Compute bin metrics from existing holdout analysis data.

    The holdout data contains per-instance results with:
    - train_correct: whether the model was correct on training data
    - ast_delta: predicted AST - gold AST
    - holdout_exact_match_rate: fraction of holdout worlds with exact match
    """
    # Try different keys for results
    instance_results = holdout_data.get("results", holdout_data.get("instance_results", []))

    if not instance_results:
        return {}

    # Group by model and bin
    model_bin_data = defaultdict(lambda: defaultdict(list))

    for result in instance_results:
        model = result.get("model", "")
        train_correct = result.get("train_correct", False)
        ast_delta = result.get("ast_delta")
        holdout_exact = result.get("holdout_exact_match_rate")
        holdout_fp = result.get("holdout_mean_fp_rate")
        holdout_fn = result.get("holdout_mean_fn_rate")

        if not model or not train_correct:
            continue  # Only consider train-correct instances

        if ast_delta is None:
            continue

        bin_name = get_bin_name(ast_delta)

        model_bin_data[model][bin_name].append(
            {
                "holdout_exact": holdout_exact,
                "holdout_fp": holdout_fp,
                "holdout_fn": holdout_fn,
            }
        )

    # Aggregate per bin
    model_results = {}

    for model, bin_data in model_bin_data.items():
        bins = {}

        for bin_name, instances in bin_data.items():
            n = len(instances)
            if n == 0:
                continue

            holdout_exacts = [
                i["holdout_exact"] for i in instances if i["holdout_exact"] is not None
            ]
            holdout_fps = [i["holdout_fp"] for i in instances if i["holdout_fp"] is not None]
            holdout_fns = [i["holdout_fn"] for i in instances if i["holdout_fn"] is not None]

            bins[bin_name] = BinMetrics(
                bin_name=bin_name,
                n_train_correct=n,
                mean_holdout_exact=(
                    sum(holdout_exacts) / len(holdout_exacts) if holdout_exacts else 0.0
                ),
                mean_holdout_fp=sum(holdout_fps) / len(holdout_fps) if holdout_fps else None,
                mean_holdout_fn=sum(holdout_fns) / len(holdout_fns) if holdout_fns else None,
            )

        model_results[model] = ModelBinResults(
            model=model,
            task=task,
            bins=bins,
        )

    return model_results


def save_csv(results: Dict[str, ModelBinResults], output_path: Path) -> None:
    """Save bin metrics to CSV."""
    output_path.parent.mkdir(parents=True, exist_ok=True)

    rows = []
    for model in MODEL_ORDER:
        if model not in results:
            continue

        model_results = results[model]

        for bin_name, _, _ in AST_BINS:
            bin_metrics = model_results.bins.get(bin_name)

            if bin_metrics is None:
                rows.append(
                    {
                        "model": model,
                        "bin": bin_name,
                        "n_train_correct": 0,
                        "mean_holdout_exact": "",
                        "mean_holdout_fp": "",
                        "mean_holdout_fn": "",
                    }
                )
            else:
                rows.append(
                    {
                        "model": model,
                        "bin": bin_name,
                        "n_train_correct": bin_metrics.n_train_correct,
                        "mean_holdout_exact": f"{bin_metrics.mean_holdout_exact:.4f}",
                        "mean_holdout_fp": (
                            f"{bin_metrics.mean_holdout_fp:.4f}"
                            if bin_metrics.mean_holdout_fp
                            else ""
                        ),
                        "mean_holdout_fn": (
                            f"{bin_metrics.mean_holdout_fn:.4f}"
                            if bin_metrics.mean_holdout_fn
                            else ""
                        ),
                    }
                )

    with open(output_path, "w", newline="") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=[
                "model",
                "bin",
                "n_train_correct",
                "mean_holdout_exact",
                "mean_holdout_fp",
                "mean_holdout_fn",
            ],
        )
        writer.writeheader()
        writer.writerows(rows)


def generate_latex_table(
    results: Dict[str, ModelBinResults], task_name: str, output_path: Path
) -> None:
    """Generate LaTeX table for bin metrics."""
    output_path.parent.mkdir(parents=True, exist_ok=True)

    lines = []
    lines.append(f"% {task_name} Generalization vs Bloat Bins Table (auto-generated)")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append(f"\\caption{{\\textbf{{{task_name} generalization by AST delta bin.}}")
    lines.append("For train-correct predictions, we report holdout exact-match rate (\\%).")
    lines.append("Bins: $\\leq$+1 (compact), +2..+25 (moderate), $>$+25 (bloated).")
    lines.append("--- indicates no train-correct instances in that bin.}")
    lines.append(f"\\label{{tab:{task_name.lower()}_generalization_bins}}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & $\\leq$+1 & +2..+25 & $>$+25 \\\\")
    lines.append("\\midrule")

    models = [m for m in MODEL_ORDER if m in results]

    for model in models:
        display = MODEL_DISPLAY_NAMES.get(model, model)
        row = [display]

        model_results = results[model]

        for bin_name, _, _ in AST_BINS:
            bin_metrics = model_results.bins.get(bin_name)

            if bin_metrics is None or bin_metrics.n_train_correct == 0:
                row.append("---")
            else:
                pct = bin_metrics.mean_holdout_exact * 100
                n = bin_metrics.n_train_correct
                row.append(f"{pct:.1f} ({n})")

        lines.append(" & ".join(row) + " \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    with open(output_path, "w") as f:
        f.write("\n".join(lines))


def generate_figure(
    results: Dict[str, ModelBinResults],
    task_name: str,
    output_path: Path,
    auto_scale_y: bool = False,
) -> None:
    """Generate line plot of generalization vs bin.

    Args:
        results: Model bin results
        task_name: Name of the task (FullObs, CI)
        output_path: Path to save the figure
        auto_scale_y: If True, auto-scale Y axis to data range; if False, use 0-100
    """
    if not HAS_MATPLOTLIB:
        return

    output_path.parent.mkdir(parents=True, exist_ok=True)

    fig, ax = plt.subplots(figsize=(8, 5))

    x_labels = ["≤+1", "+2..+25", ">+25"]
    x = [0, 1, 2]

    models = [m for m in MODEL_ORDER if m in results]

    all_y_values = []  # Collect all Y values for auto-scaling

    for model in models:
        model_results = results[model]

        y = []
        for bin_name, _, _ in AST_BINS:
            bin_metrics = model_results.bins.get(bin_name)
            if bin_metrics and bin_metrics.n_train_correct > 0:
                y.append(bin_metrics.mean_holdout_exact * 100)
            else:
                y.append(None)

        # Skip if too many missing
        if y.count(None) > 2:
            continue

        # Plot with gaps for None
        valid_x = [xi for xi, yi in zip(x, y) if yi is not None]
        valid_y = [yi for yi in y if yi is not None]

        all_y_values.extend(valid_y)

        color = MODEL_COLORS.get(model, "#333333")
        label = MODEL_DISPLAY_NAMES.get(model, model)

        ax.plot(valid_x, valid_y, "o-", color=color, label=label, linewidth=2, markersize=6)

    ax.set_xlabel("AST Delta Bin", fontsize=12)
    ax.set_ylabel("Holdout Exact Match (%)", fontsize=12)
    ax.set_title(f"{task_name}: Generalization by Formula Complexity", fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(x_labels)

    if auto_scale_y and all_y_values:
        # Auto-scale with some padding
        y_min = max(0, min(all_y_values) - 5)
        y_max = min(100, max(all_y_values) + 5)
        ax.set_ylim(y_min, y_max)
    else:
        ax.set_ylim(0, 100)

    ax.legend(loc="best", fontsize=9)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()


def run_generalization_bins_analysis(
    fo_holdout_path: Optional[Path],
    ci_holdout_path: Optional[Path],
    outdir: Path,
    tables_dir: Path,
    figures_dir: Path,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Run generalization vs bloat bins analysis.

    Args:
        fo_holdout_path: Path to FO holdout JSON
        ci_holdout_path: Path to CI holdout JSON
        outdir: Output directory for CSV
        tables_dir: Directory for LaTeX tables
        figures_dir: Directory for figures
        verbose: Print progress

    Returns:
        Summary dict
    """
    results = {
        "fo": {},
        "ci": {},
        "outputs": {},
    }

    # =========================================================================
    # FullObs Analysis
    # =========================================================================
    if fo_holdout_path and fo_holdout_path.exists():
        if verbose:
            print("[Generalization Bins] Processing FullObs...")

        fo_holdout = load_holdout_data(fo_holdout_path)
        fo_bin_results = compute_bin_metrics_from_holdout(fo_holdout, "fo")

        if fo_bin_results:
            # Save CSV
            csv_path = outdir / "generalization_vs_bloat_bins_fullobs.csv"
            save_csv(fo_bin_results, csv_path)
            results["outputs"]["fo_csv"] = str(csv_path)
            if verbose:
                print(f"  Saved CSV: {csv_path}")

            # Save LaTeX
            tex_path = tables_dir / "tab_fullobs_generalization_vs_bloat_bins.tex"
            generate_latex_table(fo_bin_results, "FullObs", tex_path)
            results["outputs"]["fo_latex"] = str(tex_path)
            if verbose:
                print(f"  Saved LaTeX: {tex_path}")

            # Save figure
            if HAS_MATPLOTLIB:
                fig_path = figures_dir / "fig_fullobs_generalization_vs_bloat_bins.pdf"
                generate_figure(fo_bin_results, "FullObs", fig_path)
                results["outputs"]["fo_fig"] = str(fig_path)
                if verbose:
                    print(f"  Saved figure: {fig_path}")

            # Summary
            results["fo"] = {
                model: {
                    bin_name: {"n": m.n_train_correct, "holdout_exact": m.mean_holdout_exact}
                    for bin_name, m in r.bins.items()
                }
                for model, r in fo_bin_results.items()
            }
        else:
            if verbose:
                print("  No FullObs holdout data found")

    # =========================================================================
    # CI Analysis
    # =========================================================================
    if ci_holdout_path and ci_holdout_path.exists():
        if verbose:
            print("[Generalization Bins] Processing CI...")

        ci_holdout = load_holdout_data(ci_holdout_path)
        ci_bin_results = compute_bin_metrics_from_holdout(ci_holdout, "ci")

        if ci_bin_results:
            # Save CSV
            csv_path = outdir / "generalization_vs_bloat_bins_ci.csv"
            save_csv(ci_bin_results, csv_path)
            results["outputs"]["ci_csv"] = str(csv_path)
            if verbose:
                print(f"  Saved CSV: {csv_path}")

            # Save LaTeX
            tex_path = tables_dir / "tab_ci_generalization_vs_bloat_bins.tex"
            generate_latex_table(ci_bin_results, "CI", tex_path)
            results["outputs"]["ci_latex"] = str(tex_path)
            if verbose:
                print(f"  Saved LaTeX: {tex_path}")

            # Save figure (auto-scale Y axis for CI since values may be compressed)
            if HAS_MATPLOTLIB:
                fig_path = figures_dir / "fig_ci_generalization_vs_bloat_bins.pdf"
                generate_figure(ci_bin_results, "CI", fig_path, auto_scale_y=True)
                results["outputs"]["ci_fig"] = str(fig_path)
                if verbose:
                    print(f"  Saved figure: {fig_path}")

            # Summary
            results["ci"] = {
                model: {
                    bin_name: {"n": m.n_train_correct, "holdout_exact": m.mean_holdout_exact}
                    for bin_name, m in r.bins.items()
                }
                for model, r in ci_bin_results.items()
            }
        else:
            if verbose:
                print("  No CI holdout data found")

    return results


def main():
    parser = argparse.ArgumentParser(
        description="Generalization vs Bloat Bins Analysis",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    parser.add_argument(
        "--fo-holdout",
        default="concept_synth/artifacts/analysis/v1/holdout/fo_holdout.json",
        help="Path to FO holdout JSON",
    )
    parser.add_argument(
        "--ci-holdout",
        default="concept_synth/artifacts/analysis/v1/holdout/ci_holdout.json",
        help="Path to CI holdout JSON",
    )
    parser.add_argument(
        "--outdir",
        "-o",
        default="concept_synth/artifacts/analysis/v1/generalization_bins",
        help="Output directory for CSV",
    )
    parser.add_argument(
        "--tables-dir", default="concept_synth/paper/auto/tables", help="Directory for LaTeX tables"
    )
    parser.add_argument(
        "--figures-dir", default="concept_synth/paper/auto/figs", help="Directory for figures"
    )
    parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output")

    args = parser.parse_args()

    run_generalization_bins_analysis(
        fo_holdout_path=Path(args.fo_holdout) if args.fo_holdout else None,
        ci_holdout_path=Path(args.ci_holdout) if args.ci_holdout else None,
        outdir=Path(args.outdir),
        tables_dir=Path(args.tables_dir),
        figures_dir=Path(args.figures_dir),
        verbose=not args.quiet,
    )


if __name__ == "__main__":
    main()
