#!/usr/bin/env python3
"""
Top ROI Analyses Driver Script

Runs all three top ROI analyses and produces paper-ready outputs:
1. Budget curves + bloat histograms (FO and EC)
2. CI failure decomposition
3. Structural breakdowns by family/subfamily

Usage:
    python -m concept_synth.analysis.run_top_roi_analyses \
        --fo-dataset results/ad_benchmark/ad_benchmark_v1.yaml \
        --ci-dataset results/c_benchmark/c_benchmark_v1.yaml \
        --ec-dataset results/e_benchmark/e_benchmark_v1.yaml

Outputs are written to:
    - artifacts/analysis/v1/{fo,ci,ec}/...
    - figures/induction/*.pdf
    - tables/induction/*.tex
"""

import argparse
import json
import os
import sys
from datetime import datetime
from pathlib import Path

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root
add_repo_root(__file__)


def run_all_analyses(
    fo_dataset: Path,
    ci_dataset: Path,
    ec_dataset: Path,
    base_outdir: Path = Path("concept_synth/artifacts/analysis/v1"),
    figures_dir: Path = Path("concept_synth/figures/induction"),
    tables_dir: Path = Path("concept_synth/tables/induction"),
    skip_dump: bool = False,
    extra_analyses_v1: bool = True,
    verbose: bool = True,
) -> dict:
    """
    Run all top ROI analyses.

    Args:
        fo_dataset: Path to FO (AD) benchmark YAML
        ci_dataset: Path to CI (C) benchmark YAML
        ec_dataset: Path to EC (E) benchmark YAML
        base_outdir: Base output directory for artifacts
        figures_dir: Directory for PDF figures
        tables_dir: Directory for LaTeX tables
        skip_dump: Skip eval record dumping (use existing)
        extra_analyses_v1: Run additional v1 analyses (lift-hard, difficulty, bloat bins)
        verbose: Print progress

    Returns:
        Summary dict with all results
    """
    from concept_synth.analysis.budget_curves import run_budget_analysis
    from concept_synth.analysis.ci_failure_decomposition import run_ci_failure_analysis
    from concept_synth.analysis.difficulty_validation import run_difficulty_validation
    from concept_synth.analysis.dump_eval_records import dump_records_for_task
    from concept_synth.analysis.generalization_bloat_bins import run_generalization_bins_analysis
    from concept_synth.analysis.lift_hard_breakdown import run_lift_hard_analysis
    from concept_synth.analysis.structural_breakdown import run_structural_analysis

    results = {
        "timestamp": datetime.now().isoformat(),
        "inputs": {
            "fo_dataset": str(fo_dataset),
            "ci_dataset": str(ci_dataset),
            "ec_dataset": str(ec_dataset),
        },
        "outputs": {},
        "summaries": {},
    }

    # Create directories
    base_outdir = Path(base_outdir)
    figures_dir = Path(figures_dir)
    tables_dir = Path(tables_dir)

    for d in [base_outdir, figures_dir, tables_dir]:
        d.mkdir(parents=True, exist_ok=True)

    # =========================================================================
    # Step 1: Dump evaluation records
    # =========================================================================
    if verbose:
        print("\n" + "=" * 70)
        print("STEP 1: Dumping Evaluation Records")
        print("=" * 70)

    fo_records = base_outdir / "fo" / "eval_records.jsonl"
    ci_records = base_outdir / "ci" / "eval_records.jsonl"
    ec_records = base_outdir / "ec" / "eval_records.jsonl"

    if not skip_dump:
        if verbose:
            print("\n[FO] Dumping evaluation records...")
        dump_records_for_task("fo", fo_dataset, fo_records, verbose=verbose)

        if verbose:
            print("\n[CI] Dumping evaluation records...")
        dump_records_for_task("ci", ci_dataset, ci_records, verbose=verbose)

        if verbose:
            print("\n[EC] Dumping evaluation records...")
        dump_records_for_task("ec", ec_dataset, ec_records, verbose=verbose)
    else:
        if verbose:
            print("Skipping record dump (using existing files)")

    results["outputs"]["fo_records"] = str(fo_records)
    results["outputs"]["ci_records"] = str(ci_records)
    results["outputs"]["ec_records"] = str(ec_records)

    # =========================================================================
    # Step 2: Budget Curves Analysis (FO and EC)
    # =========================================================================
    if verbose:
        print("\n" + "=" * 70)
        print("STEP 2: Budget Curves Analysis")
        print("=" * 70)

    if verbose:
        print("\n[FO] Running budget curve analysis...")
    fo_budget = run_budget_analysis(
        task="fo",
        records_path=fo_records,
        outdir=base_outdir / "fo" / "budget_curves",
        figures_dir=figures_dir,
        tables_dir=tables_dir,
        verbose=verbose,
    )
    results["summaries"]["fo_budget"] = fo_budget

    if verbose:
        print("\n[EC] Running budget curve analysis...")
    ec_budget = run_budget_analysis(
        task="ec",
        records_path=ec_records,
        outdir=base_outdir / "ec" / "budget_curves",
        figures_dir=figures_dir,
        tables_dir=tables_dir,
        verbose=verbose,
    )
    results["summaries"]["ec_budget"] = ec_budget

    # =========================================================================
    # Step 3: CI Failure Decomposition
    # =========================================================================
    if verbose:
        print("\n" + "=" * 70)
        print("STEP 3: CI Failure Decomposition")
        print("=" * 70)

    ci_failure = run_ci_failure_analysis(
        records_path=ci_records,
        outdir=base_outdir / "ci" / "failure_decomposition",
        figures_dir=figures_dir,
        tables_dir=tables_dir,
        verbose=verbose,
    )
    results["summaries"]["ci_failure"] = ci_failure

    # =========================================================================
    # Step 4: Structural Breakdowns
    # =========================================================================
    if verbose:
        print("\n" + "=" * 70)
        print("STEP 4: Structural Breakdowns")
        print("=" * 70)

    if verbose:
        print("\n[FO] Running structural analysis...")
    fo_struct = run_structural_analysis(
        task="fo",
        records_path=fo_records,
        outdir=base_outdir / "fo" / "structural",
        figures_dir=figures_dir,
        tables_dir=tables_dir,
        verbose=verbose,
    )
    results["summaries"]["fo_structural"] = fo_struct

    if verbose:
        print("\n[CI] Running structural analysis...")
    ci_struct = run_structural_analysis(
        task="ci",
        records_path=ci_records,
        outdir=base_outdir / "ci" / "structural",
        figures_dir=figures_dir,
        tables_dir=tables_dir,
        verbose=verbose,
    )
    results["summaries"]["ci_structural"] = ci_struct

    if verbose:
        print("\n[EC] Running structural analysis...")
    ec_struct = run_structural_analysis(
        task="ec",
        records_path=ec_records,
        outdir=base_outdir / "ec" / "structural",
        figures_dir=figures_dir,
        tables_dir=tables_dir,
        verbose=verbose,
    )
    results["summaries"]["ec_structural"] = ec_struct

    # =========================================================================
    # Step 5: Extra V1 Analyses (optional)
    # =========================================================================
    if extra_analyses_v1:
        if verbose:
            print("\n" + "=" * 70)
            print("STEP 5: Extra V1 Analyses")
            print("=" * 70)

        # Paper auto directories
        paper_tables_dir = Path("concept_synth/paper/auto/tables")
        paper_figs_dir = Path("concept_synth/paper/auto/figs")
        paper_tables_dir.mkdir(parents=True, exist_ok=True)
        paper_figs_dir.mkdir(parents=True, exist_ok=True)

        # 5a: Lift-Hard Breakdown
        if verbose:
            print("\n[Extra] Running lift-hard breakdown analysis...")
        lift_hard_result = run_lift_hard_analysis(
            fo_records_path=fo_records,
            ci_records_path=ci_records,
            ec_records_path=ec_records,
            outdir=base_outdir / "lift_hard",
            tables_dir=paper_tables_dir,
            verbose=verbose,
        )
        results["summaries"]["lift_hard"] = lift_hard_result

        # 5b: Difficulty Validation
        if verbose:
            print("\n[Extra] Running difficulty validation analysis...")
        difficulty_result = run_difficulty_validation(
            fo_dataset=fo_dataset,
            ec_dataset=ec_dataset,
            outdir=base_outdir / "difficulty",
            tables_dir=paper_tables_dir,
            figures_dir=paper_figs_dir,
            verbose=verbose,
        )
        results["summaries"]["difficulty"] = difficulty_result

        # 5c: Generalization vs Bloat Bins
        if verbose:
            print("\n[Extra] Running generalization vs bloat bins analysis...")
        fo_holdout_path = base_outdir / "holdout" / "fo_holdout.json"
        ci_holdout_path = base_outdir / "holdout" / "ci_holdout.json"

        gen_bins_result = run_generalization_bins_analysis(
            fo_holdout_path=fo_holdout_path if fo_holdout_path.exists() else None,
            ci_holdout_path=ci_holdout_path if ci_holdout_path.exists() else None,
            outdir=base_outdir / "generalization_bins",
            tables_dir=paper_tables_dir,
            figures_dir=paper_figs_dir,
            verbose=verbose,
        )
        results["summaries"]["generalization_bins"] = gen_bins_result

    # =========================================================================
    # Save master summary
    # =========================================================================
    summary_path = base_outdir / "top_roi_summary.json"
    with open(summary_path, "w") as f:
        json.dump(results, f, indent=2, default=str)

    # =========================================================================
    # Print final summary
    # =========================================================================
    if verbose:
        print("\n" + "=" * 70)
        print("ANALYSIS COMPLETE")
        print("=" * 70)

        print("\nOutputs written to:")
        print(f"  Artifacts: {base_outdir}/")
        print(f"  Figures:   {figures_dir}/")
        print(f"  Tables:    {tables_dir}/")
        print(f"  Summary:   {summary_path}")

        print("\n" + "-" * 70)
        print("KEY RESULTS SUMMARY")
        print("-" * 70)

        # FO parsimony gap
        if fo_budget and "metrics" in fo_budget:
            print("\nFO Parsimony Gap:")
            for model, m in sorted(
                fo_budget["metrics"].items(), key=lambda x: x[1].get("acc_25", 0), reverse=True
            )[:5]:
                print(
                    f"  {model:25s}: Validity={m['validity']*100:5.1f}%, "
                    f"Acc@+25={m['acc_25']*100:5.1f}%, Gap={m['gap']*100:5.1f}%"
                )

        # EC parsimony gap
        if ec_budget and "metrics" in ec_budget:
            print("\nEC Parsimony Gap:")
            for model, m in sorted(
                ec_budget["metrics"].items(), key=lambda x: x[1].get("acc_25", 0), reverse=True
            )[:5]:
                print(
                    f"  {model:25s}: Validity={m['validity']*100:5.1f}%, "
                    f"Acc@+25={m['acc_25']*100:5.1f}%, Gap={m['gap']*100:5.1f}%"
                )

        # CI failure modes
        if ci_failure and "overall_distribution" in ci_failure:
            print("\nCI Failure Modes:")
            for model, d in sorted(
                ci_failure["overall_distribution"].items(),
                key=lambda x: x[1].get("correct", 0),
                reverse=True,
            )[:5]:
                print(
                    f"  {model:25s}: Correct={d.get('correct', 0)*100:5.1f}%, "
                    f"YES-fail={d.get('yes_fail', 0)*100:5.1f}%, "
                    f"NO-fail={d.get('no_fail', 0)*100:5.1f}%"
                )

    return results


def main():
    parser = argparse.ArgumentParser(
        description="Run all top ROI analyses",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    # Run with default paths
    python -m concept_synth.analysis.run_top_roi_analyses \\
        --fo-dataset results/ad_benchmark/ad_benchmark_v1.yaml \\
        --ci-dataset results/c_benchmark/c_benchmark_v1.yaml \\
        --ec-dataset results/e_benchmark/e_benchmark_v1.yaml
    
    # Skip record dumping (use existing)
    python -m concept_synth.analysis.run_top_roi_analyses \\
        --fo-dataset ... --ci-dataset ... --ec-dataset ... \\
        --skip-dump
        """,
    )

    parser.add_argument("--fo-dataset", required=True, help="Path to FO (AD) benchmark YAML")
    parser.add_argument("--ci-dataset", required=True, help="Path to CI (C) benchmark YAML")
    parser.add_argument("--ec-dataset", required=True, help="Path to EC (E) benchmark YAML")
    parser.add_argument(
        "--outdir", default="concept_synth/artifacts/analysis/v1", help="Base output directory"
    )
    parser.add_argument(
        "--figures-dir", default="concept_synth/figures/induction", help="Directory for PDF figures"
    )
    parser.add_argument(
        "--tables-dir", default="concept_synth/tables/induction", help="Directory for LaTeX tables"
    )
    parser.add_argument("--skip-dump", action="store_true", help="Skip eval record dumping")
    parser.add_argument(
        "--extra-analyses-v1",
        action="store_true",
        default=True,
        help="Run extra v1 analyses (lift-hard, difficulty, bloat bins)",
    )
    parser.add_argument("--no-extra-analyses", action="store_true", help="Skip extra v1 analyses")
    parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output")

    args = parser.parse_args()

    run_all_analyses(
        fo_dataset=Path(args.fo_dataset),
        ci_dataset=Path(args.ci_dataset),
        ec_dataset=Path(args.ec_dataset),
        base_outdir=Path(args.outdir),
        figures_dir=Path(args.figures_dir),
        tables_dir=Path(args.tables_dir),
        skip_dump=args.skip_dump,
        extra_analyses_v1=args.extra_analyses_v1 and not args.no_extra_analyses,
        verbose=not args.quiet,
    )


if __name__ == "__main__":
    main()
