#!/usr/bin/env python3
"""
Held-Out Generalization Analysis

Evaluates model predictions on held-out worlds to measure generalization:

For FullObs (FO):
- Generate N=5 additional IID holdout worlds per instance (same sampler, no CEGIS)
- Label by gold φ*
- Compute holdout world accuracy and FP/FN rates
- Analyze as function of AST delta for valid vs invalid predictions

For CI:
- Generate M=3 YES holdouts labeled by gold φ*
- Generate N=2 NO holdouts via trap mechanism
- Evaluate holdout YES exact-match and holdout NO "avoid exact match" rate
- Analyze conditional on training correctness

Usage:
    python -m concept_synth.analysis.holdout_generalization \
        --fo-dataset results/ad_benchmark/ad_benchmark_v1.yaml \
        --ci-dataset results/c_benchmark/c_benchmark_v1.yaml \
        --out artifacts/analysis/v1/holdout/
"""

import argparse
import json
import os
import random
import sys
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root
add_repo_root(__file__)

from concept_synth.analysis.per_world_fpfn import (
    WorldMistakeProfile,
    compute_world_profile,
    format_aligned_table,
)
from concept_synth.evaluate_ad import build_model_from_world
from concept_synth.fol.formulas import FOFormula
from concept_synth.fol.model import FiniteModel
from concept_synth.io_utils import load_from_yaml
from concept_synth.metrics import ast_size
from concept_synth.sexpr_parser import parse_sexpr_formula
from concept_synth.sexpr_printer import to_sexpr
from concept_synth.target import compute_target_extension
from concept_synth.worldgen import WorldGenConfig, generate_world, model_to_world_view

# =============================================================================
# Data Classes
# =============================================================================


@dataclass
class HoldoutWorldResult:
    """Result for a single holdout world."""

    world_id: str
    world_type: str  # 'holdout_yes', 'holdout_no'
    domain_size: int
    exact_match: bool
    fp_count: int
    fn_count: int
    fp_rate: float
    fn_rate: float
    accuracy: float
    margin: Optional[int] = None  # For NO worlds


@dataclass
class HoldoutInstanceResult:
    """Holdout evaluation for a single instance-model pair."""

    instance_id: str
    model: str
    band: str
    task: str

    # Training performance
    train_correct: bool
    train_parse_ok: bool
    pred_ast: Optional[int]
    gold_ast: int
    ast_delta: Optional[int]

    # Holdout results
    holdout_worlds: List[HoldoutWorldResult]
    holdout_exact_match_rate: float
    holdout_mean_accuracy: float
    holdout_mean_fp_rate: float
    holdout_mean_fn_rate: float

    # CI-specific
    holdout_yes_exact_rate: Optional[float] = None
    holdout_no_avoid_rate: Optional[float] = None  # Rate of avoiding exact match on NO
    holdout_no_mean_margin: Optional[float] = None


@dataclass
class HoldoutModelAggregates:
    """Aggregated holdout results for a model."""

    model: str
    task: str

    # Counts
    total_instances: int
    train_correct_count: int
    train_incorrect_count: int

    # Holdout metrics (overall)
    mean_holdout_exact_rate: float
    mean_holdout_accuracy: float
    mean_holdout_fp_rate: float
    mean_holdout_fn_rate: float

    # Conditional on training correctness
    holdout_exact_given_train_correct: Optional[float]
    holdout_exact_given_train_incorrect: Optional[float]

    # By AST delta bins
    holdout_by_ast_delta: Dict[str, Dict[str, float]]

    # By AST ratio terciles (for valid formulas only)
    # r = pred_ast / gold_ast
    holdout_by_ast_ratio: Dict[str, Dict[str, float]] = field(default_factory=dict)

    # CI-specific
    holdout_yes_exact_rate: Optional[float] = None
    holdout_no_avoid_rate: Optional[float] = None


# =============================================================================
# World Generation for Holdout
# =============================================================================


def generate_holdout_worlds_fo(
    gold_formula: FOFormula,
    n_worlds: int,
    domain_sizes: List[int],
    rng: random.Random,
    min_ratio: float = 0.15,
    max_ratio: float = 0.85,
    max_attempts: int = 50,
) -> List[Tuple[FiniteModel, Dict[str, Any]]]:
    """
    Generate IID holdout worlds for FO (no CEGIS conditioning).

    Args:
        gold_formula: The gold formula φ*
        n_worlds: Number of holdout worlds to generate
        domain_sizes: List of domain sizes to sample from
        rng: Random number generator
        min_ratio: Minimum T_true fraction
        max_ratio: Maximum T_true fraction
        max_attempts: Max attempts per world

    Returns:
        List of (model, world_dict) tuples
    """
    worlds = []

    for i in range(n_worlds):
        domain_size = rng.choice(domain_sizes)

        for _ in range(max_attempts):
            cfg = WorldGenConfig(
                domain_size=domain_size,
                mode="regular_outdegree",
                out_degree_R=2,
                out_degree_S=2,
                min_unary_true_frac=min_ratio,
                max_unary_true_frac=max_ratio,
            )
            model, _ = generate_world(rng, cfg)

            # Compute target extension using gold formula
            target = compute_target_extension(model, gold_formula)

            # Check balance
            n = model.n
            t_true_count = len(target.T_true)
            ratio = t_true_count / n if n > 0 else 0

            if min_ratio <= ratio <= max_ratio:
                world_dict = model_to_world_view(model, f"holdout_{i}", list(target.T_true))
                worlds.append((model, world_dict))
                break

    return worlds


def generate_holdout_worlds_ci(
    gold_formula: FOFormula,
    n_yes: int,
    n_no: int,
    domain_sizes: List[int],
    rng: random.Random,
    trap_formulas: Optional[List[FOFormula]] = None,
    min_ratio: float = 0.15,
    max_ratio: float = 0.85,
    max_attempts: int = 50,
) -> Tuple[List[Tuple[FiniteModel, Dict[str, Any]]], List[Tuple[FiniteModel, Dict[str, Any]]]]:
    """
    Generate holdout worlds for CI.

    YES holdouts: IID worlds labeled by gold φ*
    NO holdouts: Worlds where gold doesn't match but some trap might

    Args:
        gold_formula: The gold formula φ*
        n_yes: Number of YES holdout worlds
        n_no: Number of NO holdout worlds
        domain_sizes: Domain sizes to sample from
        rng: Random number generator
        trap_formulas: Optional list of trap formulas for NO world generation
        min_ratio: Minimum T_true fraction
        max_ratio: Maximum T_true fraction
        max_attempts: Max attempts per world

    Returns:
        (yes_worlds, no_worlds) tuples
    """
    # Generate YES holdouts (same as FO)
    yes_worlds = generate_holdout_worlds_fo(
        gold_formula, n_yes, domain_sizes, rng, min_ratio, max_ratio, max_attempts
    )

    # Generate NO holdouts
    # Strategy: Generate worlds where gold doesn't produce the same extension
    # as some alternative labeling (simulating a trap scenario)
    no_worlds = []

    for i in range(n_no):
        domain_size = rng.choice(domain_sizes)

        for _ in range(max_attempts):
            cfg = WorldGenConfig(
                domain_size=domain_size,
                mode="regular_outdegree",
                out_degree_R=2,
                out_degree_S=2,
                min_unary_true_frac=min_ratio,
                max_unary_true_frac=max_ratio,
            )
            model, _ = generate_world(rng, cfg)

            # Compute gold's extension
            gold_target = compute_target_extension(model, gold_formula)

            # For NO world, we want a different target extension
            # Simple approach: perturb the gold extension
            domain = model.domain_constants()
            gold_t_true = set(gold_target.T_true)

            # Create a perturbed target (flip some elements)
            perturbed_t_true = set(gold_t_true)
            n_to_flip = max(1, len(domain) // 4)

            # Flip some from true to false
            if perturbed_t_true:
                to_remove = rng.sample(
                    list(perturbed_t_true), min(n_to_flip, len(perturbed_t_true))
                )
                perturbed_t_true -= set(to_remove)

            # Flip some from false to true
            false_elements = set(domain) - gold_t_true
            if false_elements:
                to_add = rng.sample(list(false_elements), min(n_to_flip, len(false_elements)))
                perturbed_t_true |= set(to_add)

            # Check that perturbed is different from gold
            if perturbed_t_true != gold_t_true:
                # Check balance
                ratio = len(perturbed_t_true) / len(domain) if domain else 0
                if min_ratio <= ratio <= max_ratio:
                    world_dict = model_to_world_view(
                        model, f"holdout_no_{i}", list(perturbed_t_true)
                    )
                    world_dict["worldType"] = "NO"
                    no_worlds.append((model, world_dict))
                    break

    return yes_worlds, no_worlds


# =============================================================================
# Holdout Evaluation
# =============================================================================


def evaluate_holdout_fo(
    problem: Dict[str, Any],
    llm_result: Dict[str, Any],
    n_holdout: int = 5,
    holdout_seed_offset: int = 10000,
) -> Optional[HoldoutInstanceResult]:
    """
    Evaluate a FO instance on holdout worlds.

    Args:
        problem: Problem dict
        llm_result: LLM result dict
        n_holdout: Number of holdout worlds to generate
        holdout_seed_offset: Seed offset for holdout generation

    Returns:
        HoldoutInstanceResult or None
    """
    prob = problem.get("problem", {})
    desc = problem.get("problemDescription", {})

    instance_id = prob.get("instanceId", "unknown")
    model_name = llm_result.get("model", "unknown")
    band = desc.get("ad_band", "unknown")

    # Get gold formula
    gold_sexpr = desc.get("hiddenTarget", {}).get("formula", "")
    gold_ast = desc.get("hiddenTarget", {}).get("astSize", 0)

    if not gold_sexpr:
        return None

    try:
        gold_formula = parse_sexpr_formula(gold_sexpr)
    except Exception:
        return None

    # Parse predicted formula
    extracted = llm_result.get("extractedFormula", "")
    pred_formula = None
    pred_ast = None
    train_parse_ok = False

    if extracted:
        try:
            pred_formula = parse_sexpr_formula(extracted)
            pred_ast = ast_size(pred_formula)
            train_parse_ok = True
        except Exception:
            pass

    ast_delta = pred_ast - gold_ast if pred_ast is not None else None

    # Determine training correctness (from existing evaluation)
    # We'll check if formula matches all training worlds
    train_correct = False
    if pred_formula is not None:
        train_worlds = prob.get("worlds", [])
        all_match = True
        for world_dict in train_worlds:
            domain = world_dict.get("domain", [])
            model = build_model_from_world(world_dict)
            target_ext = world_dict.get("targetExtension", {})
            gold_t_true = set(target_ext.get("T_true", []))

            try:
                pred_target = compute_target_extension(model, pred_formula)
                pred_t_true = set(pred_target.T_true)
                if pred_t_true != gold_t_true:
                    all_match = False
                    break
            except Exception:
                all_match = False
                break

        train_correct = all_match

    # Generate holdout worlds
    # Use instance_id hash + offset for deterministic seed
    holdout_seed = hash(instance_id) + holdout_seed_offset
    rng = random.Random(holdout_seed)

    # Get domain sizes from training worlds
    train_worlds = prob.get("worlds", [])
    domain_sizes = list(set(w.get("domainSize", 7) for w in train_worlds))
    if not domain_sizes:
        domain_sizes = [7, 8, 9]

    holdout_worlds = generate_holdout_worlds_fo(gold_formula, n_holdout, domain_sizes, rng)

    if not holdout_worlds:
        return None

    # Evaluate on holdout worlds
    holdout_results = []

    for model, world_dict in holdout_worlds:
        world_id = world_dict.get("worldId", "")
        domain = world_dict.get("domain", [])
        domain_size = len(domain)

        target_ext = world_dict.get("targetExtension", {})
        gold_t_true = set(target_ext.get("T_true", []))

        if pred_formula is None:
            # No prediction - count as all wrong
            holdout_results.append(
                HoldoutWorldResult(
                    world_id=world_id,
                    world_type="holdout_yes",
                    domain_size=domain_size,
                    exact_match=False,
                    fp_count=0,
                    fn_count=len(gold_t_true),
                    fp_rate=0.0,
                    fn_rate=len(gold_t_true) / domain_size if domain_size > 0 else 0.0,
                    accuracy=(
                        (domain_size - len(gold_t_true)) / domain_size if domain_size > 0 else 0.0
                    ),
                )
            )
            continue

        try:
            pred_target = compute_target_extension(model, pred_formula)
            pred_t_true = set(pred_target.T_true)
        except Exception:
            holdout_results.append(
                HoldoutWorldResult(
                    world_id=world_id,
                    world_type="holdout_yes",
                    domain_size=domain_size,
                    exact_match=False,
                    fp_count=0,
                    fn_count=len(gold_t_true),
                    fp_rate=0.0,
                    fn_rate=len(gold_t_true) / domain_size if domain_size > 0 else 0.0,
                    accuracy=(
                        (domain_size - len(gold_t_true)) / domain_size if domain_size > 0 else 0.0
                    ),
                )
            )
            continue

        # Compute metrics
        exact_match = pred_t_true == gold_t_true
        fp = len(pred_t_true - gold_t_true)
        fn = len(gold_t_true - pred_t_true)
        tp = len(pred_t_true & gold_t_true)
        tn = domain_size - tp - fp - fn

        fp_rate = fp / domain_size if domain_size > 0 else 0.0
        fn_rate = fn / domain_size if domain_size > 0 else 0.0
        accuracy = (tp + tn) / domain_size if domain_size > 0 else 0.0

        holdout_results.append(
            HoldoutWorldResult(
                world_id=world_id,
                world_type="holdout_yes",
                domain_size=domain_size,
                exact_match=exact_match,
                fp_count=fp,
                fn_count=fn,
                fp_rate=fp_rate,
                fn_rate=fn_rate,
                accuracy=accuracy,
            )
        )

    # Aggregate holdout metrics
    if holdout_results:
        holdout_exact_rate = sum(1 for r in holdout_results if r.exact_match) / len(holdout_results)
        holdout_mean_acc = sum(r.accuracy for r in holdout_results) / len(holdout_results)
        holdout_mean_fp = sum(r.fp_rate for r in holdout_results) / len(holdout_results)
        holdout_mean_fn = sum(r.fn_rate for r in holdout_results) / len(holdout_results)
    else:
        holdout_exact_rate = holdout_mean_acc = holdout_mean_fp = holdout_mean_fn = 0.0

    return HoldoutInstanceResult(
        instance_id=instance_id,
        model=model_name,
        band=band,
        task="fo",
        train_correct=train_correct,
        train_parse_ok=train_parse_ok,
        pred_ast=pred_ast,
        gold_ast=gold_ast,
        ast_delta=ast_delta,
        holdout_worlds=holdout_results,
        holdout_exact_match_rate=holdout_exact_rate,
        holdout_mean_accuracy=holdout_mean_acc,
        holdout_mean_fp_rate=holdout_mean_fp,
        holdout_mean_fn_rate=holdout_mean_fn,
    )


def evaluate_holdout_ci(
    problem: Dict[str, Any],
    llm_result: Dict[str, Any],
    n_yes_holdout: int = 3,
    n_no_holdout: int = 2,
    holdout_seed_offset: int = 20000,
) -> Optional[HoldoutInstanceResult]:
    """
    Evaluate a CI instance on holdout worlds.

    Args:
        problem: Problem dict
        llm_result: LLM result dict
        n_yes_holdout: Number of YES holdout worlds
        n_no_holdout: Number of NO holdout worlds
        holdout_seed_offset: Seed offset for holdout generation

    Returns:
        HoldoutInstanceResult or None
    """
    prob = problem.get("problem", {})
    desc = problem.get("problemDescription", {})

    instance_id = prob.get("instanceId", "unknown")
    model_name = llm_result.get("model", "unknown")
    band = desc.get("c_band", "unknown")

    # Get gold formula
    gold_sexpr = desc.get("hiddenTarget", {}).get("formula", "")
    gold_ast = desc.get("hiddenTarget", {}).get("astSize", 0)

    if not gold_sexpr:
        return None

    try:
        gold_formula = parse_sexpr_formula(gold_sexpr)
    except Exception:
        return None

    # Parse predicted formula
    extracted = llm_result.get("extractedFormula", "")
    pred_formula = None
    pred_ast = None
    train_parse_ok = False

    if extracted:
        try:
            pred_formula = parse_sexpr_formula(extracted)
            pred_ast = ast_size(pred_formula)
            train_parse_ok = True
        except Exception:
            pass

    ast_delta = pred_ast - gold_ast if pred_ast is not None else None

    # Determine training correctness
    train_correct = False
    if pred_formula is not None:
        train_worlds = prob.get("worlds", [])
        yes_all_match = True
        no_all_avoid = True

        for world_dict in train_worlds:
            domain = world_dict.get("domain", [])
            model = build_model_from_world(world_dict)
            target_ext = world_dict.get("targetExtension", {})
            gold_t_true = set(target_ext.get("T_true", []))

            # Determine world type
            world_type = world_dict.get("worldType")
            if world_type is None:
                split_label = world_dict.get("splitLabel", "")
                world_id = world_dict.get("worldId", "")
                if split_label:
                    world_type = split_label.upper()
                elif world_id.lower().startswith("yes"):
                    world_type = "YES"
                elif world_id.lower().startswith("no"):
                    world_type = "NO"
                else:
                    world_type = "YES"

            try:
                pred_target = compute_target_extension(model, pred_formula)
                pred_t_true = set(pred_target.T_true)

                if world_type == "YES":
                    if pred_t_true != gold_t_true:
                        yes_all_match = False
                else:  # NO
                    if pred_t_true == gold_t_true:
                        no_all_avoid = False
            except Exception:
                if world_type == "YES":
                    yes_all_match = False

        train_correct = yes_all_match and no_all_avoid

    # Generate holdout worlds
    holdout_seed = hash(instance_id) + holdout_seed_offset
    rng = random.Random(holdout_seed)

    # Get domain sizes from training worlds
    train_worlds = prob.get("worlds", [])
    domain_sizes = list(set(w.get("domainSize", 7) for w in train_worlds))
    if not domain_sizes:
        domain_sizes = [7, 8, 9]

    yes_holdouts, no_holdouts = generate_holdout_worlds_ci(
        gold_formula, n_yes_holdout, n_no_holdout, domain_sizes, rng
    )

    # Evaluate on holdout worlds
    holdout_results = []

    # YES holdouts
    for model, world_dict in yes_holdouts:
        world_id = world_dict.get("worldId", "")
        domain = world_dict.get("domain", [])
        domain_size = len(domain)

        target_ext = world_dict.get("targetExtension", {})
        gold_t_true = set(target_ext.get("T_true", []))

        if pred_formula is None:
            holdout_results.append(
                HoldoutWorldResult(
                    world_id=world_id,
                    world_type="holdout_yes",
                    domain_size=domain_size,
                    exact_match=False,
                    fp_count=0,
                    fn_count=len(gold_t_true),
                    fp_rate=0.0,
                    fn_rate=len(gold_t_true) / domain_size if domain_size > 0 else 0.0,
                    accuracy=(
                        (domain_size - len(gold_t_true)) / domain_size if domain_size > 0 else 0.0
                    ),
                )
            )
            continue

        try:
            pred_target = compute_target_extension(model, pred_formula)
            pred_t_true = set(pred_target.T_true)
        except Exception:
            holdout_results.append(
                HoldoutWorldResult(
                    world_id=world_id,
                    world_type="holdout_yes",
                    domain_size=domain_size,
                    exact_match=False,
                    fp_count=0,
                    fn_count=len(gold_t_true),
                    fp_rate=0.0,
                    fn_rate=len(gold_t_true) / domain_size if domain_size > 0 else 0.0,
                    accuracy=(
                        (domain_size - len(gold_t_true)) / domain_size if domain_size > 0 else 0.0
                    ),
                )
            )
            continue

        exact_match = pred_t_true == gold_t_true
        fp = len(pred_t_true - gold_t_true)
        fn = len(gold_t_true - pred_t_true)
        tp = len(pred_t_true & gold_t_true)
        tn = domain_size - tp - fp - fn

        holdout_results.append(
            HoldoutWorldResult(
                world_id=world_id,
                world_type="holdout_yes",
                domain_size=domain_size,
                exact_match=exact_match,
                fp_count=fp,
                fn_count=fn,
                fp_rate=fp / domain_size if domain_size > 0 else 0.0,
                fn_rate=fn / domain_size if domain_size > 0 else 0.0,
                accuracy=(tp + tn) / domain_size if domain_size > 0 else 0.0,
            )
        )

    # NO holdouts
    for model, world_dict in no_holdouts:
        world_id = world_dict.get("worldId", "")
        domain = world_dict.get("domain", [])
        domain_size = len(domain)

        target_ext = world_dict.get("targetExtension", {})
        gold_t_true = set(target_ext.get("T_true", []))

        if pred_formula is None:
            # No prediction - can't match, so technically "avoids"
            holdout_results.append(
                HoldoutWorldResult(
                    world_id=world_id,
                    world_type="holdout_no",
                    domain_size=domain_size,
                    exact_match=False,
                    fp_count=0,
                    fn_count=len(gold_t_true),
                    fp_rate=0.0,
                    fn_rate=len(gold_t_true) / domain_size if domain_size > 0 else 0.0,
                    accuracy=0.0,
                    margin=len(gold_t_true),  # All FN
                )
            )
            continue

        try:
            pred_target = compute_target_extension(model, pred_formula)
            pred_t_true = set(pred_target.T_true)
        except Exception:
            holdout_results.append(
                HoldoutWorldResult(
                    world_id=world_id,
                    world_type="holdout_no",
                    domain_size=domain_size,
                    exact_match=False,
                    fp_count=0,
                    fn_count=len(gold_t_true),
                    fp_rate=0.0,
                    fn_rate=len(gold_t_true) / domain_size if domain_size > 0 else 0.0,
                    accuracy=0.0,
                    margin=len(gold_t_true),
                )
            )
            continue

        exact_match = pred_t_true == gold_t_true
        fp = len(pred_t_true - gold_t_true)
        fn = len(gold_t_true - pred_t_true)
        margin = fp + fn  # Total mismatches

        holdout_results.append(
            HoldoutWorldResult(
                world_id=world_id,
                world_type="holdout_no",
                domain_size=domain_size,
                exact_match=exact_match,
                fp_count=fp,
                fn_count=fn,
                fp_rate=fp / domain_size if domain_size > 0 else 0.0,
                fn_rate=fn / domain_size if domain_size > 0 else 0.0,
                accuracy=0.0,  # Not meaningful for NO worlds
                margin=margin,
            )
        )

    # Aggregate metrics
    yes_results = [r for r in holdout_results if r.world_type == "holdout_yes"]
    no_results = [r for r in holdout_results if r.world_type == "holdout_no"]

    if yes_results:
        holdout_yes_exact = sum(1 for r in yes_results if r.exact_match) / len(yes_results)
    else:
        holdout_yes_exact = 0.0

    if no_results:
        holdout_no_avoid = sum(1 for r in no_results if not r.exact_match) / len(no_results)
        holdout_no_margin = sum(r.margin or 0 for r in no_results) / len(no_results)
    else:
        holdout_no_avoid = 0.0
        holdout_no_margin = 0.0

    # Overall metrics
    if holdout_results:
        holdout_exact_rate = sum(1 for r in holdout_results if r.exact_match) / len(holdout_results)
        holdout_mean_acc = (
            sum(r.accuracy for r in yes_results) / len(yes_results) if yes_results else 0.0
        )
        holdout_mean_fp = sum(r.fp_rate for r in holdout_results) / len(holdout_results)
        holdout_mean_fn = sum(r.fn_rate for r in holdout_results) / len(holdout_results)
    else:
        holdout_exact_rate = holdout_mean_acc = holdout_mean_fp = holdout_mean_fn = 0.0

    return HoldoutInstanceResult(
        instance_id=instance_id,
        model=model_name,
        band=band,
        task="ci",
        train_correct=train_correct,
        train_parse_ok=train_parse_ok,
        pred_ast=pred_ast,
        gold_ast=gold_ast,
        ast_delta=ast_delta,
        holdout_worlds=holdout_results,
        holdout_exact_match_rate=holdout_exact_rate,
        holdout_mean_accuracy=holdout_mean_acc,
        holdout_mean_fp_rate=holdout_mean_fp,
        holdout_mean_fn_rate=holdout_mean_fn,
        holdout_yes_exact_rate=holdout_yes_exact,
        holdout_no_avoid_rate=holdout_no_avoid,
        holdout_no_mean_margin=holdout_no_margin,
    )


# =============================================================================
# Aggregation
# =============================================================================


def aggregate_holdout_results(
    results: List[HoldoutInstanceResult], task: str
) -> Dict[str, HoldoutModelAggregates]:
    """Aggregate holdout results by model."""
    by_model = defaultdict(list)
    for r in results:
        by_model[r.model].append(r)

    aggregates = {}
    for model, model_results in by_model.items():
        total = len(model_results)
        train_correct = [r for r in model_results if r.train_correct]
        train_incorrect = [r for r in model_results if not r.train_correct and r.train_parse_ok]

        # Overall holdout metrics
        mean_exact = (
            sum(r.holdout_exact_match_rate for r in model_results) / total if total > 0 else 0.0
        )
        mean_acc = sum(r.holdout_mean_accuracy for r in model_results) / total if total > 0 else 0.0
        mean_fp = sum(r.holdout_mean_fp_rate for r in model_results) / total if total > 0 else 0.0
        mean_fn = sum(r.holdout_mean_fn_rate for r in model_results) / total if total > 0 else 0.0

        # Conditional on training correctness
        if train_correct:
            exact_given_correct = sum(r.holdout_exact_match_rate for r in train_correct) / len(
                train_correct
            )
        else:
            exact_given_correct = None

        if train_incorrect:
            exact_given_incorrect = sum(r.holdout_exact_match_rate for r in train_incorrect) / len(
                train_incorrect
            )
        else:
            exact_given_incorrect = None

        # By AST delta bins
        ast_bins = {
            "delta<0": [],
            "delta=0": [],
            "delta_1-10": [],
            "delta_11-25": [],
            "delta>25": [],
        }

        for r in model_results:
            if r.ast_delta is None:
                continue
            if r.ast_delta < 0:
                ast_bins["delta<0"].append(r)
            elif r.ast_delta == 0:
                ast_bins["delta=0"].append(r)
            elif r.ast_delta <= 10:
                ast_bins["delta_1-10"].append(r)
            elif r.ast_delta <= 25:
                ast_bins["delta_11-25"].append(r)
            else:
                ast_bins["delta>25"].append(r)

        holdout_by_ast = {}
        for bin_name, bin_results in ast_bins.items():
            if bin_results:
                holdout_by_ast[bin_name] = {
                    "count": len(bin_results),
                    "mean_exact": sum(r.holdout_exact_match_rate for r in bin_results)
                    / len(bin_results),
                    "mean_acc": sum(r.holdout_mean_accuracy for r in bin_results)
                    / len(bin_results),
                }

        # By AST ratio terciles (for valid/train-correct formulas only)
        # r = pred_ast / gold_ast
        valid_with_ratio = []
        for r in model_results:
            if r.train_correct and r.pred_ast is not None and r.gold_ast > 0:
                ratio = r.pred_ast / r.gold_ast
                valid_with_ratio.append((ratio, r))

        holdout_by_ratio = {}
        if valid_with_ratio:
            # Sort by ratio and split into terciles
            valid_with_ratio.sort(key=lambda x: x[0])
            n = len(valid_with_ratio)
            tercile_size = n // 3

            if tercile_size > 0:
                # Low tercile (most compact relative to gold)
                low_tercile = valid_with_ratio[:tercile_size]
                # Medium tercile
                mid_tercile = valid_with_ratio[tercile_size : 2 * tercile_size]
                # High tercile (most bloated relative to gold)
                high_tercile = valid_with_ratio[2 * tercile_size :]

                for name, tercile in [
                    ("low", low_tercile),
                    ("mid", mid_tercile),
                    ("high", high_tercile),
                ]:
                    if tercile:
                        ratios = [t[0] for t in tercile]
                        results_in_tercile = [t[1] for t in tercile]

                        if task == "fo":
                            mean_holdout = sum(
                                r.holdout_exact_match_rate for r in results_in_tercile
                            ) / len(results_in_tercile)
                        else:  # CI
                            mean_holdout = sum(
                                r.holdout_yes_exact_rate or 0 for r in results_in_tercile
                            ) / len(results_in_tercile)

                        holdout_by_ratio[name] = {
                            "count": len(tercile),
                            "min_ratio": min(ratios),
                            "max_ratio": max(ratios),
                            "mean_ratio": sum(ratios) / len(ratios),
                            "holdout_exact": mean_holdout,
                        }

        # Binary compact vs bloated analysis
        # Compact: r <= (N+1)/N where N = gold_ast (i.e., pred_ast <= gold_ast + 1)
        # Bloated: r > (N+1)/N (i.e., pred_ast > gold_ast + 1)
        compact_results = []
        bloated_results = []
        for r in model_results:
            if r.train_correct and r.pred_ast is not None and r.gold_ast > 0:
                threshold = (r.gold_ast + 1) / r.gold_ast
                ratio = r.pred_ast / r.gold_ast
                if ratio <= threshold:
                    compact_results.append(r)
                else:
                    bloated_results.append(r)

        if compact_results or bloated_results:
            holdout_by_ratio["compact_vs_bloated"] = {}
            if compact_results:
                if task == "fo":
                    compact_holdout = sum(
                        r.holdout_exact_match_rate for r in compact_results
                    ) / len(compact_results)
                else:
                    compact_holdout = sum(
                        r.holdout_yes_exact_rate or 0 for r in compact_results
                    ) / len(compact_results)
                holdout_by_ratio["compact_vs_bloated"]["compact"] = {
                    "count": len(compact_results),
                    "holdout_exact": compact_holdout,
                }
            if bloated_results:
                if task == "fo":
                    bloated_holdout = sum(
                        r.holdout_exact_match_rate for r in bloated_results
                    ) / len(bloated_results)
                else:
                    bloated_holdout = sum(
                        r.holdout_yes_exact_rate or 0 for r in bloated_results
                    ) / len(bloated_results)
                holdout_by_ratio["compact_vs_bloated"]["bloated"] = {
                    "count": len(bloated_results),
                    "holdout_exact": bloated_holdout,
                }

        # CI-specific
        holdout_yes_exact = None
        holdout_no_avoid = None
        if task == "ci":
            ci_results = [r for r in model_results if r.holdout_yes_exact_rate is not None]
            if ci_results:
                holdout_yes_exact = sum(r.holdout_yes_exact_rate for r in ci_results) / len(
                    ci_results
                )
                holdout_no_avoid = sum(r.holdout_no_avoid_rate or 0 for r in ci_results) / len(
                    ci_results
                )

        aggregates[model] = HoldoutModelAggregates(
            model=model,
            task=task,
            total_instances=total,
            train_correct_count=len(train_correct),
            train_incorrect_count=len(train_incorrect),
            mean_holdout_exact_rate=mean_exact,
            mean_holdout_accuracy=mean_acc,
            mean_holdout_fp_rate=mean_fp,
            mean_holdout_fn_rate=mean_fn,
            holdout_exact_given_train_correct=exact_given_correct,
            holdout_exact_given_train_incorrect=exact_given_incorrect,
            holdout_by_ast_delta=holdout_by_ast,
            holdout_by_ast_ratio=holdout_by_ratio,
            holdout_yes_exact_rate=holdout_yes_exact,
            holdout_no_avoid_rate=holdout_no_avoid,
        )

    return aggregates


# =============================================================================
# Main Analysis Pipeline
# =============================================================================


def run_holdout_analysis(
    fo_dataset: Optional[Path],
    ci_dataset: Optional[Path],
    outdir: Path,
    n_fo_holdout: int = 5,
    n_ci_yes_holdout: int = 3,
    n_ci_no_holdout: int = 2,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Run the full holdout generalization analysis.

    Args:
        fo_dataset: Path to FO (AD) benchmark YAML
        ci_dataset: Path to CI (C) benchmark YAML
        outdir: Output directory
        n_fo_holdout: Number of FO holdout worlds per instance
        n_ci_yes_holdout: Number of CI YES holdout worlds
        n_ci_no_holdout: Number of CI NO holdout worlds
        verbose: Print progress

    Returns:
        Summary dict with all results
    """
    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    results = {"fo": None, "ci": None}

    # Process FO
    if fo_dataset and Path(fo_dataset).exists():
        if verbose:
            print(f"\n[FO] Loading {fo_dataset}...")
        problems = load_from_yaml(str(fo_dataset))

        fo_problems = [p for p in problems if p.get("problem", {}).get("scenario") == "AD"]
        if verbose:
            print(f"[FO] Found {len(fo_problems)} AD problems")

        fo_results = []
        for i, problem in enumerate(fo_problems):
            for llm_result in problem.get("llmResults", []):
                result = evaluate_holdout_fo(problem, llm_result, n_fo_holdout)
                if result:
                    fo_results.append(result)

            if verbose and (i + 1) % 50 == 0:
                print(f"[FO] Processed {i + 1}/{len(fo_problems)} problems...")

        if verbose:
            print(f"[FO] Analyzed {len(fo_results)} instance-model pairs")

        fo_aggregates = aggregate_holdout_results(fo_results, "fo")
        results["fo"] = {
            "results": [asdict(r) for r in fo_results],
            "aggregates": {k: asdict(v) for k, v in fo_aggregates.items()},
        }

        with open(outdir / "fo_holdout.json", "w") as f:
            json.dump(results["fo"], f, indent=2, default=str)
        if verbose:
            print(f"[FO] Saved to {outdir / 'fo_holdout.json'}")

    # Process CI
    if ci_dataset and Path(ci_dataset).exists():
        if verbose:
            print(f"\n[CI] Loading {ci_dataset}...")
        problems = load_from_yaml(str(ci_dataset))

        ci_problems = [p for p in problems if p.get("problem", {}).get("scenario") == "C"]
        if verbose:
            print(f"[CI] Found {len(ci_problems)} C problems")

        ci_results = []
        for i, problem in enumerate(ci_problems):
            for llm_result in problem.get("llmResults", []):
                result = evaluate_holdout_ci(problem, llm_result, n_ci_yes_holdout, n_ci_no_holdout)
                if result:
                    ci_results.append(result)

            if verbose and (i + 1) % 50 == 0:
                print(f"[CI] Processed {i + 1}/{len(ci_problems)} problems...")

        if verbose:
            print(f"[CI] Analyzed {len(ci_results)} instance-model pairs")

        ci_aggregates = aggregate_holdout_results(ci_results, "ci")
        results["ci"] = {
            "results": [asdict(r) for r in ci_results],
            "aggregates": {k: asdict(v) for k, v in ci_aggregates.items()},
        }

        with open(outdir / "ci_holdout.json", "w") as f:
            json.dump(results["ci"], f, indent=2, default=str)
        if verbose:
            print(f"[CI] Saved to {outdir / 'ci_holdout.json'}")

    # Generate report
    generate_holdout_report(results, outdir, verbose)

    return results


def generate_holdout_report(results: Dict[str, Any], outdir: Path, verbose: bool = True) -> None:
    """Generate markdown report for holdout analysis."""
    lines = []
    lines.append("# Held-Out Generalization Analysis")
    lines.append("")
    lines.append(
        "This report evaluates model predictions on held-out worlds to measure generalization."
    )
    lines.append("")
    lines.append("## Methodology")
    lines.append("")
    lines.append("**FullObs (FO):**")
    lines.append(
        "- Generate N=5 IID holdout worlds per instance (same sampler as training, no CEGIS)"
    )
    lines.append("- Label holdout worlds by gold formula φ*")
    lines.append("- Compute holdout exact-match rate and FP/FN rates")
    lines.append("")
    lines.append("**Contrastive (CI):**")
    lines.append("- Generate M=3 YES holdout worlds labeled by gold φ*")
    lines.append("- Generate N=2 NO holdout worlds with perturbed labels")
    lines.append("- Evaluate YES exact-match and NO avoid-match rates")
    lines.append("")

    for task in ["fo", "ci"]:
        if results.get(task) is None:
            continue

        task_name = {"fo": "FullObs (FO)", "ci": "Contrastive (CI)"}[task]
        lines.append(f"## {task_name}")
        lines.append("")

        aggregates = results[task].get("aggregates", {})
        if not aggregates:
            lines.append("No data available.")
            lines.append("")
            continue

        # Overall table
        lines.append("### Overall Holdout Performance")
        lines.append("")

        if task == "fo":
            headers = [
                "Model",
                "Total",
                "Train Correct",
                "Holdout Exact%",
                "Holdout Acc%",
                "Exact|Correct",
                "Exact|Incorrect",
            ]
            alignments = ["l"] + ["r"] * 6
            rows = []
            for model, agg in sorted(aggregates.items()):
                exact_correct = (
                    f"{agg['holdout_exact_given_train_correct']*100:.1f}"
                    if agg["holdout_exact_given_train_correct"] is not None
                    else "-"
                )
                exact_incorrect = (
                    f"{agg['holdout_exact_given_train_incorrect']*100:.1f}"
                    if agg["holdout_exact_given_train_incorrect"] is not None
                    else "-"
                )
                rows.append(
                    [
                        model[:15],
                        str(agg["total_instances"]),
                        str(agg["train_correct_count"]),
                        f"{agg['mean_holdout_exact_rate']*100:.1f}",
                        f"{agg['mean_holdout_accuracy']*100:.1f}",
                        exact_correct,
                        exact_incorrect,
                    ]
                )
            lines.append(format_aligned_table(headers, rows, alignments))
        else:  # CI
            headers = [
                "Model",
                "Total",
                "Train Correct",
                "YES Exact%",
                "NO Avoid%",
                "Exact|Correct",
                "Exact|Incorrect",
            ]
            alignments = ["l"] + ["r"] * 6
            rows = []
            for model, agg in sorted(aggregates.items()):
                exact_correct = (
                    f"{agg['holdout_exact_given_train_correct']*100:.1f}"
                    if agg["holdout_exact_given_train_correct"] is not None
                    else "-"
                )
                exact_incorrect = (
                    f"{agg['holdout_exact_given_train_incorrect']*100:.1f}"
                    if agg["holdout_exact_given_train_incorrect"] is not None
                    else "-"
                )
                rows.append(
                    [
                        model[:15],
                        str(agg["total_instances"]),
                        str(agg["train_correct_count"]),
                        f"{(agg['holdout_yes_exact_rate'] or 0)*100:.1f}",
                        f"{(agg['holdout_no_avoid_rate'] or 0)*100:.1f}",
                        exact_correct,
                        exact_incorrect,
                    ]
                )
            lines.append(format_aligned_table(headers, rows, alignments))

        lines.append("")

        # By AST delta
        lines.append("### Holdout Performance by AST Delta")
        lines.append("")
        lines.append("AST Delta = predicted AST - gold AST")
        lines.append("")

        # Collect all bins
        all_bins = set()
        for agg in aggregates.values():
            all_bins.update(agg.get("holdout_by_ast_delta", {}).keys())

        if all_bins:
            bin_order = ["delta<0", "delta=0", "delta_1-10", "delta_11-25", "delta>25"]
            bins_present = [b for b in bin_order if b in all_bins]

            headers = ["Model"] + bins_present
            alignments = ["l"] + ["r"] * len(bins_present)
            rows = []

            for model, agg in sorted(aggregates.items()):
                row = [model[:15]]
                for bin_name in bins_present:
                    bin_data = agg.get("holdout_by_ast_delta", {}).get(bin_name, {})
                    if bin_data:
                        row.append(
                            f"{bin_data.get('mean_exact', 0)*100:.0f}% (n={bin_data.get('count', 0)})"
                        )
                    else:
                        row.append("-")
                rows.append(row)

            lines.append(format_aligned_table(headers, rows, alignments))

        lines.append("")

        # By AST ratio (valid formulas only)
        lines.append("### Holdout Performance by AST Ratio (Valid Formulas Only)")
        lines.append("")
        lines.append("AST Ratio = predicted AST / gold AST (terciles among valid formulas)")
        lines.append("")
        lines.append("This analysis examines whether more compact or more bloated valid formulas")
        lines.append("generalize better to held-out worlds.")
        lines.append("")

        # Check if any model has ratio data
        any_ratio_data = any(agg.get("holdout_by_ast_ratio", {}) for agg in aggregates.values())

        if any_ratio_data:
            metric_name = "Holdout Exact%" if task == "fo" else "YES Exact%"
            headers = ["Model", "Low (compact)", "Mid", "High (bloated)"]
            alignments = ["l", "r", "r", "r"]
            rows = []

            for model, agg in sorted(aggregates.items()):
                ratio_data = agg.get("holdout_by_ast_ratio", {})
                row = [model[:15]]
                for tercile in ["low", "mid", "high"]:
                    td = ratio_data.get(tercile, {})
                    if td:
                        ratio_range = f"r={td.get('min_ratio', 0):.2f}-{td.get('max_ratio', 0):.2f}"
                        holdout_val = td.get("holdout_exact", 0) * 100
                        row.append(f"{holdout_val:.1f}% ({ratio_range}, n={td.get('count', 0)})")
                    else:
                        row.append("-")
                rows.append(row)

            lines.append(format_aligned_table(headers, rows, alignments))

            # Add summary interpretation
            lines.append("")
            lines.append(
                "**Interpretation:** Lower AST ratio means the model's formula is more compact"
            )
            lines.append(
                "relative to gold. If low-ratio formulas have higher holdout accuracy, it suggests"
            )
            lines.append("compact solutions generalize better (and bloat hurts generalization).")

            # Binary compact vs bloated table
            lines.append("")
            lines.append("### Compact vs Bloated (Binary Split)")
            lines.append("")
            lines.append(
                "**Compact:** pred_AST ≤ gold_AST + 1 (i.e., r ≤ (N+1)/N where N = gold_AST)"
            )
            lines.append("")
            lines.append("**Bloated:** pred_AST > gold_AST + 1 (i.e., r > (N+1)/N)")
            lines.append("")

            headers = ["Model", "Compact (≤gold+1)", "Bloated (>gold+1)", "Δ (Compact - Bloated)"]
            alignments = ["l", "r", "r", "r"]
            rows = []

            for model, agg in sorted(aggregates.items()):
                ratio_data = agg.get("holdout_by_ast_ratio", {})
                cvb = ratio_data.get("compact_vs_bloated", {})

                compact_data = cvb.get("compact", {})
                bloated_data = cvb.get("bloated", {})

                if compact_data:
                    compact_str = f"{compact_data.get('holdout_exact', 0)*100:.1f}% (n={compact_data.get('count', 0)})"
                    compact_val = compact_data.get("holdout_exact", 0)
                else:
                    compact_str = "-"
                    compact_val = None

                if bloated_data:
                    bloated_str = f"{bloated_data.get('holdout_exact', 0)*100:.1f}% (n={bloated_data.get('count', 0)})"
                    bloated_val = bloated_data.get("holdout_exact", 0)
                else:
                    bloated_str = "-"
                    bloated_val = None

                if compact_val is not None and bloated_val is not None:
                    delta = (compact_val - bloated_val) * 100
                    delta_str = f"{delta:+.1f}%"
                else:
                    delta_str = "-"

                rows.append([model[:15], compact_str, bloated_str, delta_str])

            lines.append(format_aligned_table(headers, rows, alignments))
            lines.append("")
            lines.append(
                "**Interpretation:** Positive Δ means compact formulas generalize better than bloated ones."
            )
        else:
            lines.append("No valid formulas with AST ratio data available.")

        lines.append("")
        lines.append("")

    # Write report
    report_path = outdir / "holdout_report.md"
    with open(report_path, "w") as f:
        f.write("\n".join(lines))

    if verbose:
        print(f"\nGenerated report: {report_path}")


# =============================================================================
# CLI
# =============================================================================


def main():
    parser = argparse.ArgumentParser(
        description="Held-Out Generalization Analysis",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    # Analyze FO and CI
    python -m concept_synth.analysis.holdout_generalization \\
        --fo-dataset results/ad_benchmark/ad_benchmark_v1.yaml \\
        --ci-dataset results/c_benchmark/c_benchmark_v1.yaml \\
        --out artifacts/analysis/v1/holdout/
        """,
    )

    parser.add_argument("--fo-dataset", help="Path to FO (AD) benchmark YAML")
    parser.add_argument("--ci-dataset", help="Path to CI (C) benchmark YAML")
    parser.add_argument("--out", "-o", required=True, help="Output directory for results")
    parser.add_argument(
        "--n-fo-holdout", type=int, default=5, help="Number of FO holdout worlds (default: 5)"
    )
    parser.add_argument(
        "--n-ci-yes", type=int, default=3, help="Number of CI YES holdout worlds (default: 3)"
    )
    parser.add_argument(
        "--n-ci-no", type=int, default=2, help="Number of CI NO holdout worlds (default: 2)"
    )
    parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output")

    args = parser.parse_args()

    run_holdout_analysis(
        fo_dataset=Path(args.fo_dataset) if args.fo_dataset else None,
        ci_dataset=Path(args.ci_dataset) if args.ci_dataset else None,
        outdir=Path(args.out),
        n_fo_holdout=args.n_fo_holdout,
        n_ci_yes_holdout=args.n_ci_yes,
        n_ci_no_holdout=args.n_ci_no,
        verbose=not args.quiet,
    )


if __name__ == "__main__":
    main()
