#!/usr/bin/env python3
"""
Structural Breakdown Analysis

Analyzes performance by:
1. family_id_gold (formula family)
2. is_lift_hard_gold (lift-hard patterns)
3. gold_ast and quantifier_depth_gold bins
4. Top subfamilies by frequency

Usage:
    python -m concept_synth.analysis.structural_breakdown \
        --task fo \
        --records artifacts/analysis/v1/fo/eval_records.jsonl \
        --outdir artifacts/analysis/v1/fo/structural
"""

import argparse
import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root
add_repo_root(__file__)

import numpy as np

try:
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    HAS_MATPLOTLIB = True
except ImportError:
    HAS_MATPLOTLIB = False

try:
    import pandas as pd

    HAS_PANDAS = True
except ImportError:
    HAS_PANDAS = False

from concept_synth.analysis.schema import load_records

MODEL_DISPLAY_NAMES = {
    "grok4": "Grok4",
    "gpt-5.2": "GPT-5.2",
    "grok4.1fast": "Grok4.1f",
    "gemini-3-pro-preview": "Gemini 3",
    "deepseek-reasoner": "DSR",
    "claude-opus-4-5-20251101": "Opus 4.5",
    "hermes4": "Hermes4",
    "gpt-4o": "GPT-4o",
}

# Family display names (abbreviate OTHER)
FAMILY_DISPLAY_NAMES = {
    "OTHER": "oth",
}

MODEL_ORDER = [
    "grok4",
    "gpt-5.2",
    "grok4.1fast",
    "gemini-3-pro-preview",
    "deepseek-reasoner",
    "claude-opus-4-5-20251101",
    "hermes4",
    "gpt-4o",
]

MODEL_COLORS = {
    "grok4": "#1f77b4",
    "gpt-5.2": "#ff7f0e",
    "grok4.1fast": "#2ca02c",
    "gemini-3-pro-preview": "#d62728",
    "deepseek-reasoner": "#9467bd",
    "claude-opus-4-5-20251101": "#8c564b",
    "hermes4": "#e377c2",
    "gpt-4o": "#7f7f7f",
}


def compute_metrics_by_group(
    df: "pd.DataFrame", group_col: str
) -> Dict[str, Dict[str, Dict[str, float]]]:
    """
    Compute accuracy metrics grouped by a column.

    Returns: group_value -> model -> {acc_25, validity, bloat_pct, count}
    """
    result = {}

    for group_val in df[group_col].unique():
        if pd.isna(group_val):
            continue

        group_df = df[df[group_col] == group_val]
        result[str(group_val)] = {}

        for model in group_df["model"].unique():
            model_df = group_df[group_df["model"] == model]
            total = len(model_df)

            if total == 0:
                continue

            valid_count = model_df["valid"].sum()
            acc_25_count = model_df["budget_ok_25"].sum()

            # Bloat: valid but ast_delta > 25
            bloat_count = (
                (model_df["valid"] == True) & (model_df["ast_delta"].fillna(0) > 25)
            ).sum()

            result[str(group_val)][model] = {
                "acc_25": acc_25_count / total,
                "validity": valid_count / total,
                "bloat_pct": bloat_count / total,
                "count": total,
            }

    return result


def compute_family_metrics(df: "pd.DataFrame") -> Dict[str, Dict[str, Dict[str, float]]]:
    """Compute metrics by family_id_gold."""
    return compute_metrics_by_group(df, "family_id_gold")


def compute_lift_hard_metrics(df: "pd.DataFrame") -> Dict[str, Dict[str, Dict[str, float]]]:
    """Compute metrics by is_lift_hard_gold."""
    return compute_metrics_by_group(df, "is_lift_hard_gold")


def compute_ast_bin_metrics(
    df: "pd.DataFrame", bins: List[int] = [0, 15, 20, 25, 30, 100]
) -> Dict[str, Dict[str, Dict[str, float]]]:
    """Compute metrics by gold_ast bins."""
    df = df.copy()
    df["ast_bin"] = pd.cut(
        df["gold_ast"], bins=bins, labels=[f"{bins[i]}-{bins[i+1]}" for i in range(len(bins) - 1)]
    )
    return compute_metrics_by_group(df, "ast_bin")


def compute_qd_metrics(df: "pd.DataFrame") -> Dict[str, Dict[str, Dict[str, float]]]:
    """Compute metrics by quantifier_depth_gold."""
    return compute_metrics_by_group(df, "quantifier_depth_gold")


def get_top_subfamilies(df: "pd.DataFrame", n: int = 10) -> List[Tuple[str, int]]:
    """Get top N subfamilies by frequency."""
    counts = df["subfamily_key_gold"].value_counts()
    return [(str(k), int(v)) for k, v in counts.head(n).items()]


def compute_subfamily_metrics(
    df: "pd.DataFrame", top_n: int = 10
) -> Dict[str, Dict[str, Dict[str, float]]]:
    """Compute metrics for top subfamilies."""
    top_subs = [s for s, _ in get_top_subfamilies(df, top_n)]

    # Filter to top subfamilies
    df_filtered = df[df["subfamily_key_gold"].isin(top_subs)]
    return compute_metrics_by_group(df_filtered, "subfamily_key_gold")


def plot_family_bars(
    df: "pd.DataFrame", task: str, output_path: Path, figsize: tuple = (14, 8)
) -> None:
    """Plot accuracy by family for each model."""
    if not HAS_MATPLOTLIB:
        return

    family_metrics = compute_family_metrics(df)
    families = sorted(family_metrics.keys())

    if not families:
        return

    models = [m for m in MODEL_ORDER if any(m in family_metrics[f] for f in families)]
    n_models = len(models)

    if n_models == 0:
        return

    fig, ax = plt.subplots(figsize=figsize)

    x = np.arange(len(families))
    width = 0.8 / n_models

    for i, model in enumerate(models):
        values = []
        for family in families:
            if model in family_metrics[family]:
                values.append(family_metrics[family][model]["acc_25"] * 100)
            else:
                values.append(0)

        offset = (i - n_models / 2 + 0.5) * width
        ax.bar(
            x + offset,
            values,
            width,
            label=MODEL_DISPLAY_NAMES.get(model, model),
            color=MODEL_COLORS.get(model),
        )

    ax.set_xlabel("Formula Family", fontsize=12)
    ax.set_ylabel("Acc@+25 (%)", fontsize=12)
    ax.set_title(f"{task.upper()} v1: Accuracy by Formula Family", fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(families, rotation=45, ha="right")
    ax.legend(loc="upper right", fontsize=9, ncol=2)
    ax.grid(True, alpha=0.3, axis="y")
    ax.set_ylim(0, 100)

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved family bar plot to {output_path}")


def plot_lift_hard_comparison(
    df: "pd.DataFrame", task: str, output_path: Path, figsize: tuple = (10, 6)
) -> None:
    """Plot lift-hard vs non-lift comparison."""
    if not HAS_MATPLOTLIB:
        return

    lift_metrics = compute_lift_hard_metrics(df)

    if "True" not in lift_metrics or "False" not in lift_metrics:
        print(f"Insufficient lift-hard data for {task}")
        return

    models = [
        m
        for m in MODEL_ORDER
        if m in lift_metrics.get("True", {}) or m in lift_metrics.get("False", {})
    ]

    if not models:
        return

    fig, ax = plt.subplots(figsize=figsize)

    x = np.arange(len(models))
    width = 0.35

    lift_vals = [lift_metrics.get("True", {}).get(m, {}).get("acc_25", 0) * 100 for m in models]
    nonlift_vals = [lift_metrics.get("False", {}).get(m, {}).get("acc_25", 0) * 100 for m in models]

    ax.bar(x - width / 2, nonlift_vals, width, label="Non-Lift", color="#3498db")
    ax.bar(x + width / 2, lift_vals, width, label="Lift-Hard", color="#e74c3c")

    ax.set_xlabel("Model", fontsize=12)
    ax.set_ylabel("Acc@+25 (%)", fontsize=12)
    ax.set_title(f"{task.upper()} v1: Lift-Hard vs Non-Lift Accuracy", fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels([MODEL_DISPLAY_NAMES.get(m, m) for m in models], rotation=45, ha="right")
    ax.legend(loc="upper right")
    ax.grid(True, alpha=0.3, axis="y")
    ax.set_ylim(0, 100)

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved lift-hard plot to {output_path}")


def generate_family_table(df: "pd.DataFrame", task: str, output_path: Path) -> str:
    """Generate LaTeX table for family metrics."""
    family_metrics = compute_family_metrics(df)
    families = sorted(family_metrics.keys())

    # Get models present
    all_models = set()
    for f in families:
        all_models.update(family_metrics[f].keys())
    models = [m for m in MODEL_ORDER if m in all_models][:5]  # Top 5 models

    # Find best value for each family (across models)
    best_by_family = {}
    for family in families:
        accs = [family_metrics[family].get(m, {}).get("acc_25", 0) for m in models]
        best_by_family[family] = max(accs) if accs else 0

    lines = []
    lines.append(f"% {task.upper()} v1 Accuracy by Family")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append(f"\\caption{{{task.upper()} v1 Acc@+25 by formula family.}}")
    lines.append(f"\\label{{tab:{task}_by_family}}")
    lines.append("\\small")

    # Header
    header_models = " & ".join([MODEL_DISPLAY_NAMES.get(m, m)[:8] for m in models])
    lines.append(f"\\begin{{tabular}}{{@{{}}l{'r' * len(models)}@{{}}}}")
    lines.append("\\toprule")
    lines.append(f"Family & {header_models} \\\\")
    lines.append("\\midrule")

    for family in families:
        # Use abbreviated family name
        family_display = FAMILY_DISPLAY_NAMES.get(family, family)
        vals = []
        best = best_by_family[family]
        for model in models:
            if model in family_metrics[family]:
                acc = family_metrics[family][model]["acc_25"]
                pct = f"{acc*100:.1f}\\%"
                # Bold if best for this family
                if abs(acc - best) < 0.001 and best > 0:
                    pct = f"\\textbf{{{pct}}}"
                vals.append(pct)
            else:
                vals.append("--")
        lines.append(f"{family_display} & {' & '.join(vals)} \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    table_content = "\n".join(lines)

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        f.write(table_content)

    print(f"Saved family table to {output_path}")
    return table_content


def run_structural_analysis(
    task: str,
    records_path: Path,
    outdir: Path,
    figures_dir: Optional[Path] = None,
    tables_dir: Optional[Path] = None,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Run full structural breakdown analysis.
    """
    if not HAS_PANDAS:
        print("Error: pandas required")
        return {}

    if verbose:
        print(f"Loading records from {records_path}...")

    df = load_records([records_path])

    if len(df) == 0:
        print(f"No records found")
        return {}

    if verbose:
        print(f"Loaded {len(df)} records")

    # Create directories
    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    if figures_dir is None:
        figures_dir = Path("figures/induction")
    figures_dir = Path(figures_dir)
    figures_dir.mkdir(parents=True, exist_ok=True)

    if tables_dir is None:
        tables_dir = Path("tables/induction")
    tables_dir = Path(tables_dir)
    tables_dir.mkdir(parents=True, exist_ok=True)

    # Generate plots
    if HAS_MATPLOTLIB:
        plot_family_bars(df, task, figures_dir / f"{task}_family_bars.pdf")
        plot_lift_hard_comparison(df, task, figures_dir / f"{task}_lift_hard.pdf")

    # Generate tables
    generate_family_table(df, task, tables_dir / f"{task}_by_family.tex")

    # Compute all metrics
    family_metrics = compute_family_metrics(df)
    lift_metrics = compute_lift_hard_metrics(df)
    ast_metrics = compute_ast_bin_metrics(df)
    qd_metrics = compute_qd_metrics(df)
    top_subs = get_top_subfamilies(df)
    subfamily_metrics = compute_subfamily_metrics(df)

    summary = {
        "task": task,
        "total_records": len(df),
        "by_family": family_metrics,
        "by_lift_hard": lift_metrics,
        "by_ast_bin": ast_metrics,
        "by_qd": qd_metrics,
        "top_subfamilies": top_subs,
        "by_subfamily": subfamily_metrics,
    }

    summary_path = outdir / "structural_summary.json"
    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=2)

    if verbose:
        print(f"\nSummary saved to {summary_path}")

        print("\nFamily Distribution:")
        for family in sorted(family_metrics.keys()):
            counts = [family_metrics[family][m]["count"] for m in family_metrics[family]]
            avg_count = sum(counts) / len(counts) if counts else 0
            print(f"  {family}: ~{avg_count:.0f} instances per model")

        print("\nLift-Hard Breakdown:")
        for is_lift in ["False", "True"]:
            if is_lift in lift_metrics:
                label = "Lift-Hard" if is_lift == "True" else "Non-Lift"
                print(f"  {label}:")
                for model in sorted(
                    lift_metrics[is_lift].keys(),
                    key=lambda m: lift_metrics[is_lift][m]["acc_25"],
                    reverse=True,
                )[:5]:
                    m = lift_metrics[is_lift][model]
                    display = MODEL_DISPLAY_NAMES.get(model, model)
                    print(f"    {display:20s}: Acc@+25={m['acc_25']*100:5.1f}%")

    return summary


def main():
    parser = argparse.ArgumentParser(
        description="Structural breakdown analysis",
    )
    parser.add_argument("--task", "-t", required=True, choices=["fo", "ci", "ec"], help="Task type")
    parser.add_argument("--records", "-r", required=True, help="Path to eval_records.jsonl")
    parser.add_argument("--outdir", "-o", required=True, help="Output directory")
    parser.add_argument("--figures-dir", help="Directory for PDF figures")
    parser.add_argument("--tables-dir", help="Directory for LaTeX tables")
    parser.add_argument("--quiet", "-q", action="store_true")

    args = parser.parse_args()

    run_structural_analysis(
        task=args.task,
        records_path=Path(args.records),
        outdir=Path(args.outdir),
        figures_dir=Path(args.figures_dir) if args.figures_dir else None,
        tables_dir=Path(args.tables_dir) if args.tables_dir else None,
        verbose=not args.quiet,
    )


if __name__ == "__main__":
    main()
