#!/usr/bin/env python3
"""
Difficulty Controls Validation Analysis

Validates that generation-time diagnostics (VS, kills) correlate with model success.
This empirically shows that our difficulty controls are meaningful.

Outputs:
- CSV: difficulty_correlations_{fullobs,ec}.csv
- LaTeX: tab_difficulty_correlations.tex
- Figures: fig_fullobs_difficulty_quintiles.pdf, fig_ec_difficulty_quintiles.pdf
"""

import argparse
import csv
import json
import os
import sys
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root
add_repo_root(__file__)

from concept_synth.io_utils import load_from_yaml

try:
    from scipy import stats

    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False

try:
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    HAS_MATPLOTLIB = True
except ImportError:
    HAS_MATPLOTLIB = False

# Model display names and order
MODEL_DISPLAY_NAMES = {
    "grok4": "Grok4",
    "gpt-5.2": "GPT-5.2",
    "grok4.1fast": "Grok4.1f",
    "gemini-3-pro-preview": "Gemini3",
    "deepseek-reasoner": "DSR",
    "claude-opus-4-5-20251101": "Opus4.5",
    "hermes4": "Hermes4",
    "gpt-4o": "GPT-4o",
}

MODEL_ORDER = [
    "grok4",
    "gpt-5.2",
    "grok4.1fast",
    "gemini-3-pro-preview",
    "deepseek-reasoner",
    "claude-opus-4-5-20251101",
    "hermes4",
    "gpt-4o",
]

# Colors for plotting
MODEL_COLORS = {
    "grok4": "#1f77b4",
    "gpt-5.2": "#ff7f0e",
    "grok4.1fast": "#2ca02c",
    "gemini-3-pro-preview": "#d62728",
    "deepseek-reasoner": "#9467bd",
    "claude-opus-4-5-20251101": "#8c564b",
    "hermes4": "#e377c2",
    "gpt-4o": "#7f7f7f",
}


@dataclass
class DifficultyCorrelation:
    """Correlation between difficulty metric and correctness."""

    model: str
    metric_name: str
    spearman_rho: float
    p_value: float
    n_instances: int


@dataclass
class QuintileAccuracy:
    """Accuracy in a quintile bin."""

    quintile: int  # 1-5
    n_instances: int
    accuracy: float
    metric_low: float
    metric_high: float


def extract_fo_difficulty_metrics(problems: List[Dict]) -> Dict[str, Dict[str, Any]]:
    """
    Extract difficulty metrics from FO problems.

    Returns dict mapping instance_id to metrics dict.
    """
    metrics = {}

    for prob in problems:
        desc = prob.get("problemDescription", {})
        instance_id = prob.get("problem", {}).get("instanceId", "")

        if not instance_id:
            continue

        # Extract VS metrics
        vs_total = desc.get("ad_vs_total")
        vs_tier1 = desc.get("ad_vs_tier1")
        vs_tier2 = desc.get("ad_vs_tier2")

        # Extract kill metrics
        min_killed = desc.get("ad_min_killed_per_world")
        mean_killed = desc.get("ad_mean_killed_per_world")
        total_killed = desc.get("ad_total_killed")

        metrics[instance_id] = {
            "vs_total": vs_total,
            "vs_tier1": vs_tier1,
            "vs_tier2": vs_tier2,
            "min_killed": min_killed,
            "mean_killed": mean_killed,
            "total_killed": total_killed,
        }

    return metrics


def extract_ec_difficulty_metrics(problems: List[Dict]) -> Dict[str, Dict[str, Any]]:
    """
    Extract difficulty metrics from EC problems.

    Note: EC may have different/fewer diagnostics than FO.
    """
    metrics = {}

    for prob in problems:
        desc = prob.get("problemDescription", {})
        instance_id = prob.get("problem", {}).get("instanceId", "")

        if not instance_id:
            continue

        # EC has fewer diagnostics - extract what's available
        metrics[instance_id] = {
            "gold_ast": desc.get("gold_ast"),
            "gold_qd": desc.get("gold_qd"),
            # Add any EC-specific diagnostics here if available
        }

    return metrics


def extract_model_results(problems: List[Dict], task: str) -> Dict[str, Dict[str, bool]]:
    """
    Extract correctness results per model per instance.

    Returns dict mapping instance_id -> {model -> correct}
    """
    results = defaultdict(dict)

    for prob in problems:
        instance_id = prob.get("problem", {}).get("instanceId", "")
        if not instance_id:
            continue

        for llm_result in prob.get("llmResults", []):
            model = llm_result.get("model", "")
            if not model:
                continue

            # Get evaluation result
            eval_result = llm_result.get("evaluation", {})

            if task == "fo":
                correct = eval_result.get("correct", False)
            elif task == "ec":
                # EC: valid AND budget OK
                correct = eval_result.get("correct", False)
                pred_ast = eval_result.get("predicted_ast")
                gold_ast = prob.get("problemDescription", {}).get("hiddenTarget", {}).get("astSize")
                if gold_ast and pred_ast:
                    budget_ok = pred_ast <= gold_ast + 25
                    correct = correct and budget_ok
            else:
                correct = eval_result.get("correct", False)

            results[instance_id][model] = correct

    return dict(results)


def compute_correlations(
    difficulty_metrics: Dict[str, Dict[str, Any]],
    model_results: Dict[str, Dict[str, bool]],
    metric_name: str,
) -> Dict[str, DifficultyCorrelation]:
    """
    Compute Spearman correlation between a difficulty metric and correctness.

    Returns dict mapping model -> DifficultyCorrelation
    """
    if not HAS_SCIPY:
        return {}

    correlations = {}

    # Get all models
    all_models = set()
    for results in model_results.values():
        all_models.update(results.keys())

    for model in all_models:
        # Build paired data
        metric_values = []
        correct_values = []

        for instance_id, metrics in difficulty_metrics.items():
            metric_val = metrics.get(metric_name)
            if metric_val is None:
                continue

            if instance_id not in model_results:
                continue
            if model not in model_results[instance_id]:
                continue

            metric_values.append(float(metric_val))
            correct_values.append(1 if model_results[instance_id][model] else 0)

        if len(metric_values) < 10:
            continue

        # Compute Spearman correlation
        # Check for constant arrays first
        if len(set(metric_values)) <= 1 or len(set(correct_values)) <= 1:
            # Constant input - correlation undefined
            continue

        rho, p_value = stats.spearmanr(metric_values, correct_values)

        # Skip if NaN
        if np.isnan(rho):
            continue

        correlations[model] = DifficultyCorrelation(
            model=model,
            metric_name=metric_name,
            spearman_rho=rho,
            p_value=p_value,
            n_instances=len(metric_values),
        )

    return correlations


def compute_quintile_accuracies(
    difficulty_metrics: Dict[str, Dict[str, Any]],
    model_results: Dict[str, Dict[str, bool]],
    metric_name: str,
) -> Dict[str, List[QuintileAccuracy]]:
    """
    Compute accuracy by quintile of a difficulty metric.

    Returns dict mapping model -> list of QuintileAccuracy
    """
    quintiles = {}

    # Get all models
    all_models = set()
    for results in model_results.values():
        all_models.update(results.keys())

    for model in all_models:
        # Build paired data
        data = []

        for instance_id, metrics in difficulty_metrics.items():
            metric_val = metrics.get(metric_name)
            if metric_val is None:
                continue

            if instance_id not in model_results:
                continue
            if model not in model_results[instance_id]:
                continue

            data.append(
                {
                    "metric": float(metric_val),
                    "correct": model_results[instance_id][model],
                }
            )

        if len(data) < 10:
            continue

        # Sort by metric and bin into quintiles
        data.sort(key=lambda x: x["metric"])
        n = len(data)
        quintile_size = n // 5

        model_quintiles = []
        for q in range(5):
            start = q * quintile_size
            end = start + quintile_size if q < 4 else n

            bin_data = data[start:end]
            if not bin_data:
                continue

            n_correct = sum(1 for d in bin_data if d["correct"])
            accuracy = n_correct / len(bin_data)

            model_quintiles.append(
                QuintileAccuracy(
                    quintile=q + 1,
                    n_instances=len(bin_data),
                    accuracy=accuracy,
                    metric_low=bin_data[0]["metric"],
                    metric_high=bin_data[-1]["metric"],
                )
            )

        quintiles[model] = model_quintiles

    return quintiles


def save_correlations_csv(
    correlations: Dict[str, Dict[str, DifficultyCorrelation]], output_path: Path
) -> None:
    """Save correlations to CSV."""
    output_path.parent.mkdir(parents=True, exist_ok=True)

    rows = []
    for metric_name, model_corrs in correlations.items():
        for model, corr in model_corrs.items():
            rows.append(
                {
                    "metric": metric_name,
                    "model": model,
                    "spearman_rho": f"{corr.spearman_rho:.4f}",
                    "p_value": f"{corr.p_value:.4f}",
                    "n_instances": corr.n_instances,
                }
            )

    with open(output_path, "w", newline="") as f:
        writer = csv.DictWriter(
            f, fieldnames=["metric", "model", "spearman_rho", "p_value", "n_instances"]
        )
        writer.writeheader()
        writer.writerows(rows)


def generate_correlations_latex(
    fo_correlations: Dict[str, Dict[str, DifficultyCorrelation]],
    ec_correlations: Dict[str, Dict[str, DifficultyCorrelation]],
    output_path: Path,
) -> None:
    """Generate LaTeX table for correlations."""
    output_path.parent.mkdir(parents=True, exist_ok=True)

    lines = []
    lines.append("% Difficulty Correlations Table (auto-generated)")
    lines.append("\\begin{table}[t]")
    lines.append("\\centering")
    lines.append(
        "\\caption{\\textbf{Correlation between generation diagnostics and model success.}"
    )
    lines.append("Spearman $\\rho$ between difficulty metrics and correctness (0/1).")
    lines.append("VS\\_final = number of hypotheses remaining after all worlds;")
    lines.append("Kill\\_mean = mean hypotheses killed per world.")
    lines.append(
        "Negative correlations indicate harder instances (higher metric) have lower accuracy.}"
    )
    lines.append("\\label{tab:difficulty_correlations}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}l|rr|rr@{}}")
    lines.append("\\toprule")
    lines.append(" & \\multicolumn{2}{c|}{FullObs} & \\multicolumn{2}{c}{EC} \\\\")
    lines.append("Model & VS\\_final & Kill\\_mean & VS\\_final & Kill\\_mean \\\\")
    lines.append("\\midrule")

    # Get all models
    all_models = set()
    for metric_corrs in fo_correlations.values():
        all_models.update(metric_corrs.keys())
    for metric_corrs in ec_correlations.values():
        all_models.update(metric_corrs.keys())

    models = [m for m in MODEL_ORDER if m in all_models]

    for model in models:
        display = MODEL_DISPLAY_NAMES.get(model, model)
        row = [display]

        # FO VS_final
        fo_vs = fo_correlations.get("vs_total", {}).get(model)
        if fo_vs:
            sig = "*" if fo_vs.p_value < 0.05 else ""
            row.append(f"{fo_vs.spearman_rho:.2f}{sig}")
        else:
            row.append("---")

        # FO Kill_mean
        fo_kill = fo_correlations.get("mean_killed", {}).get(model)
        if fo_kill:
            sig = "*" if fo_kill.p_value < 0.05 else ""
            row.append(f"{fo_kill.spearman_rho:.2f}{sig}")
        else:
            row.append("---")

        # EC VS_final (may not be available)
        ec_vs = ec_correlations.get("vs_total", {}).get(model)
        if ec_vs:
            sig = "*" if ec_vs.p_value < 0.05 else ""
            row.append(f"{ec_vs.spearman_rho:.2f}{sig}")
        else:
            row.append("N/A")

        # EC Kill_mean (may not be available)
        ec_kill = ec_correlations.get("mean_killed", {}).get(model)
        if ec_kill:
            sig = "*" if ec_kill.p_value < 0.05 else ""
            row.append(f"{ec_kill.spearman_rho:.2f}{sig}")
        else:
            row.append("N/A")

        lines.append(" & ".join(row) + " \\\\")

    lines.append("\\bottomrule")
    lines.append("\\multicolumn{5}{l}{\\footnotesize * $p < 0.05$}")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")

    with open(output_path, "w") as f:
        f.write("\n".join(lines))


def generate_quintile_figure(
    quintiles: Dict[str, List[QuintileAccuracy]],
    metric_name: str,
    task_name: str,
    output_path: Path,
) -> None:
    """Generate quintile accuracy figure."""
    if not HAS_MATPLOTLIB:
        return

    output_path.parent.mkdir(parents=True, exist_ok=True)

    fig, ax = plt.subplots(figsize=(8, 5))

    x = [1, 2, 3, 4, 5]

    models = [m for m in MODEL_ORDER if m in quintiles]

    for model in models:
        model_quints = quintiles[model]
        y = [q.accuracy * 100 for q in model_quints]

        if len(y) < 5:
            continue

        color = MODEL_COLORS.get(model, "#333333")
        label = MODEL_DISPLAY_NAMES.get(model, model)

        ax.plot(x, y, "o-", color=color, label=label, linewidth=2, markersize=6)

    ax.set_xlabel(f"{metric_name} Quintile (low → high)", fontsize=12)
    ax.set_ylabel("Accuracy (%)", fontsize=12)
    ax.set_title(f"{task_name}: Accuracy by {metric_name} Quintile", fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(["Q1\n(easiest)", "Q2", "Q3", "Q4", "Q5\n(hardest)"])
    ax.set_ylim(0, 100)
    ax.legend(loc="best", fontsize=9)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()


def run_difficulty_validation(
    fo_dataset: Optional[Path],
    ec_dataset: Optional[Path],
    outdir: Path,
    tables_dir: Path,
    figures_dir: Path,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Run difficulty validation analysis.

    Args:
        fo_dataset: Path to FO benchmark YAML
        ec_dataset: Path to EC benchmark YAML
        outdir: Output directory for CSV
        tables_dir: Directory for LaTeX tables
        figures_dir: Directory for figures
        verbose: Print progress

    Returns:
        Summary dict
    """
    results = {
        "fo": {},
        "ec": {},
        "outputs": {},
    }

    # =========================================================================
    # FullObs Analysis
    # =========================================================================
    if fo_dataset and fo_dataset.exists():
        if verbose:
            print("[Difficulty Validation] Processing FullObs...")

        fo_problems = load_from_yaml(fo_dataset)
        fo_difficulty = extract_fo_difficulty_metrics(fo_problems)
        fo_results = extract_model_results(fo_problems, "fo")

        if verbose:
            print(f"  Found {len(fo_difficulty)} instances with difficulty metrics")

        # Compute correlations for available metrics
        fo_correlations = {}
        for metric_name in ["vs_total", "mean_killed", "min_killed"]:
            corrs = compute_correlations(fo_difficulty, fo_results, metric_name)
            if corrs:
                fo_correlations[metric_name] = corrs

        results["fo"]["correlations"] = {
            metric: {
                model: {"rho": c.spearman_rho, "p": c.p_value, "n": c.n_instances}
                for model, c in corrs.items()
            }
            for metric, corrs in fo_correlations.items()
        }

        # Save FO correlations CSV
        csv_path = outdir / "difficulty_correlations_fullobs.csv"
        save_correlations_csv(fo_correlations, csv_path)
        results["outputs"]["fo_csv"] = str(csv_path)
        if verbose:
            print(f"  Saved CSV: {csv_path}")

        # Compute quintile accuracies and generate figures
        for metric_name, display_name in [("vs_total", "VS_final"), ("mean_killed", "Kill_mean")]:
            quintiles = compute_quintile_accuracies(fo_difficulty, fo_results, metric_name)

            if quintiles and HAS_MATPLOTLIB:
                fig_path = figures_dir / f"fig_fullobs_difficulty_{metric_name}.pdf"
                generate_quintile_figure(quintiles, display_name, "FullObs", fig_path)
                results["outputs"][f"fo_fig_{metric_name}"] = str(fig_path)
                if verbose:
                    print(f"  Saved figure: {fig_path}")
    else:
        fo_correlations = {}

    # =========================================================================
    # EC Analysis
    # =========================================================================
    ec_correlations = {}
    if ec_dataset and ec_dataset.exists():
        if verbose:
            print("[Difficulty Validation] Processing EC...")

        ec_problems = load_from_yaml(ec_dataset)
        ec_difficulty = extract_ec_difficulty_metrics(ec_problems)
        ec_results = extract_model_results(ec_problems, "ec")

        if verbose:
            print(f"  Found {len(ec_difficulty)} instances")

        # EC has limited diagnostics - compute what's available
        for metric_name in ["gold_ast", "gold_qd"]:
            corrs = compute_correlations(ec_difficulty, ec_results, metric_name)
            if corrs:
                ec_correlations[metric_name] = corrs

        results["ec"]["correlations"] = {
            metric: {
                model: {"rho": c.spearman_rho, "p": c.p_value, "n": c.n_instances}
                for model, c in corrs.items()
            }
            for metric, corrs in ec_correlations.items()
        }

        # Save EC correlations CSV
        csv_path = outdir / "difficulty_correlations_ec.csv"
        save_correlations_csv(ec_correlations, csv_path)
        results["outputs"]["ec_csv"] = str(csv_path)
        if verbose:
            print(f"  Saved CSV: {csv_path}")

    # =========================================================================
    # Combined LaTeX Table
    # =========================================================================
    tex_path = tables_dir / "tab_difficulty_correlations.tex"
    generate_correlations_latex(fo_correlations, ec_correlations, tex_path)
    results["outputs"]["latex"] = str(tex_path)
    if verbose:
        print(f"[Difficulty Validation] Saved LaTeX: {tex_path}")

    return results


def main():
    parser = argparse.ArgumentParser(
        description="Difficulty Controls Validation Analysis",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    parser.add_argument("--fo-dataset", help="Path to FO benchmark YAML")
    parser.add_argument("--ec-dataset", help="Path to EC benchmark YAML")
    parser.add_argument(
        "--outdir",
        "-o",
        default="concept_synth/artifacts/analysis/v1/difficulty",
        help="Output directory for CSV",
    )
    parser.add_argument(
        "--tables-dir", default="concept_synth/paper/auto/tables", help="Directory for LaTeX tables"
    )
    parser.add_argument(
        "--figures-dir", default="concept_synth/paper/auto/figs", help="Directory for figures"
    )
    parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output")

    args = parser.parse_args()

    run_difficulty_validation(
        fo_dataset=Path(args.fo_dataset) if args.fo_dataset else None,
        ec_dataset=Path(args.ec_dataset) if args.ec_dataset else None,
        outdir=Path(args.outdir),
        tables_dir=Path(args.tables_dir),
        figures_dir=Path(args.figures_dir),
        verbose=not args.quiet,
    )


if __name__ == "__main__":
    main()
