#!/usr/bin/env python3
"""
Lift-Hard Breakdown Analysis

Computes accuracy breakdowns for lift-hard vs non-lift instances across tasks.
Lift-hard patterns are cross-relational (R/S) patterns that tend to be harder.

Outputs:
- CSV: lift_hard_breakdown.csv
- LaTeX: tab_lift_hard_breakdown.tex
"""

import argparse
import csv
import json
import os
import sys
from collections import defaultdict
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root
add_repo_root(__file__)

from concept_synth.analysis.schema import load_records

# Model display names and order
MODEL_DISPLAY_NAMES = {
    "grok4": "Grok4",
    "gpt-5.2": "GPT-5.2",
    "grok4.1fast": "Grok4.1f",
    "gemini-3-pro-preview": "Gemini3",
    "deepseek-reasoner": "DSR",
    "claude-opus-4-5-20251101": "Opus4.5",
    "hermes4": "Hermes4",
    "gpt-4o": "GPT-4o",
}

MODEL_ORDER = [
    "grok4",
    "gpt-5.2",
    "grok4.1fast",
    "gemini-3-pro-preview",
    "deepseek-reasoner",
    "claude-opus-4-5-20251101",
    "hermes4",
    "gpt-4o",
]


@dataclass
class LiftHardSliceMetrics:
    """Metrics for a lift-hard or non-lift slice."""

    task: str
    model: str
    is_lift_hard: bool
    n_instances: int
    acc_all: float  # Accuracy with denominator=all
    acc_25: float  # Accuracy @gold+25
    coverage: float  # Fraction with parsed formula


@dataclass
class TaskLiftHardSummary:
    """Summary of lift-hard breakdown for a task."""

    task: str
    n_lift_hard: int
    n_non_lift: int
    has_lift_hard: bool  # False if n_lift_hard == 0
    model_metrics: Dict[str, Dict[str, LiftHardSliceMetrics]]  # model -> {lift/nonlift -> metrics}


def compute_lift_hard_breakdown(
    fo_records_path: Optional[Path],
    ci_records_path: Optional[Path],
    ec_records_path: Optional[Path],
    verbose: bool = True,
) -> Dict[str, TaskLiftHardSummary]:
    """
    Compute lift-hard breakdown for each task.

    Args:
        fo_records_path: Path to FO eval_records.jsonl
        ci_records_path: Path to CI eval_records.jsonl
        ec_records_path: Path to EC eval_records.jsonl
        verbose: Print progress

    Returns:
        Dict mapping task name to summary
    """
    results = {}

    tasks = [
        ("FullObs", fo_records_path),
        ("CI", ci_records_path),
        ("EC", ec_records_path),
    ]

    for task_name, records_path in tasks:
        if records_path is None or not records_path.exists():
            if verbose:
                print(f"  [{task_name}] Skipping (no records)")
            continue

        df = load_records([records_path])
        if df.empty:
            if verbose:
                print(f"  [{task_name}] Skipping (empty records)")
            continue

        # Count lift-hard instances (per unique instance_id)
        instance_lift_hard = df.groupby("instance_id")["is_lift_hard_gold"].first()
        n_lift_hard = instance_lift_hard.sum() if "is_lift_hard_gold" in df.columns else 0
        n_non_lift = len(instance_lift_hard) - n_lift_hard

        if verbose:
            print(f"  [{task_name}] {n_lift_hard} lift-hard, {n_non_lift} non-lift instances")

        has_lift_hard = n_lift_hard > 0

        # Compute per-model metrics
        model_metrics = {}

        for model in df["model"].unique():
            model_df = df[df["model"] == model]
            model_metrics[model] = {}

            for is_lift in [True, False]:
                slice_name = "lift" if is_lift else "nonlift"

                if "is_lift_hard_gold" in model_df.columns:
                    slice_df = model_df[model_df["is_lift_hard_gold"] == is_lift]
                else:
                    # If no lift-hard column, treat all as non-lift
                    slice_df = model_df if not is_lift else model_df.iloc[0:0]

                n = len(slice_df)
                if n == 0:
                    model_metrics[model][slice_name] = LiftHardSliceMetrics(
                        task=task_name,
                        model=model,
                        is_lift_hard=is_lift,
                        n_instances=0,
                        acc_all=0.0,
                        acc_25=0.0,
                        coverage=0.0,
                    )
                    continue

                acc_all = slice_df["valid"].sum() / n
                acc_25 = slice_df["budget_ok_25"].sum() / n
                coverage = slice_df["parse_ok"].sum() / n

                model_metrics[model][slice_name] = LiftHardSliceMetrics(
                    task=task_name,
                    model=model,
                    is_lift_hard=is_lift,
                    n_instances=n,
                    acc_all=acc_all,
                    acc_25=acc_25,
                    coverage=coverage,
                )

        results[task_name] = TaskLiftHardSummary(
            task=task_name,
            n_lift_hard=int(n_lift_hard),
            n_non_lift=int(n_non_lift),
            has_lift_hard=has_lift_hard,
            model_metrics=model_metrics,
        )

    return results


def save_csv(summaries: Dict[str, TaskLiftHardSummary], output_path: Path) -> None:
    """Save lift-hard breakdown to CSV."""
    output_path.parent.mkdir(parents=True, exist_ok=True)

    rows = []
    for task_name, summary in summaries.items():
        for model in MODEL_ORDER:
            if model not in summary.model_metrics:
                continue

            for slice_name in ["lift", "nonlift"]:
                metrics = summary.model_metrics[model].get(slice_name)
                if metrics is None:
                    continue

                rows.append(
                    {
                        "task": task_name,
                        "model": model,
                        "slice": slice_name,
                        "n_instances": metrics.n_instances,
                        "acc_all": f"{metrics.acc_all:.4f}",
                        "acc_25": f"{metrics.acc_25:.4f}",
                        "coverage": f"{metrics.coverage:.4f}",
                    }
                )

    with open(output_path, "w", newline="") as f:
        writer = csv.DictWriter(
            f, fieldnames=["task", "model", "slice", "n_instances", "acc_all", "acc_25", "coverage"]
        )
        writer.writeheader()
        writer.writerows(rows)


def generate_latex_table(summaries: Dict[str, TaskLiftHardSummary], output_path: Path) -> None:
    """Generate LaTeX table for lift-hard breakdown."""
    output_path.parent.mkdir(parents=True, exist_ok=True)

    lines = []
    lines.append("% Lift-Hard Breakdown Table (auto-generated)")
    lines.append("\\begin{table*}[t]")
    lines.append("\\centering")
    lines.append("\\caption{\\textbf{Lift-hard breakdown across tasks.}")
    lines.append(
        "Lift-hard instances contain cross-relational patterns (using both R and S predicates)"
    )
    lines.append("that empirically prove harder for models. We report Acc@+25 (budgeted accuracy)")
    lines.append(
        "separately for lift-hard and non-lift instances. EC v1 contains no lift-hard instances"
    )
    lines.append("by construction (N/A).}")
    lines.append("\\label{tab:lift_hard_breakdown}")
    lines.append("\\small")

    # Build header
    # Format: Model | FO Lift | FO Non | CI Lift | CI Non | EC Lift | EC Non
    lines.append("\\begin{tabular}{@{}l|rr|rr|rr@{}}")
    lines.append("\\toprule")

    # Get instance counts for header
    fo_summary = summaries.get("FullObs")
    ci_summary = summaries.get("CI")
    ec_summary = summaries.get("EC")

    fo_lift_n = fo_summary.n_lift_hard if fo_summary else 0
    fo_non_n = fo_summary.n_non_lift if fo_summary else 0
    ci_lift_n = ci_summary.n_lift_hard if ci_summary else 0
    ci_non_n = ci_summary.n_non_lift if ci_summary else 0
    ec_lift_n = ec_summary.n_lift_hard if ec_summary else 0
    ec_non_n = ec_summary.n_non_lift if ec_summary else 0

    lines.append(
        f" & \\multicolumn{{2}}{{c|}}{{FullObs}} & \\multicolumn{{2}}{{c|}}{{CI}} & \\multicolumn{{2}}{{c}}{{EC}} \\\\"
    )
    lines.append(
        f"Model & Lift ({fo_lift_n}) & Non ({fo_non_n}) & Lift ({ci_lift_n}) & Non ({ci_non_n}) & Lift ({ec_lift_n}) & Non ({ec_non_n}) \\\\"
    )
    lines.append("\\midrule")

    # Get all models present
    all_models = set()
    for summary in summaries.values():
        all_models.update(summary.model_metrics.keys())
    models = [m for m in MODEL_ORDER if m in all_models]

    for model in models:
        display = MODEL_DISPLAY_NAMES.get(model, model)
        row = [display]

        for task_name in ["FullObs", "CI", "EC"]:
            summary = summaries.get(task_name)

            if summary is None or model not in summary.model_metrics:
                row.extend(["---", "---"])
                continue

            metrics = summary.model_metrics[model]

            # Lift column
            if summary.has_lift_hard:
                lift_m = metrics.get("lift")
                if lift_m and lift_m.n_instances > 0:
                    row.append(f"{lift_m.acc_25*100:.1f}\\%")
                else:
                    row.append("---")
            else:
                row.append("N/A")

            # Non-lift column
            nonlift_m = metrics.get("nonlift")
            if nonlift_m and nonlift_m.n_instances > 0:
                row.append(f"{nonlift_m.acc_25*100:.1f}\\%")
            else:
                row.append("---")

        lines.append(" & ".join(row) + " \\\\")

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table*}")

    with open(output_path, "w") as f:
        f.write("\n".join(lines))


def run_lift_hard_analysis(
    fo_records_path: Optional[Path],
    ci_records_path: Optional[Path],
    ec_records_path: Optional[Path],
    outdir: Path,
    tables_dir: Path,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Run lift-hard breakdown analysis.

    Args:
        fo_records_path: Path to FO eval_records.jsonl
        ci_records_path: Path to CI eval_records.jsonl
        ec_records_path: Path to EC eval_records.jsonl
        outdir: Output directory for CSV
        tables_dir: Directory for LaTeX tables
        verbose: Print progress

    Returns:
        Summary dict
    """
    if verbose:
        print("[Lift-Hard Breakdown] Computing breakdown...")

    summaries = compute_lift_hard_breakdown(
        fo_records_path, ci_records_path, ec_records_path, verbose
    )

    # Save CSV
    csv_path = outdir / "lift_hard_breakdown.csv"
    save_csv(summaries, csv_path)
    if verbose:
        print(f"[Lift-Hard Breakdown] Saved CSV: {csv_path}")

    # Save LaTeX
    tex_path = tables_dir / "tab_lift_hard_breakdown.tex"
    generate_latex_table(summaries, tex_path)
    if verbose:
        print(f"[Lift-Hard Breakdown] Saved LaTeX: {tex_path}")

    # Return summary
    return {
        "tasks": {
            task: {
                "n_lift_hard": s.n_lift_hard,
                "n_non_lift": s.n_non_lift,
                "has_lift_hard": s.has_lift_hard,
            }
            for task, s in summaries.items()
        },
        "outputs": {
            "csv": str(csv_path),
            "latex": str(tex_path),
        },
    }


def main():
    parser = argparse.ArgumentParser(
        description="Lift-Hard Breakdown Analysis",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    parser.add_argument(
        "--fo-records",
        default="concept_synth/artifacts/analysis/v1/fo/eval_records.jsonl",
        help="Path to FO eval_records.jsonl",
    )
    parser.add_argument(
        "--ci-records",
        default="concept_synth/artifacts/analysis/v1/ci/eval_records.jsonl",
        help="Path to CI eval_records.jsonl",
    )
    parser.add_argument(
        "--ec-records",
        default="concept_synth/artifacts/analysis/v1/ec/eval_records.jsonl",
        help="Path to EC eval_records.jsonl",
    )
    parser.add_argument(
        "--outdir",
        "-o",
        default="concept_synth/artifacts/analysis/v1/lift_hard",
        help="Output directory for CSV",
    )
    parser.add_argument(
        "--tables-dir", default="concept_synth/paper/auto/tables", help="Directory for LaTeX tables"
    )
    parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output")

    args = parser.parse_args()

    run_lift_hard_analysis(
        fo_records_path=Path(args.fo_records) if args.fo_records else None,
        ci_records_path=Path(args.ci_records) if args.ci_records else None,
        ec_records_path=Path(args.ec_records) if args.ec_records else None,
        outdir=Path(args.outdir),
        tables_dir=Path(args.tables_dir),
        verbose=not args.quiet,
    )


if __name__ == "__main__":
    main()
