#!/usr/bin/env python3
"""
Within-Problem Bloat Control Analysis

Addresses potential selection bias in Table 3 (FullObs held-out generalization vs bloat)
by computing within-problem comparisons across models.

For each problem with ≥2 train-correct predictions from different models:
1. Compare holdout generalization of shortest vs longest valid formula
2. Compare compact vs bloated formulas (when both exist for the same problem)

This controls for instance difficulty since we compare solutions to the SAME problem.

Usage:
    python -m concept_synth.analysis.within_problem_bloat_control \
        --task fullobs --benchmark v1 \
        --out artifacts/analysis/v1/
"""

import argparse
import json
import os
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from scipy import stats

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root

add_repo_root(__file__)


# =============================================================================
# Data Classes
# =============================================================================


@dataclass
class ProblemRow:
    """A single train-correct prediction for a problem."""

    instance_id: str
    model: str
    band: str
    pred_ast: int
    gold_ast: int
    ast_delta: int
    is_compact: bool
    is_bloated: bool
    holdout_exact_match_rate: float
    is_exact_gold: bool = False  # True if pred_sexpr == gold_sexpr (exact string match)


@dataclass
class WithinProblemComparison:
    """Within-problem comparison results."""

    instance_id: str
    band: str
    n_valid: int

    # Short vs Long comparison
    short_ast: int
    long_ast: int
    short_model: str
    long_model: str
    short_gen: float
    long_gen: float
    delta_short_minus_long: float

    # Compact vs Bloated comparison (if both exist)
    has_compact: bool
    has_bloated: bool
    n_compact: int
    n_bloated: int
    compact_gen_mean: Optional[float]
    bloated_gen_mean: Optional[float]
    delta_compact_minus_bloated: Optional[float]


@dataclass
class SummaryStats:
    """Summary statistics for a comparison type."""

    comparison_type: str
    n_problems: int
    mean_short_gen: float
    mean_long_gen: float
    mean_delta: float
    median_delta: float
    std_delta: float
    frac_delta_pos: float
    frac_delta_zero: float
    frac_delta_neg: float
    ci_lower: float  # 95% bootstrap CI
    ci_upper: float
    pval_delta_pos: float = 1.0  # p-value for H0: frac_delta_pos <= 0.5 (one-sided binomial)


# =============================================================================
# Core Analysis Functions
# =============================================================================


def load_holdout_results(holdout_path: Path) -> List[Dict[str, Any]]:
    """Load holdout results from JSON file."""
    with open(holdout_path) as f:
        data = json.load(f)
    return data.get("results", [])


def load_eval_records_formulas(eval_records_path: Path) -> Dict[Tuple[str, str], Tuple[str, str]]:
    """
    Load eval records and extract gold/pred formula strings.
    
    Returns:
        Dict mapping (instance_id, model) -> (gold_sexpr, pred_sexpr)
    """
    formulas = {}
    with open(eval_records_path) as f:
        for line in f:
            r = json.loads(line)
            instance_id = r.get("instance_id", "")
            model = r.get("model", "")
            metadata = r.get("metadata", {})
            gold_sexpr = metadata.get("gold_sexpr", "")
            pred_sexpr = metadata.get("pred_sexpr", "")
            if instance_id and model:
                formulas[(instance_id, model)] = (gold_sexpr, pred_sexpr)
    return formulas


def build_problem_rows(
    results: List[Dict[str, Any]], 
    threshold_delta: int = 1,
    formula_lookup: Optional[Dict[Tuple[str, str], Tuple[str, str]]] = None,
) -> Dict[str, List[ProblemRow]]:
    """
    Build per-problem rows of train-correct predictions.

    Args:
        results: List of holdout result dicts
        threshold_delta: Threshold for compact vs bloated (default: 1)
                        compact = pred_ast <= gold_ast + threshold_delta
        formula_lookup: Optional dict mapping (instance_id, model) -> (gold_sexpr, pred_sexpr)
                       Used to determine exact gold matches

    Returns:
        Dict mapping instance_id -> list of ProblemRow for train-correct predictions
    """
    by_problem: Dict[str, List[ProblemRow]] = defaultdict(list)

    for r in results:
        # Only include train-correct predictions
        if not r.get("train_correct", False):
            continue

        instance_id = r.get("instance_id", "")
        model = r.get("model", "")
        band = r.get("band", "")
        pred_ast = r.get("pred_ast")
        gold_ast = r.get("gold_ast", 0)
        ast_delta = r.get("ast_delta")
        holdout_exact = r.get("holdout_exact_match_rate", 0.0)

        # Skip if missing AST info
        if pred_ast is None or ast_delta is None:
            continue

        is_compact = pred_ast <= gold_ast + threshold_delta
        is_bloated = pred_ast > gold_ast + threshold_delta
        
        # Check for exact gold match using formula strings
        is_exact_gold = False
        if formula_lookup:
            key = (instance_id, model)
            if key in formula_lookup:
                gold_sexpr, pred_sexpr = formula_lookup[key]
                # Exact string match (could also normalize, but string match is conservative)
                is_exact_gold = (gold_sexpr and pred_sexpr and gold_sexpr == pred_sexpr)

        row = ProblemRow(
            instance_id=instance_id,
            model=model,
            band=band,
            pred_ast=pred_ast,
            gold_ast=gold_ast,
            ast_delta=ast_delta,
            is_compact=is_compact,
            is_bloated=is_bloated,
            holdout_exact_match_rate=holdout_exact,
            is_exact_gold=is_exact_gold,
        )
        by_problem[instance_id].append(row)

    return by_problem


@dataclass
class ExcludeGoldComparison:
    """Within-problem comparison excluding likely-gold formulas."""

    instance_id: str
    band: str
    
    # Short vs Long (excluding gold)
    has_short_long: bool
    n_non_gold: int
    short_gen: Optional[float]
    long_gen: Optional[float]
    delta_short_long: Optional[float]
    
    # Compact vs Bloated (excluding gold)
    has_compact_bloated: bool
    n_compact_non_gold: int
    n_bloated: int
    compact_gen_mean: Optional[float]
    bloated_gen_mean: Optional[float]
    delta_compact_bloated: Optional[float]


def compute_within_problem_comparisons(
    by_problem: Dict[str, List[ProblemRow]], require_both_groups: bool = False
) -> List[WithinProblemComparison]:
    """
    Compute within-problem comparisons.

    Args:
        by_problem: Dict mapping instance_id -> list of ProblemRow
        require_both_groups: If True, only include problems with both compact and bloated

    Returns:
        List of WithinProblemComparison objects
    """
    comparisons = []

    for instance_id, rows in by_problem.items():
        # Require at least 2 train-correct predictions
        if len(rows) < 2:
            continue

        band = rows[0].band
        n_valid = len(rows)

        # Find shortest and longest
        sorted_by_ast = sorted(rows, key=lambda r: r.pred_ast)
        shortest = sorted_by_ast[0]
        longest = sorted_by_ast[-1]

        short_gen = shortest.holdout_exact_match_rate
        long_gen = longest.holdout_exact_match_rate
        delta_short_long = short_gen - long_gen

        # Compact vs Bloated
        compact_rows = [r for r in rows if r.is_compact]
        bloated_rows = [r for r in rows if r.is_bloated]

        has_compact = len(compact_rows) > 0
        has_bloated = len(bloated_rows) > 0

        compact_gen_mean = None
        bloated_gen_mean = None
        delta_compact_bloated = None

        if has_compact and has_bloated:
            compact_gen_mean = np.mean([r.holdout_exact_match_rate for r in compact_rows])
            bloated_gen_mean = np.mean([r.holdout_exact_match_rate for r in bloated_rows])
            delta_compact_bloated = compact_gen_mean - bloated_gen_mean

        # Skip if require_both_groups and we don't have both
        if require_both_groups and not (has_compact and has_bloated):
            continue

        comp = WithinProblemComparison(
            instance_id=instance_id,
            band=band,
            n_valid=n_valid,
            short_ast=shortest.pred_ast,
            long_ast=longest.pred_ast,
            short_model=shortest.model,
            long_model=longest.model,
            short_gen=short_gen,
            long_gen=long_gen,
            delta_short_minus_long=delta_short_long,
            has_compact=has_compact,
            has_bloated=has_bloated,
            n_compact=len(compact_rows),
            n_bloated=len(bloated_rows),
            compact_gen_mean=compact_gen_mean,
            bloated_gen_mean=bloated_gen_mean,
            delta_compact_minus_bloated=delta_compact_bloated,
        )
        comparisons.append(comp)

    return comparisons


def compute_exclude_gold_comparisons(
    by_problem: Dict[str, List[ProblemRow]]
) -> List[ExcludeGoldComparison]:
    """
    Compute within-problem comparisons EXCLUDING exact-gold formulas.
    
    This controls for the concern that the gold formula itself (which has
    perfect holdout generalization by construction) might be driving the effect.
    
    Uses exact string match (pred_sexpr == gold_sexpr) to identify gold formulas.
    """
    comparisons = []

    for instance_id, rows in by_problem.items():
        band = rows[0].band
        
        # Filter out exact-gold formulas
        non_gold_rows = [r for r in rows if not r.is_exact_gold]
        
        # Short vs Long (excluding gold)
        has_short_long = len(non_gold_rows) >= 2
        short_gen = None
        long_gen = None
        delta_short_long = None
        
        if has_short_long:
            sorted_by_ast = sorted(non_gold_rows, key=lambda r: r.pred_ast)
            shortest = sorted_by_ast[0]
            longest = sorted_by_ast[-1]
            short_gen = shortest.holdout_exact_match_rate
            long_gen = longest.holdout_exact_match_rate
            delta_short_long = short_gen - long_gen
        
        # Compact vs Bloated (excluding gold from compact)
        compact_non_gold = [r for r in rows if r.is_compact and not r.is_exact_gold]
        bloated_rows = [r for r in rows if r.is_bloated]  # bloated can't be gold
        
        has_compact_bloated = len(compact_non_gold) > 0 and len(bloated_rows) > 0
        compact_gen_mean = None
        bloated_gen_mean = None
        delta_compact_bloated = None
        
        if has_compact_bloated:
            compact_gen_mean = np.mean([r.holdout_exact_match_rate for r in compact_non_gold])
            bloated_gen_mean = np.mean([r.holdout_exact_match_rate for r in bloated_rows])
            delta_compact_bloated = compact_gen_mean - bloated_gen_mean
        
        # Only include if we have at least one valid comparison
        if has_short_long or has_compact_bloated:
            comp = ExcludeGoldComparison(
                instance_id=instance_id,
                band=band,
                has_short_long=has_short_long,
                n_non_gold=len(non_gold_rows),
                short_gen=short_gen,
                long_gen=long_gen,
                delta_short_long=delta_short_long,
                has_compact_bloated=has_compact_bloated,
                n_compact_non_gold=len(compact_non_gold),
                n_bloated=len(bloated_rows),
                compact_gen_mean=compact_gen_mean,
                bloated_gen_mean=bloated_gen_mean,
                delta_compact_bloated=delta_compact_bloated,
            )
            comparisons.append(comp)

    return comparisons


def bootstrap_ci(
    values: List[float], n_boot: int = 2000, seed: int = 0, alpha: float = 0.05
) -> Tuple[float, float]:
    """
    Compute bootstrap confidence interval for the mean.

    Args:
        values: List of values to bootstrap
        n_boot: Number of bootstrap samples
        seed: Random seed
        alpha: Significance level (default 0.05 for 95% CI)

    Returns:
        (lower, upper) bounds of CI
    """
    if len(values) == 0:
        return (0.0, 0.0)

    rng = np.random.default_rng(seed)
    values_arr = np.array(values)
    n = len(values_arr)

    boot_means = []
    for _ in range(n_boot):
        sample = rng.choice(values_arr, size=n, replace=True)
        boot_means.append(np.mean(sample))

    boot_means = np.array(boot_means)
    lower = np.percentile(boot_means, 100 * alpha / 2)
    upper = np.percentile(boot_means, 100 * (1 - alpha / 2))

    return (lower, upper)


def compute_summary_stats(
    comparisons: List[WithinProblemComparison],
    comparison_type: str,
    n_boot: int = 2000,
    seed: int = 0,
) -> SummaryStats:
    """
    Compute summary statistics for a comparison type.

    Args:
        comparisons: List of WithinProblemComparison objects
        comparison_type: "short_long" or "compact_bloated"
        n_boot: Number of bootstrap samples
        seed: Random seed

    Returns:
        SummaryStats object
    """
    if comparison_type == "short_long":
        deltas = [c.delta_short_minus_long for c in comparisons]
        short_gens = [c.short_gen for c in comparisons]
        long_gens = [c.long_gen for c in comparisons]
        n_problems = len(comparisons)
    else:  # compact_bloated
        # Only include problems with both compact and bloated
        valid = [c for c in comparisons if c.delta_compact_minus_bloated is not None]
        deltas = [c.delta_compact_minus_bloated for c in valid]
        short_gens = [c.compact_gen_mean for c in valid]
        long_gens = [c.bloated_gen_mean for c in valid]
        n_problems = len(valid)

    if n_problems == 0:
        return SummaryStats(
            comparison_type=comparison_type,
            n_problems=0,
            mean_short_gen=0.0,
            mean_long_gen=0.0,
            mean_delta=0.0,
            median_delta=0.0,
            std_delta=0.0,
            frac_delta_pos=0.0,
            frac_delta_zero=0.0,
            frac_delta_neg=0.0,
            ci_lower=0.0,
            ci_upper=0.0,
        )

    mean_delta = np.mean(deltas)
    median_delta = np.median(deltas)
    std_delta = np.std(deltas)

    # Fractions
    eps = 1e-9
    frac_pos = sum(1 for d in deltas if d > eps) / n_problems
    frac_zero = sum(1 for d in deltas if abs(d) <= eps) / n_problems
    frac_neg = sum(1 for d in deltas if d < -eps) / n_problems

    # Bootstrap CI
    ci_lower, ci_upper = bootstrap_ci(deltas, n_boot=n_boot, seed=seed)
    
    # Binomial test: H0: P(delta > 0) <= 0.5, H1: P(delta > 0) > 0.5
    # Count successes (delta > 0) excluding ties (delta == 0)
    n_pos = sum(1 for d in deltas if d > eps)
    n_neg = sum(1 for d in deltas if d < -eps)
    n_nonzero = n_pos + n_neg
    if n_nonzero > 0:
        # One-sided binomial test (scipy >= 1.7 uses binomtest)
        result = stats.binomtest(n_pos, n_nonzero, 0.5, alternative='greater')
        pval_pos = result.pvalue
    else:
        pval_pos = 1.0

    return SummaryStats(
        comparison_type=comparison_type,
        n_problems=n_problems,
        mean_short_gen=np.mean(short_gens),
        mean_long_gen=np.mean(long_gens),
        mean_delta=mean_delta,
        median_delta=median_delta,
        std_delta=std_delta,
        frac_delta_pos=frac_pos,
        frac_delta_zero=frac_zero,
        frac_delta_neg=frac_neg,
        ci_lower=ci_lower,
        ci_upper=ci_upper,
        pval_delta_pos=pval_pos,
    )


def compute_exclude_gold_summary_stats(
    comparisons: List[ExcludeGoldComparison],
    comparison_type: str,
    n_boot: int = 2000,
    seed: int = 0,
) -> SummaryStats:
    """
    Compute summary statistics for exclude-gold comparisons.
    """
    if comparison_type == "short_long":
        valid = [c for c in comparisons if c.has_short_long]
        deltas = [c.delta_short_long for c in valid]
        short_gens = [c.short_gen for c in valid]
        long_gens = [c.long_gen for c in valid]
        n_problems = len(valid)
    else:  # compact_bloated
        valid = [c for c in comparisons if c.has_compact_bloated]
        deltas = [c.delta_compact_bloated for c in valid]
        short_gens = [c.compact_gen_mean for c in valid]
        long_gens = [c.bloated_gen_mean for c in valid]
        n_problems = len(valid)

    if n_problems == 0:
        return SummaryStats(
            comparison_type=comparison_type,
            n_problems=0,
            mean_short_gen=0.0,
            mean_long_gen=0.0,
            mean_delta=0.0,
            median_delta=0.0,
            std_delta=0.0,
            frac_delta_pos=0.0,
            frac_delta_zero=0.0,
            frac_delta_neg=0.0,
            ci_lower=0.0,
            ci_upper=0.0,
        )

    mean_delta = np.mean(deltas)
    median_delta = np.median(deltas)
    std_delta = np.std(deltas)

    eps = 1e-9
    frac_pos = sum(1 for d in deltas if d > eps) / n_problems
    frac_zero = sum(1 for d in deltas if abs(d) <= eps) / n_problems
    frac_neg = sum(1 for d in deltas if d < -eps) / n_problems

    ci_lower, ci_upper = bootstrap_ci(deltas, n_boot=n_boot, seed=seed)
    
    # Binomial test: H0: P(delta > 0) <= 0.5, H1: P(delta > 0) > 0.5
    n_pos = sum(1 for d in deltas if d > eps)
    n_neg = sum(1 for d in deltas if d < -eps)
    n_nonzero = n_pos + n_neg
    if n_nonzero > 0:
        result = stats.binomtest(n_pos, n_nonzero, 0.5, alternative='greater')
        pval_pos = result.pvalue
    else:
        pval_pos = 1.0

    return SummaryStats(
        comparison_type=comparison_type,
        n_problems=n_problems,
        mean_short_gen=np.mean(short_gens),
        mean_long_gen=np.mean(long_gens),
        mean_delta=mean_delta,
        median_delta=median_delta,
        std_delta=std_delta,
        frac_delta_pos=frac_pos,
        frac_delta_zero=frac_zero,
        frac_delta_neg=frac_neg,
        ci_lower=ci_lower,
        ci_upper=ci_upper,
        pval_delta_pos=pval_pos,
    )


def compute_band_stratified_stats(
    comparisons: List[WithinProblemComparison],
    comparison_type: str,
    n_boot: int = 2000,
    seed: int = 0,
) -> Dict[str, SummaryStats]:
    """Compute summary stats stratified by band."""
    by_band: Dict[str, List[WithinProblemComparison]] = defaultdict(list)
    for c in comparisons:
        by_band[c.band].append(c)

    stats = {}
    for band, band_comps in sorted(by_band.items()):
        stats[band] = compute_summary_stats(
            band_comps, comparison_type, n_boot=n_boot, seed=seed
        )

    return stats


def run_fixed_effects_regression(
    by_problem: Dict[str, List[ProblemRow]]
) -> Optional[Dict[str, Any]]:
    """
    Run fixed-effects OLS regression:
        holdout_exact_match_rate ~ ast_delta + C(model) + C(problem_id)

    Returns regression summary dict or None if statsmodels not available.
    """
    try:
        import pandas as pd
        import statsmodels.formula.api as smf
    except ImportError:
        return None

    # Build dataframe
    rows = []
    for instance_id, problem_rows in by_problem.items():
        for r in problem_rows:
            rows.append(
                {
                    "instance_id": instance_id,
                    "model": r.model,
                    "ast_delta": r.ast_delta,
                    "holdout_exact": r.holdout_exact_match_rate,
                }
            )

    if len(rows) < 10:
        return None

    df = pd.DataFrame(rows)

    # Need at least 2 levels for each factor
    if df["instance_id"].nunique() < 2 or df["model"].nunique() < 2:
        return None

    try:
        # OLS with fixed effects
        formula = "holdout_exact ~ ast_delta + C(model) + C(instance_id)"
        model = smf.ols(formula, data=df).fit()

        return {
            "ast_delta_coef": model.params.get("ast_delta", 0.0),
            "ast_delta_se": model.bse.get("ast_delta", 0.0),
            "ast_delta_pval": model.pvalues.get("ast_delta", 1.0),
            "r_squared": model.rsquared,
            "n_obs": int(model.nobs),
            "n_problems": df["instance_id"].nunique(),
            "n_models": df["model"].nunique(),
        }
    except Exception as e:
        print(f"  Warning: Regression failed: {e}")
        return None


# =============================================================================
# Output Generation
# =============================================================================


def generate_csv(
    comparisons: List[WithinProblemComparison], output_path: Path
) -> None:
    """Generate CSV with per-problem comparison data."""
    import csv

    output_path.parent.mkdir(parents=True, exist_ok=True)

    with open(output_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(
            [
                "instance_id",
                "band",
                "n_valid",
                "short_ast",
                "long_ast",
                "short_model",
                "long_model",
                "short_gen",
                "long_gen",
                "delta_short_long",
                "n_compact",
                "n_bloated",
                "compact_gen_mean",
                "bloated_gen_mean",
                "delta_compact_bloated",
            ]
        )

        for c in comparisons:
            writer.writerow(
                [
                    c.instance_id,
                    c.band,
                    c.n_valid,
                    c.short_ast,
                    c.long_ast,
                    c.short_model,
                    c.long_model,
                    f"{c.short_gen:.4f}",
                    f"{c.long_gen:.4f}",
                    f"{c.delta_short_minus_long:.4f}",
                    c.n_compact,
                    c.n_bloated,
                    f"{c.compact_gen_mean:.4f}" if c.compact_gen_mean is not None else "",
                    f"{c.bloated_gen_mean:.4f}" if c.bloated_gen_mean is not None else "",
                    f"{c.delta_compact_minus_bloated:.4f}"
                    if c.delta_compact_minus_bloated is not None
                    else "",
                ]
            )

    print(f"  Saved CSV: {output_path}")


def generate_latex_table(
    short_long_stats: SummaryStats,
    compact_bloated_stats: SummaryStats,
    regression_results: Optional[Dict[str, Any]],
    output_path: Path,
    excl_gold_sl: Optional[SummaryStats] = None,
    excl_gold_cb: Optional[SummaryStats] = None,
    task: str = "fullobs",
) -> None:
    """Generate LaTeX table for appendix."""
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Task-specific names
    task_display = "FullObs" if task == "fullobs" else "CI"
    label_suffix = "fo" if task == "fullobs" else "ci"

    lines = []
    lines.append("% Within-problem bloat control (auto-generated)")
    lines.append("\\begin{table}[tb]")
    lines.append("\\centering")
    lines.append(f"\\caption{{\\textbf{{Within-problem bloat control ({task_display}).}}")
    lines.append("For problems with $\\geq 2$ train-correct predictions (excluding exact gold matches) across models,")
    lines.append("we compare holdout generalization of shortest vs longest formulas")
    lines.append("(and compact vs bloated when both exist). This controls for instance difficulty.")
    lines.append("$\\Delta$ = short/compact $-$ long/bloated; positive values indicate")
    lines.append("shorter formulas generalize better on the \\emph{same} problem.")
    
    # Collect negative fractions for caption (use excl_gold stats)
    neg_fracs = []
    if excl_gold_sl and excl_gold_sl.n_problems > 0:
        neg_fracs.append(f"Short--Long {excl_gold_sl.frac_delta_neg*100:.0f}\\%")
    if excl_gold_cb and excl_gold_cb.n_problems > 0:
        neg_fracs.append(f"Compact--Bloat {excl_gold_cb.frac_delta_neg*100:.0f}\\%")
    if neg_fracs:
        lines.append(f"Fraction $\\Delta < 0$: {', '.join(neg_fracs)}.}}")
    else:
        lines.append("}")
    
    lines.append(f"\\label{{tab:within_problem_bloat_{label_suffix}}}")
    lines.append("\\small")
    lines.append("\\begin{tabular}{@{}l@{\\hspace{4pt}}c@{\\hspace{4pt}}c@{\\hspace{4pt}}c@{\\hspace{4pt}}c@{\\hspace{4pt}}c@{\\hspace{4pt}}c@{}}")
    lines.append("  \\toprule")
    lines.append("  & & \\multicolumn{2}{c}{Holdout Gen} & & & \\\\")
    lines.append("  \\cmidrule(lr){3-4}")
    lines.append("  Comparison & $n$ & Short & Long & $\\Delta$ [CI] & $\\Delta>0$ & $p$ \\\\")
    lines.append("  \\midrule")

    def fmt_delta(val: float) -> str:
        """Format delta with proper sign."""
        if val >= 0:
            return f"+{val:.1f}"
        else:
            return f"{val:.1f}"
    
    def fmt_pval(p: float) -> str:
        """Format p-value."""
        if p < 0.001:
            return "$<$0.001"
        elif p < 0.01:
            return f"{p:.3f}"
        else:
            return f"{p:.2f}"
    
    # Only show exclude-gold rows (skip the rows that include gold)
    
    # Short vs Long excluding gold
    if excl_gold_sl and excl_gold_sl.n_problems > 0:
        ci_str = f"[{excl_gold_sl.ci_lower*100:.0f}, {excl_gold_sl.ci_upper*100:.0f}]"
        lines.append(
            f"  Short--Long & {excl_gold_sl.n_problems} & {excl_gold_sl.mean_short_gen*100:.1f}\\% & "
            f"{excl_gold_sl.mean_long_gen*100:.1f}\\% & {fmt_delta(excl_gold_sl.mean_delta*100)} {ci_str} & "
            f"{excl_gold_sl.frac_delta_pos*100:.0f}\\% & {fmt_pval(excl_gold_sl.pval_delta_pos)} \\\\"
        )

    # Compact vs Bloated excluding gold
    if excl_gold_cb and excl_gold_cb.n_problems > 0:
        ci_str = f"[{excl_gold_cb.ci_lower*100:.0f}, {excl_gold_cb.ci_upper*100:.0f}]"
        lines.append(
            f"  Compact--Bloat & {excl_gold_cb.n_problems} & {excl_gold_cb.mean_short_gen*100:.1f}\\% & "
            f"{excl_gold_cb.mean_long_gen*100:.1f}\\% & {fmt_delta(excl_gold_cb.mean_delta*100)} {ci_str} & "
            f"{excl_gold_cb.frac_delta_pos*100:.0f}\\% & {fmt_pval(excl_gold_cb.pval_delta_pos)} \\\\"
        )

    lines.append("  \\bottomrule")
    lines.append("\\end{tabular}")

    # Add regression note if available
    if regression_results:
        coef = regression_results["ast_delta_coef"]
        se = regression_results["ast_delta_se"]
        pval = regression_results["ast_delta_pval"]
        lines.append(
            f"\\caption*{{Fixed-effects regression: $\\beta_{{\\text{{AST}}\\Delta}}={coef:.4f}$ "
            f"(SE={se:.4f}, $p={pval:.3f}$), controlling for model and problem.}}"
        )

    lines.append("\\end{table}")

    with open(output_path, "w") as f:
        f.write("\n".join(lines))

    print(f"  Saved LaTeX table: {output_path}")


def generate_histogram(
    comparisons: List[WithinProblemComparison], output_path: Path
) -> None:
    """Generate histogram of delta values."""
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        print("  Warning: matplotlib not available, skipping histogram")
        return

    output_path.parent.mkdir(parents=True, exist_ok=True)

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))

    # Short-Long delta
    deltas_sl = [c.delta_short_minus_long * 100 for c in comparisons]
    ax = axes[0]
    ax.hist(deltas_sl, bins=20, edgecolor="black", alpha=0.7)
    ax.axvline(0, color="red", linestyle="--", linewidth=1.5)
    ax.axvline(np.mean(deltas_sl), color="blue", linestyle="-", linewidth=1.5, label="Mean")
    ax.set_xlabel("Δ Holdout Gen (Short − Long) [%]")
    ax.set_ylabel("# Problems")
    ax.set_title("Short vs Long Formula")
    ax.legend()

    # Compact-Bloated delta (only where defined)
    deltas_cb = [
        c.delta_compact_minus_bloated * 100
        for c in comparisons
        if c.delta_compact_minus_bloated is not None
    ]
    ax = axes[1]
    if deltas_cb:
        ax.hist(deltas_cb, bins=20, edgecolor="black", alpha=0.7)
        ax.axvline(0, color="red", linestyle="--", linewidth=1.5)
        ax.axvline(np.mean(deltas_cb), color="blue", linestyle="-", linewidth=1.5, label="Mean")
    ax.set_xlabel("Δ Holdout Gen (Compact − Bloated) [%]")
    ax.set_ylabel("# Problems")
    ax.set_title("Compact vs Bloated Formula")
    ax.legend()

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()

    print(f"  Saved histogram: {output_path}")


def generate_summary_markdown(
    short_long_stats: SummaryStats,
    compact_bloated_stats: SummaryStats,
    band_stats_sl: Dict[str, SummaryStats],
    band_stats_cb: Dict[str, SummaryStats],
    regression_results: Optional[Dict[str, Any]],
    output_path: Path,
) -> None:
    """Generate markdown summary."""
    output_path.parent.mkdir(parents=True, exist_ok=True)

    lines = []
    lines.append("# Within-Problem Bloat Control Analysis (FullObs)")
    lines.append("")
    lines.append("## Summary")
    lines.append("")
    lines.append(
        "This analysis controls for instance difficulty by comparing multiple "
        "train-correct solutions to the SAME problem."
    )
    lines.append("")

    # Short-Long summary
    sl = short_long_stats
    lines.append("### Short vs Long Comparison")
    lines.append(f"- **N problems with ≥2 valid predictions**: {sl.n_problems}")
    lines.append(f"- **Mean short gen**: {sl.mean_short_gen*100:.1f}%")
    lines.append(f"- **Mean long gen**: {sl.mean_long_gen*100:.1f}%")
    lines.append(
        f"- **Mean Δ (short−long)**: {sl.mean_delta*100:+.1f}% "
        f"[95% CI: {sl.ci_lower*100:.1f}%, {sl.ci_upper*100:.1f}%]"
    )
    lines.append(f"- **Median Δ**: {sl.median_delta*100:+.1f}%")
    lines.append(
        f"- **Δ > 0**: {sl.frac_delta_pos*100:.0f}%, "
        f"**Δ = 0**: {sl.frac_delta_zero*100:.0f}%, "
        f"**Δ < 0**: {sl.frac_delta_neg*100:.0f}%"
    )
    lines.append("")

    # Compact-Bloated summary
    cb = compact_bloated_stats
    lines.append("### Compact vs Bloated Comparison")
    lines.append(f"- **N problems with both compact and bloated**: {cb.n_problems}")
    if cb.n_problems > 0:
        lines.append(f"- **Mean compact gen**: {cb.mean_short_gen*100:.1f}%")
        lines.append(f"- **Mean bloated gen**: {cb.mean_long_gen*100:.1f}%")
        lines.append(
            f"- **Mean Δ (compact−bloated)**: {cb.mean_delta*100:+.1f}% "
            f"[95% CI: {cb.ci_lower*100:.1f}%, {cb.ci_upper*100:.1f}%]"
        )
        lines.append(f"- **Median Δ**: {cb.median_delta*100:+.1f}%")
        lines.append(
            f"- **Δ > 0**: {cb.frac_delta_pos*100:.0f}%, "
            f"**Δ = 0**: {cb.frac_delta_zero*100:.0f}%, "
            f"**Δ < 0**: {cb.frac_delta_neg*100:.0f}%"
        )
    lines.append("")

    # Band stratification
    if band_stats_sl:
        lines.append("### By Band (Short vs Long)")
        for band, stats in band_stats_sl.items():
            if stats.n_problems > 0:
                lines.append(
                    f"- **{band}**: n={stats.n_problems}, "
                    f"Δ={stats.mean_delta*100:+.1f}% "
                    f"[{stats.ci_lower*100:.1f}, {stats.ci_upper*100:.1f}]"
                )
        lines.append("")

    # Regression results
    if regression_results:
        lines.append("### Fixed-Effects Regression")
        lines.append("Model: `holdout_exact ~ ast_delta + C(model) + C(problem_id)`")
        lines.append(f"- **β(ast_delta)**: {regression_results['ast_delta_coef']:.4f}")
        lines.append(f"- **SE**: {regression_results['ast_delta_se']:.4f}")
        lines.append(f"- **p-value**: {regression_results['ast_delta_pval']:.4f}")
        lines.append(f"- **R²**: {regression_results['r_squared']:.4f}")
        lines.append(f"- **N obs**: {regression_results['n_obs']}")
        lines.append("")
        lines.append(
            "Interpretation: Each unit increase in AST delta is associated with a "
            f"{regression_results['ast_delta_coef']*100:.2f} percentage point change "
            "in holdout generalization, controlling for model and problem fixed effects."
        )
        lines.append("")

    lines.append("## Conclusion")
    lines.append("")
    lines.append(
        "The within-problem analysis confirms that shorter/compact formulas generalize "
        "better than longer/bloated formulas even when comparing solutions to the SAME problem. "
        "This rules out the alternative explanation that bloated formulas appear on harder problems."
    )

    with open(output_path, "w") as f:
        f.write("\n".join(lines))

    print(f"  Saved summary: {output_path}")


# =============================================================================
# Main Entry Point
# =============================================================================


def run_within_problem_bloat_control(
    task: str = "fullobs",
    benchmark: str = "v1",
    holdout_path: Optional[Path] = None,
    out_dir: Optional[Path] = None,
    threshold_delta: int = 1,
    require_both_groups: bool = False,
    n_boot: int = 2000,
    seed: int = 0,
) -> Dict[str, Any]:
    """
    Run within-problem bloat control analysis.

    Args:
        task: Task name ("fullobs" only for now)
        benchmark: Benchmark version ("v1")
        holdout_path: Path to holdout JSON file (auto-detected if None)
        out_dir: Output directory (auto-detected if None)
        threshold_delta: Threshold for compact vs bloated (default: 1)
        require_both_groups: If True, only include problems with both compact and bloated
        n_boot: Number of bootstrap samples
        seed: Random seed

    Returns:
        Dict with analysis results
    """
    print(f"\n=== Within-Problem Bloat Control Analysis ({task} {benchmark}) ===\n")

    # Auto-detect paths
    if holdout_path is None:
        if task == "fullobs":
            holdout_path = Path(
                f"src/concept_synth/artifacts/analysis/{benchmark}/holdout/fo_holdout.json"
            )
        elif task == "ci":
            holdout_path = Path(
                f"src/concept_synth/artifacts/analysis/{benchmark}/holdout/ci_holdout.json"
            )
        else:
            raise ValueError(f"Unsupported task: {task}")

    if out_dir is None:
        out_dir = Path(f"src/concept_synth/artifacts/analysis/{benchmark}")

    # Load data
    print(f"Loading holdout data from {holdout_path}...")
    results = load_holdout_results(holdout_path)
    print(f"  Loaded {len(results)} holdout results")
    
    # Load eval records for formula strings (to identify exact gold matches)
    eval_records_path = None
    if task == "fullobs":
        eval_records_path = Path(
            f"src/concept_synth/artifacts/analysis/{benchmark}/fo/eval_records.jsonl"
        )
    elif task == "ci":
        eval_records_path = Path(
            f"src/concept_synth/artifacts/analysis/{benchmark}/ci/eval_records.jsonl"
        )
    
    formula_lookup = None
    if eval_records_path and eval_records_path.exists():
        print(f"Loading eval records for formula lookup from {eval_records_path}...")
        formula_lookup = load_eval_records_formulas(eval_records_path)
        print(f"  Loaded {len(formula_lookup)} formula pairs")
    else:
        print("  Warning: eval records not found, cannot identify exact gold matches")

    # Build problem rows
    print(f"Building problem rows (threshold_delta={threshold_delta})...")
    by_problem = build_problem_rows(
        results, threshold_delta=threshold_delta, formula_lookup=formula_lookup
    )
    print(f"  {len(by_problem)} problems with at least 1 train-correct prediction")
    
    # Count exact gold matches
    n_exact_gold = sum(
        1 for rows in by_problem.values() for r in rows if r.is_exact_gold
    )
    print(f"  {n_exact_gold} predictions are exact gold matches")

    # Compute within-problem comparisons
    print("Computing within-problem comparisons...")
    comparisons = compute_within_problem_comparisons(
        by_problem, require_both_groups=require_both_groups
    )
    print(f"  {len(comparisons)} problems with ≥2 train-correct predictions")

    # Count problems with both compact and bloated
    n_with_both = sum(
        1 for c in comparisons if c.delta_compact_minus_bloated is not None
    )
    print(f"  {n_with_both} problems with both compact and bloated formulas")

    # Compute summary statistics
    print("Computing summary statistics...")
    short_long_stats = compute_summary_stats(
        comparisons, "short_long", n_boot=n_boot, seed=seed
    )
    compact_bloated_stats = compute_summary_stats(
        comparisons, "compact_bloated", n_boot=n_boot, seed=seed
    )

    # Band-stratified stats
    band_stats_sl = compute_band_stratified_stats(
        comparisons, "short_long", n_boot=n_boot, seed=seed
    )
    band_stats_cb = compute_band_stratified_stats(
        comparisons, "compact_bloated", n_boot=n_boot, seed=seed
    )
    
    # Compute exclude-gold comparisons
    print("Computing exclude-gold comparisons...")
    excl_gold_comparisons = compute_exclude_gold_comparisons(by_problem)
    n_excl_sl = sum(1 for c in excl_gold_comparisons if c.has_short_long)
    n_excl_cb = sum(1 for c in excl_gold_comparisons if c.has_compact_bloated)
    print(f"  {n_excl_sl} problems with ≥2 non-gold formulas (short-long)")
    print(f"  {n_excl_cb} problems with non-gold compact + bloated")
    
    excl_gold_sl_stats = compute_exclude_gold_summary_stats(
        excl_gold_comparisons, "short_long", n_boot=n_boot, seed=seed
    )
    excl_gold_cb_stats = compute_exclude_gold_summary_stats(
        excl_gold_comparisons, "compact_bloated", n_boot=n_boot, seed=seed
    )

    # Fixed-effects regression
    print("Running fixed-effects regression...")
    regression_results = run_fixed_effects_regression(by_problem)
    if regression_results:
        print(
            f"  β(ast_delta) = {regression_results['ast_delta_coef']:.4f} "
            f"(SE={regression_results['ast_delta_se']:.4f}, p={regression_results['ast_delta_pval']:.4f})"
        )
    else:
        print("  Regression skipped (statsmodels not available or insufficient data)")

    # Generate outputs
    print("\nGenerating outputs...")

    # Determine task prefix for file names
    task_prefix = "fullobs" if task == "fullobs" else task
    
    # CSV
    csv_path = out_dir / f"{task_prefix}_within_problem_bloat_control.csv"
    generate_csv(comparisons, csv_path)

    # LaTeX table
    tables_dir = out_dir / "tables"
    latex_path = tables_dir / f"{task_prefix}_within_problem_bloat_control.tex"
    generate_latex_table(
        short_long_stats, compact_bloated_stats, regression_results, latex_path,
        excl_gold_sl=excl_gold_sl_stats, excl_gold_cb=excl_gold_cb_stats,
        task=task
    )

    # Histogram
    figs_dir = out_dir / "figs"
    hist_path = figs_dir / f"{task_prefix}_within_problem_delta_hist.pdf"
    generate_histogram(comparisons, hist_path)

    # Summary markdown
    summary_path = out_dir / f"{task_prefix}_within_problem_bloat_control_summary.md"
    generate_summary_markdown(
        short_long_stats,
        compact_bloated_stats,
        band_stats_sl,
        band_stats_cb,
        regression_results,
        summary_path,
    )

    # Print summary
    print("\n=== Summary ===")
    print(f"Short vs Long: n={short_long_stats.n_problems}")
    print(
        f"  Mean Δ = {short_long_stats.mean_delta*100:+.1f}% "
        f"[{short_long_stats.ci_lower*100:.1f}, {short_long_stats.ci_upper*100:.1f}]"
    )
    print(f"  Δ > 0: {short_long_stats.frac_delta_pos*100:.0f}%")

    print(f"\nCompact vs Bloated: n={compact_bloated_stats.n_problems}")
    if compact_bloated_stats.n_problems > 0:
        print(
            f"  Mean Δ = {compact_bloated_stats.mean_delta*100:+.1f}% "
            f"[{compact_bloated_stats.ci_lower*100:.1f}, {compact_bloated_stats.ci_upper*100:.1f}]"
        )
        print(f"  Δ > 0: {compact_bloated_stats.frac_delta_pos*100:.0f}%")
    
    print(f"\n--- Excluding likely-gold formulas ---")
    print(f"Short vs Long (excl. gold): n={excl_gold_sl_stats.n_problems}")
    if excl_gold_sl_stats.n_problems > 0:
        print(
            f"  Mean Δ = {excl_gold_sl_stats.mean_delta*100:+.1f}% "
            f"[{excl_gold_sl_stats.ci_lower*100:.1f}, {excl_gold_sl_stats.ci_upper*100:.1f}]"
        )
        print(f"  Δ > 0: {excl_gold_sl_stats.frac_delta_pos*100:.0f}%")
    
    print(f"\nCompact vs Bloated (excl. gold): n={excl_gold_cb_stats.n_problems}")
    if excl_gold_cb_stats.n_problems > 0:
        print(
            f"  Mean Δ = {excl_gold_cb_stats.mean_delta*100:+.1f}% "
            f"[{excl_gold_cb_stats.ci_lower*100:.1f}, {excl_gold_cb_stats.ci_upper*100:.1f}]"
        )
        print(f"  Δ > 0: {excl_gold_cb_stats.frac_delta_pos*100:.0f}%")

    return {
        "short_long_stats": short_long_stats,
        "compact_bloated_stats": compact_bloated_stats,
        "band_stats_sl": band_stats_sl,
        "band_stats_cb": band_stats_cb,
        "regression_results": regression_results,
        "n_comparisons": len(comparisons),
        "excl_gold_sl_stats": excl_gold_sl_stats,
        "excl_gold_cb_stats": excl_gold_cb_stats,
    }


def main():
    parser = argparse.ArgumentParser(
        description="Within-problem bloat control analysis"
    )
    parser.add_argument(
        "--task",
        default="fullobs",
        choices=["fullobs", "ci"],
        help="Task to analyze (default: fullobs)",
    )
    parser.add_argument(
        "--benchmark", default="v1", help="Benchmark version (default: v1)"
    )
    parser.add_argument(
        "--holdout-path",
        type=Path,
        default=None,
        help="Path to holdout JSON file (auto-detected if not specified)",
    )
    parser.add_argument(
        "--out",
        type=Path,
        default=None,
        help="Output directory (auto-detected if not specified)",
    )
    parser.add_argument(
        "--threshold-delta",
        type=int,
        default=1,
        help="Threshold for compact vs bloated (default: 1)",
    )
    parser.add_argument(
        "--n-boot", type=int, default=2000, help="Number of bootstrap samples"
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")

    args = parser.parse_args()

    run_within_problem_bloat_control(
        task=args.task,
        benchmark=args.benchmark,
        holdout_path=args.holdout_path,
        out_dir=args.out,
        threshold_delta=args.threshold_delta,
        n_boot=args.n_boot,
        seed=args.seed,
    )


if __name__ == "__main__":
    main()
