#!/usr/bin/env python3
"""
CI Failure Decomposition Analysis

Produces:
1. Stacked bar charts showing fail_mode distribution per model
2. Tables by band (core vs lift_mix)
3. Breakdown by is_lift_hard_gold

Usage:
    python -m concept_synth.analysis.ci_failure_decomposition \
        --records artifacts/analysis/v1/ci/eval_records.jsonl \
        --outdir artifacts/analysis/v1/ci/failure_decomposition
"""

import argparse
import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root
add_repo_root(__file__)

import numpy as np

try:
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    HAS_MATPLOTLIB = True
except ImportError:
    HAS_MATPLOTLIB = False

try:
    import pandas as pd

    HAS_PANDAS = True
except ImportError:
    HAS_PANDAS = False

from concept_synth.analysis.schema import load_records

# Model display names
MODEL_DISPLAY_NAMES = {
    "grok4": "Grok4",
    "gpt-5.2": "GPT-5.2",
    "grok4.1fast": "Grok4.1f",
    "gemini-3-pro-preview": "Gemini 3",
    "deepseek-reasoner": "DSR",
    "claude-opus-4-5-20251101": "Opus 4.5",
    "hermes4": "Hermes4",
    "gpt-4o": "GPT-4o",
}

MODEL_ORDER = [
    "grok4",
    "gpt-5.2",
    "grok4.1fast",
    "gemini-3-pro-preview",
    "deepseek-reasoner",
    "claude-opus-4-5-20251101",
    "hermes4",
    "gpt-4o",
]

# Failure mode colors
FAIL_MODE_COLORS = {
    "correct": "#2ecc71",  # Green
    "no_fail": "#f39c12",  # Orange
    "yes_fail": "#e74c3c",  # Red
    "parse": "#9b59b6",  # Purple
    "missing": "#95a5a6",  # Gray
}

FAIL_MODE_ORDER = ["correct", "no_fail", "yes_fail", "parse", "missing"]
FAIL_MODE_LABELS = {
    "correct": "Correct",
    "no_fail": "NO-fail",
    "yes_fail": "YES-fail",
    "parse": "Parse Error",
    "missing": "Missing",
}


def compute_failure_distribution(df: "pd.DataFrame") -> Dict[str, Dict[str, float]]:
    """
    Compute failure mode distribution for each model.

    Returns: model -> {fail_mode -> fraction}
    """
    result = {}

    for model in df["model"].unique():
        model_df = df[df["model"] == model]
        total = len(model_df)

        if total == 0:
            continue

        dist = {}
        for mode in FAIL_MODE_ORDER:
            count = (model_df["fail_mode"] == mode).sum()
            dist[mode] = count / total

        result[model] = dist

    return result


def compute_failure_by_band(df: "pd.DataFrame") -> Dict[str, Dict[str, Dict[str, float]]]:
    """
    Compute failure distribution by band for each model.

    Returns: band -> model -> {fail_mode -> fraction}
    """
    result = {}

    for band in df["band"].unique():
        band_df = df[df["band"] == band]
        result[band] = compute_failure_distribution(band_df)

    return result


def compute_failure_by_lift_hard(df: "pd.DataFrame") -> Dict[bool, Dict[str, Dict[str, float]]]:
    """
    Compute failure distribution by lift-hard status.

    Returns: is_lift_hard -> model -> {fail_mode -> fraction}
    """
    result = {}

    for is_lift in [True, False]:
        lift_df = df[df["is_lift_hard_gold"] == is_lift]
        if len(lift_df) > 0:
            result[is_lift] = compute_failure_distribution(lift_df)

    return result


def plot_failure_stacked_bars(
    df: "pd.DataFrame", output_path: Path, figsize: tuple = (12, 6)
) -> None:
    """
    Create stacked bar chart of failure modes per model.
    """
    if not HAS_MATPLOTLIB:
        return

    dist = compute_failure_distribution(df)

    # Sort models by correct rate
    models = [m for m in MODEL_ORDER if m in dist]
    if not models:
        models = sorted(dist.keys(), key=lambda m: dist[m].get("correct", 0), reverse=True)

    fig, ax = plt.subplots(figsize=figsize)

    x = np.arange(len(models))
    width = 0.7

    bottom = np.zeros(len(models))

    for mode in FAIL_MODE_ORDER:
        values = [dist[m].get(mode, 0) * 100 for m in models]
        ax.bar(
            x,
            values,
            width,
            label=FAIL_MODE_LABELS[mode],
            bottom=bottom,
            color=FAIL_MODE_COLORS[mode],
        )
        bottom += values

    ax.set_xlabel("Model", fontsize=12)
    ax.set_ylabel("Percentage", fontsize=12)
    ax.set_title("CI v1: Failure Mode Distribution by Model", fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels([MODEL_DISPLAY_NAMES.get(m, m) for m in models], rotation=45, ha="right")
    ax.legend(loc="upper right", fontsize=10)
    ax.set_ylim(0, 100)
    ax.grid(True, alpha=0.3, axis="y")

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved failure mode plot to {output_path}")


def plot_failure_by_band(df: "pd.DataFrame", output_path: Path, figsize: tuple = (14, 6)) -> None:
    """
    Create side-by-side failure mode charts by band.
    """
    if not HAS_MATPLOTLIB:
        return

    bands = sorted(df["band"].unique())
    n_bands = len(bands)

    if n_bands == 0:
        return

    fig, axes = plt.subplots(1, n_bands, figsize=figsize)
    if n_bands == 1:
        axes = [axes]

    for i, band in enumerate(bands):
        ax = axes[i]
        band_df = df[df["band"] == band]
        dist = compute_failure_distribution(band_df)

        models = [m for m in MODEL_ORDER if m in dist]
        if not models:
            models = sorted(dist.keys(), key=lambda m: dist[m].get("correct", 0), reverse=True)

        x = np.arange(len(models))
        width = 0.7
        bottom = np.zeros(len(models))

        for mode in FAIL_MODE_ORDER:
            values = [dist[m].get(mode, 0) * 100 for m in models]
            ax.bar(
                x,
                values,
                width,
                label=FAIL_MODE_LABELS[mode] if i == 0 else "",
                bottom=bottom,
                color=FAIL_MODE_COLORS[mode],
            )
            bottom += values

        ax.set_title(f"Band: {band}", fontsize=12)
        ax.set_xticks(x)
        ax.set_xticklabels(
            [MODEL_DISPLAY_NAMES.get(m, m)[:8] for m in models], rotation=45, ha="right", fontsize=9
        )
        ax.set_ylim(0, 100)
        ax.grid(True, alpha=0.3, axis="y")

        if i == 0:
            ax.set_ylabel("Percentage", fontsize=11)

    # Add legend
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper center", ncol=5, fontsize=10, bbox_to_anchor=(0.5, 1.02))

    fig.suptitle("CI v1: Failure Modes by Band", fontsize=14, y=1.08)
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved band failure plot to {output_path}")


def generate_failure_table(df: "pd.DataFrame", output_path: Path) -> str:
    """
    Generate LaTeX table for CI failure modes.
    """
    dist = compute_failure_distribution(df)

    # Sort by correct rate
    models = [m for m in MODEL_ORDER if m in dist]
    if not models:
        models = sorted(dist.keys(), key=lambda m: dist[m].get("correct", 0), reverse=True)

    # Find best values (higher correct is better, lower failures are better)
    best_correct = max(dist[m].get("correct", 0) for m in models)
    best_yes_fail = min(dist[m].get("yes_fail", 1) for m in models)
    best_no_fail = min(dist[m].get("no_fail", 1) for m in models)
    best_parse = min(dist[m].get("parse", 1) for m in models)
    best_missing = min(dist[m].get("missing", 1) for m in models)

    def fmt_val(val: float, best: float, higher_better: bool = True) -> str:
        pct = f"{val*100:.1f}\\%"
        is_best = abs(val - best) < 0.001
        return f"\\textbf{{{pct}}}" if is_best else pct

    lines = []
    lines.append("% CI v1 Failure Mode Distribution")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append("\\caption{CI v1 failure mode distribution by model.}")
    lines.append("\\label{tab:ci_failure_modes}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}lrrrrr@{}}")
    lines.append("\\toprule")
    lines.append("Model & Correct & YES-fail & NO-fail & Parse & Missing \\\\")
    lines.append("\\midrule")

    for model in models:
        d = dist[model]
        display = MODEL_DISPLAY_NAMES.get(model, model)
        correct_str = fmt_val(d.get("correct", 0), best_correct, higher_better=True)
        yes_fail_str = fmt_val(d.get("yes_fail", 0), best_yes_fail, higher_better=False)
        no_fail_str = fmt_val(d.get("no_fail", 0), best_no_fail, higher_better=False)
        parse_str = fmt_val(d.get("parse", 0), best_parse, higher_better=False)
        missing_str = fmt_val(d.get("missing", 0), best_missing, higher_better=False)
        lines.append(
            f"{display} & {correct_str} & {yes_fail_str} & {no_fail_str} & {parse_str} & {missing_str} \\\\"
        )

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    table_content = "\n".join(lines)

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write(table_content)

    print(f"Saved failure table to {output_path}")
    return table_content


def generate_band_failure_table(df: "pd.DataFrame", output_path: Path) -> str:
    """
    Generate LaTeX table for CI failure modes by band.
    """
    by_band = compute_failure_by_band(df)
    bands = sorted(by_band.keys())

    def fmt_val(val: float, best: float, higher_better: bool = True) -> str:
        pct = f"{val*100:.1f}\\%"
        is_best = abs(val - best) < 0.001
        return f"\\textbf{{{pct}}}" if is_best else pct

    lines = []
    lines.append("% CI v1 Failure Modes by Band")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append("\\caption{CI v1 failure modes by band.}")
    lines.append("\\label{tab:ci_failure_by_band}")
    lines.append("\\small")

    for band in bands:
        # Escape underscores for LaTeX
        band_escaped = band.replace("_", "\\_")
        lines.append(f"\\textbf{{Band: {band_escaped}}}")
        lines.append("")
        lines.append("\\begin{tabular}{@{}lrrrrr@{}}")
        lines.append("\\toprule")
        lines.append("Model & Correct & YES-fail & NO-fail & Parse & Missing \\\\")
        lines.append("\\midrule")

        dist = by_band[band]
        models = [m for m in MODEL_ORDER if m in dist]
        if not models:
            models = sorted(dist.keys(), key=lambda m: dist[m].get("correct", 0), reverse=True)

        # Find best values for this band
        best_correct = max(dist[m].get("correct", 0) for m in models)
        best_yes_fail = min(dist[m].get("yes_fail", 1) for m in models)
        best_no_fail = min(dist[m].get("no_fail", 1) for m in models)
        best_parse = min(dist[m].get("parse", 1) for m in models)
        best_missing = min(dist[m].get("missing", 1) for m in models)

        for model in models:
            d = dist[model]
            display = MODEL_DISPLAY_NAMES.get(model, model)
            correct_str = fmt_val(d.get("correct", 0), best_correct, higher_better=True)
            yes_fail_str = fmt_val(d.get("yes_fail", 0), best_yes_fail, higher_better=False)
            no_fail_str = fmt_val(d.get("no_fail", 0), best_no_fail, higher_better=False)
            parse_str = fmt_val(d.get("parse", 0), best_parse, higher_better=False)
            missing_str = fmt_val(d.get("missing", 0), best_missing, higher_better=False)
            lines.append(
                f"{display} & {correct_str} & {yes_fail_str} & {no_fail_str} & {parse_str} & {missing_str} \\\\"
            )

        lines.append("\\bottomrule")
        lines.append("\\end{tabular}")
        lines.append("")

    lines.append("\\end{table}")

    table_content = "\n".join(lines)

    with open(output_path, "w") as f:
        f.write(table_content)

    print(f"Saved band failure table to {output_path}")
    return table_content


def run_ci_failure_analysis(
    records_path: Path,
    outdir: Path,
    figures_dir: Optional[Path] = None,
    tables_dir: Optional[Path] = None,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Run full CI failure decomposition analysis.
    """
    if not HAS_PANDAS:
        print("Error: pandas required")
        return {}

    if verbose:
        print(f"Loading records from {records_path}...")

    df = load_records([records_path])

    if len(df) == 0:
        print(f"No records found")
        return {}

    if verbose:
        print(f"Loaded {len(df)} records")

    # Create directories
    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    if figures_dir is None:
        figures_dir = Path("figures/induction")
    figures_dir = Path(figures_dir)
    figures_dir.mkdir(parents=True, exist_ok=True)

    if tables_dir is None:
        tables_dir = Path("tables/induction")
    tables_dir = Path(tables_dir)
    tables_dir.mkdir(parents=True, exist_ok=True)

    # Generate plots
    if HAS_MATPLOTLIB:
        plot_failure_stacked_bars(df, figures_dir / "ci_failure_modes.pdf")
        plot_failure_by_band(df, figures_dir / "ci_failure_by_band.pdf")

    # Generate tables
    generate_failure_table(df, tables_dir / "ci_failure_modes.tex")
    generate_band_failure_table(df, tables_dir / "ci_failure_by_band.tex")

    # Compute summary
    dist = compute_failure_distribution(df)
    by_band = compute_failure_by_band(df)
    by_lift = compute_failure_by_lift_hard(df)

    summary = {
        "total_records": len(df),
        "overall_distribution": dist,
        "by_band": by_band,
        "by_lift_hard": {str(k): v for k, v in by_lift.items()},
    }

    summary_path = outdir / "ci_failure_summary.json"
    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=2)

    if verbose:
        print(f"\nSummary saved to {summary_path}")
        print("\nOverall Failure Distribution:")
        print("-" * 60)
        for model in sorted(dist.keys(), key=lambda m: dist[m].get("correct", 0), reverse=True):
            d = dist[model]
            display = MODEL_DISPLAY_NAMES.get(model, model)
            print(
                f"  {display:20s}: Correct={d.get('correct', 0)*100:5.1f}%, "
                f"YES-fail={d.get('yes_fail', 0)*100:5.1f}%, "
                f"NO-fail={d.get('no_fail', 0)*100:5.1f}%"
            )

    return summary


def main():
    parser = argparse.ArgumentParser(
        description="CI failure decomposition analysis",
    )
    parser.add_argument("--records", "-r", required=True, help="Path to eval_records.jsonl")
    parser.add_argument("--outdir", "-o", required=True, help="Output directory")
    parser.add_argument("--figures-dir", help="Directory for PDF figures")
    parser.add_argument("--tables-dir", help="Directory for LaTeX tables")
    parser.add_argument("--quiet", "-q", action="store_true")

    args = parser.parse_args()

    run_ci_failure_analysis(
        records_path=Path(args.records),
        outdir=Path(args.outdir),
        figures_dir=Path(args.figures_dir) if args.figures_dir else None,
        tables_dir=Path(args.tables_dir) if args.tables_dir else None,
        verbose=not args.quiet,
    )


if __name__ == "__main__":
    main()
