#!/usr/bin/env python3
"""
Per-World FP/FN Profile Analysis

Computes detailed per-world mistake profiles for FO, CI, and EC tasks:
- For each model's formula and each world: FP/FN counts and rates
- Averaged per world and per band
- For CI: separate computation for YES and NO worlds
- NO margin = #mismatches on NO worlds (0 implies accidental exact-match NO fail)

Usage:
    python -m concept_synth.analysis.per_world_fpfn \
        --fo-dataset results/ad_benchmark/ad_benchmark_v1.yaml \
        --ci-dataset results/c_benchmark/c_benchmark_v1.yaml \
        --ec-dataset results/e_benchmark/e_benchmark_v1.yaml \
        --out artifacts/analysis/v1/per_world_fpfn/
"""

import argparse
import json
import os
import sys
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

# Bootstrap path
try:
    from concept_synth.bootstrap import add_repo_root
except ModuleNotFoundError:
    _path = os.path.abspath(__file__)
    while True:
        parent = os.path.dirname(_path)
        if os.path.basename(_path) == "concept_synth":
            if parent not in sys.path:
                sys.path.insert(0, parent)
            break
        if parent == _path:
            break
        _path = parent
    from concept_synth.bootstrap import add_repo_root
add_repo_root(__file__)

try:
    import pandas as pd

    HAS_PANDAS = True
except ImportError:
    HAS_PANDAS = False

from concept_synth.evaluate_ad import build_model_from_world
from concept_synth.fol.formulas import FOFormula
from concept_synth.fol.model import FiniteModel
from concept_synth.io_utils import load_from_yaml
from concept_synth.sexpr_parser import parse_sexpr_formula
from concept_synth.target import compute_target_extension

# =============================================================================
# Data Classes
# =============================================================================


@dataclass
class WorldMistakeProfile:
    """Mistake profile for a single world."""

    world_id: str
    world_type: str  # 'training', 'YES', 'NO' for CI
    domain_size: int

    # Counts
    fp_count: int  # False positives (predicted T but not gold T)
    fn_count: int  # False negatives (gold T but not predicted T)
    tp_count: int  # True positives
    tn_count: int  # True negatives

    # Rates (normalized by domain size)
    fp_rate: float
    fn_rate: float
    accuracy: float  # (TP + TN) / domain_size

    # For CI NO worlds: margin = total mismatches (want > 0 for valid NO)
    margin: Optional[int] = None  # Only for NO worlds


@dataclass
class InstanceMistakeProfile:
    """Mistake profile for a single instance (all worlds)."""

    instance_id: str
    model: str
    band: str
    task: str

    # Overall correctness
    correct: bool
    parse_ok: bool

    # Per-world profiles
    world_profiles: List[WorldMistakeProfile]

    # Aggregated stats
    mean_fp_rate: float
    mean_fn_rate: float
    mean_accuracy: float
    total_fp: int
    total_fn: int

    # CI-specific
    yes_mean_fp_rate: Optional[float] = None
    yes_mean_fn_rate: Optional[float] = None
    no_mean_margin: Optional[float] = None  # Mean mismatches on NO worlds
    no_exact_match_count: Optional[int] = None  # NO worlds with 0 mismatches (bad)


@dataclass
class ModelMistakeAggregates:
    """Aggregated mistake profile for a model."""

    model: str
    task: str

    # Instance counts
    total_instances: int
    parsed_instances: int
    correct_instances: int

    # Mean rates across all instances
    mean_fp_rate: float
    mean_fn_rate: float
    mean_accuracy: float

    # Per-band breakdown
    band_stats: Dict[str, Dict[str, float]]

    # CI-specific
    yes_mean_fp_rate: Optional[float] = None
    yes_mean_fn_rate: Optional[float] = None
    no_mean_margin: Optional[float] = None
    no_exact_match_rate: Optional[float] = None  # Fraction of NO worlds with exact match


# =============================================================================
# Core Analysis Functions
# =============================================================================


def compute_world_profile(
    model: FiniteModel,
    domain: List[str],
    predicted_formula: FOFormula,
    gold_t_true: set,
    world_id: str,
    world_type: str = "training",
) -> WorldMistakeProfile:
    """
    Compute FP/FN profile for a single world.

    Args:
        model: FiniteModel for the world
        domain: List of domain elements
        predicted_formula: Parsed predicted formula
        gold_t_true: Set of elements in gold T_true
        world_id: World identifier
        world_type: 'training', 'YES', or 'NO'

    Returns:
        WorldMistakeProfile with detailed counts and rates
    """
    domain_size = len(domain)
    gold_t_false = set(domain) - gold_t_true

    # Compute predicted extension
    try:
        pred_target = compute_target_extension(model, predicted_formula)
        pred_t_true = set(pred_target.T_true)
        pred_t_false = set(domain) - pred_t_true
    except Exception:
        # Formula evaluation failed - treat as all wrong
        return WorldMistakeProfile(
            world_id=world_id,
            world_type=world_type,
            domain_size=domain_size,
            fp_count=0,
            fn_count=len(gold_t_true),
            tp_count=0,
            tn_count=len(gold_t_false),
            fp_rate=0.0,
            fn_rate=1.0 if gold_t_true else 0.0,
            accuracy=len(gold_t_false) / domain_size if domain_size > 0 else 0.0,
            margin=domain_size if world_type == "NO" else None,
        )

    # Compute confusion matrix
    tp = len(pred_t_true & gold_t_true)
    fp = len(pred_t_true - gold_t_true)
    fn = len(gold_t_true - pred_t_true)
    tn = len(pred_t_false & gold_t_false)

    # Rates
    fp_rate = fp / domain_size if domain_size > 0 else 0.0
    fn_rate = fn / domain_size if domain_size > 0 else 0.0
    accuracy = (tp + tn) / domain_size if domain_size > 0 else 0.0

    # Margin for NO worlds (total mismatches)
    margin = None
    if world_type == "NO":
        margin = fp + fn  # Total mismatches; 0 means exact match (bad for NO)

    return WorldMistakeProfile(
        world_id=world_id,
        world_type=world_type,
        domain_size=domain_size,
        fp_count=fp,
        fn_count=fn,
        tp_count=tp,
        tn_count=tn,
        fp_rate=fp_rate,
        fn_rate=fn_rate,
        accuracy=accuracy,
        margin=margin,
    )


def analyze_fo_instance(
    problem: Dict[str, Any], llm_result: Dict[str, Any]
) -> Optional[InstanceMistakeProfile]:
    """
    Analyze FP/FN profile for a single FO (AD) instance.

    Args:
        problem: Problem dict with 'problem' and 'problemDescription'
        llm_result: LLM result dict

    Returns:
        InstanceMistakeProfile or None if no formula
    """
    prob = problem.get("problem", {})
    desc = problem.get("problemDescription", {})

    instance_id = prob.get("instanceId", "unknown")
    model_name = llm_result.get("model", "unknown")
    band = desc.get("ad_band", "unknown")

    # Parse formula
    extracted = llm_result.get("extractedFormula", "")
    if not extracted:
        return InstanceMistakeProfile(
            instance_id=instance_id,
            model=model_name,
            band=band,
            task="fo",
            correct=False,
            parse_ok=False,
            world_profiles=[],
            mean_fp_rate=0.0,
            mean_fn_rate=0.0,
            mean_accuracy=0.0,
            total_fp=0,
            total_fn=0,
        )

    try:
        formula = parse_sexpr_formula(extracted)
        parse_ok = True
    except Exception:
        return InstanceMistakeProfile(
            instance_id=instance_id,
            model=model_name,
            band=band,
            task="fo",
            correct=False,
            parse_ok=False,
            world_profiles=[],
            mean_fp_rate=0.0,
            mean_fn_rate=0.0,
            mean_accuracy=0.0,
            total_fp=0,
            total_fn=0,
        )

    # Analyze each world
    worlds = prob.get("worlds", [])
    world_profiles = []
    all_correct = True

    for world_dict in worlds:
        world_id = world_dict.get("worldId", "")
        domain = world_dict.get("domain", [])

        # Build model
        model = build_model_from_world(world_dict)

        # Get gold target
        target_ext = world_dict.get("targetExtension", {})
        gold_t_true = set(target_ext.get("T_true", []))

        # Compute profile
        profile = compute_world_profile(model, domain, formula, gold_t_true, world_id, "training")
        world_profiles.append(profile)

        if profile.accuracy < 1.0:
            all_correct = False

    # Aggregate
    if world_profiles:
        mean_fp_rate = sum(p.fp_rate for p in world_profiles) / len(world_profiles)
        mean_fn_rate = sum(p.fn_rate for p in world_profiles) / len(world_profiles)
        mean_accuracy = sum(p.accuracy for p in world_profiles) / len(world_profiles)
        total_fp = sum(p.fp_count for p in world_profiles)
        total_fn = sum(p.fn_count for p in world_profiles)
    else:
        mean_fp_rate = mean_fn_rate = mean_accuracy = 0.0
        total_fp = total_fn = 0

    return InstanceMistakeProfile(
        instance_id=instance_id,
        model=model_name,
        band=band,
        task="fo",
        correct=all_correct and mean_accuracy == 1.0,
        parse_ok=parse_ok,
        world_profiles=world_profiles,
        mean_fp_rate=mean_fp_rate,
        mean_fn_rate=mean_fn_rate,
        mean_accuracy=mean_accuracy,
        total_fp=total_fp,
        total_fn=total_fn,
    )


def analyze_ci_instance(
    problem: Dict[str, Any], llm_result: Dict[str, Any]
) -> Optional[InstanceMistakeProfile]:
    """
    Analyze FP/FN profile for a single CI (C) instance.

    Separately tracks YES and NO worlds:
    - YES worlds: standard FP/FN (want exact match)
    - NO worlds: compute margin (mismatches; want > 0)

    Args:
        problem: Problem dict
        llm_result: LLM result dict

    Returns:
        InstanceMistakeProfile with CI-specific fields
    """
    prob = problem.get("problem", {})
    desc = problem.get("problemDescription", {})

    instance_id = prob.get("instanceId", "unknown")
    model_name = llm_result.get("model", "unknown")
    band = desc.get("c_band", "unknown")

    # Parse formula
    extracted = llm_result.get("extractedFormula", "")
    if not extracted:
        return InstanceMistakeProfile(
            instance_id=instance_id,
            model=model_name,
            band=band,
            task="ci",
            correct=False,
            parse_ok=False,
            world_profiles=[],
            mean_fp_rate=0.0,
            mean_fn_rate=0.0,
            mean_accuracy=0.0,
            total_fp=0,
            total_fn=0,
            yes_mean_fp_rate=0.0,
            yes_mean_fn_rate=0.0,
            no_mean_margin=0.0,
            no_exact_match_count=0,
        )

    try:
        formula = parse_sexpr_formula(extracted)
        parse_ok = True
    except Exception:
        return InstanceMistakeProfile(
            instance_id=instance_id,
            model=model_name,
            band=band,
            task="ci",
            correct=False,
            parse_ok=False,
            world_profiles=[],
            mean_fp_rate=0.0,
            mean_fn_rate=0.0,
            mean_accuracy=0.0,
            total_fp=0,
            total_fn=0,
            yes_mean_fp_rate=0.0,
            yes_mean_fn_rate=0.0,
            no_mean_margin=0.0,
            no_exact_match_count=0,
        )

    # Analyze worlds by type
    worlds = prob.get("worlds", [])
    world_profiles = []
    yes_profiles = []
    no_profiles = []

    for world_dict in worlds:
        world_id = world_dict.get("worldId", "")
        # Determine world type from multiple possible sources
        world_type = world_dict.get("worldType")
        if world_type is None:
            # Check splitLabel
            split_label = world_dict.get("splitLabel", "")
            if split_label:
                world_type = split_label.upper()
            # Infer from worldId prefix
            elif world_id.lower().startswith("yes"):
                world_type = "YES"
            elif world_id.lower().startswith("no"):
                world_type = "NO"
            else:
                world_type = "YES"  # Default
        domain = world_dict.get("domain", [])

        # Build model
        model = build_model_from_world(world_dict)

        # Get gold target
        target_ext = world_dict.get("targetExtension", {})
        gold_t_true = set(target_ext.get("T_true", []))

        # Compute profile
        profile = compute_world_profile(model, domain, formula, gold_t_true, world_id, world_type)
        world_profiles.append(profile)

        if world_type == "YES":
            yes_profiles.append(profile)
        else:
            no_profiles.append(profile)

    # Aggregate YES worlds
    if yes_profiles:
        yes_mean_fp_rate = sum(p.fp_rate for p in yes_profiles) / len(yes_profiles)
        yes_mean_fn_rate = sum(p.fn_rate for p in yes_profiles) / len(yes_profiles)
        yes_all_match = all(p.accuracy == 1.0 for p in yes_profiles)
    else:
        yes_mean_fp_rate = yes_mean_fn_rate = 0.0
        yes_all_match = True

    # Aggregate NO worlds
    if no_profiles:
        no_mean_margin = sum(p.margin or 0 for p in no_profiles) / len(no_profiles)
        no_exact_match_count = sum(1 for p in no_profiles if p.margin == 0)
        no_all_avoid = all((p.margin or 0) > 0 for p in no_profiles)
    else:
        no_mean_margin = 0.0
        no_exact_match_count = 0
        no_all_avoid = True

    # Overall aggregates
    if world_profiles:
        mean_fp_rate = sum(p.fp_rate for p in world_profiles) / len(world_profiles)
        mean_fn_rate = sum(p.fn_rate for p in world_profiles) / len(world_profiles)
        mean_accuracy = sum(p.accuracy for p in world_profiles) / len(world_profiles)
        total_fp = sum(p.fp_count for p in world_profiles)
        total_fn = sum(p.fn_count for p in world_profiles)
    else:
        mean_fp_rate = mean_fn_rate = mean_accuracy = 0.0
        total_fp = total_fn = 0

    correct = yes_all_match and no_all_avoid

    return InstanceMistakeProfile(
        instance_id=instance_id,
        model=model_name,
        band=band,
        task="ci",
        correct=correct,
        parse_ok=parse_ok,
        world_profiles=world_profiles,
        mean_fp_rate=mean_fp_rate,
        mean_fn_rate=mean_fn_rate,
        mean_accuracy=mean_accuracy,
        total_fp=total_fp,
        total_fn=total_fn,
        yes_mean_fp_rate=yes_mean_fp_rate,
        yes_mean_fn_rate=yes_mean_fn_rate,
        no_mean_margin=no_mean_margin,
        no_exact_match_count=no_exact_match_count,
    )


def analyze_ec_instance(
    problem: Dict[str, Any], llm_result: Dict[str, Any]
) -> Optional[InstanceMistakeProfile]:
    """
    Analyze FP/FN profile for a single EC (E) instance.

    Note: EC uses existential completion semantics, so "mistakes" are computed
    against the gold target extension, but validity requires a completion to exist.
    This function computes the profile assuming the best completion.

    Args:
        problem: Problem dict
        llm_result: LLM result dict

    Returns:
        InstanceMistakeProfile
    """
    # For EC, we use the same structure as FO but note that
    # the semantics are different (exists completion)
    prob = problem.get("problem", {})
    desc = problem.get("problemDescription", {})

    instance_id = prob.get("instanceId", "unknown")
    model_name = llm_result.get("model", "unknown")
    band = desc.get("e_band", "unknown")

    # Parse formula
    extracted = llm_result.get("extractedFormula", "")
    if not extracted:
        return InstanceMistakeProfile(
            instance_id=instance_id,
            model=model_name,
            band=band,
            task="ec",
            correct=False,
            parse_ok=False,
            world_profiles=[],
            mean_fp_rate=0.0,
            mean_fn_rate=0.0,
            mean_accuracy=0.0,
            total_fp=0,
            total_fn=0,
        )

    try:
        formula = parse_sexpr_formula(extracted)
        parse_ok = True
    except Exception:
        return InstanceMistakeProfile(
            instance_id=instance_id,
            model=model_name,
            band=band,
            task="ec",
            correct=False,
            parse_ok=False,
            world_profiles=[],
            mean_fp_rate=0.0,
            mean_fn_rate=0.0,
            mean_accuracy=0.0,
            total_fp=0,
            total_fn=0,
        )

    # For EC, we compute profile against gold target
    # (actual validity check requires Z3, done elsewhere)
    worlds = prob.get("worlds", [])
    world_profiles = []

    for world_dict in worlds:
        world_id = world_dict.get("worldId", "")
        domain = world_dict.get("domain", [])

        # Build model (with known facts only for EC)
        model = build_model_from_world(world_dict)

        # Get gold target
        target_ext = world_dict.get("targetExtension", {})
        gold_t_true = set(target_ext.get("T_true", []))

        # Compute profile
        profile = compute_world_profile(model, domain, formula, gold_t_true, world_id, "training")
        world_profiles.append(profile)

    # Aggregate
    if world_profiles:
        mean_fp_rate = sum(p.fp_rate for p in world_profiles) / len(world_profiles)
        mean_fn_rate = sum(p.fn_rate for p in world_profiles) / len(world_profiles)
        mean_accuracy = sum(p.accuracy for p in world_profiles) / len(world_profiles)
        total_fp = sum(p.fp_count for p in world_profiles)
        total_fn = sum(p.fn_count for p in world_profiles)
    else:
        mean_fp_rate = mean_fn_rate = mean_accuracy = 0.0
        total_fp = total_fn = 0

    # Note: correct flag here is approximate; actual EC validity requires Z3
    return InstanceMistakeProfile(
        instance_id=instance_id,
        model=model_name,
        band=band,
        task="ec",
        correct=mean_accuracy == 1.0,  # Approximate
        parse_ok=parse_ok,
        world_profiles=world_profiles,
        mean_fp_rate=mean_fp_rate,
        mean_fn_rate=mean_fn_rate,
        mean_accuracy=mean_accuracy,
        total_fp=total_fp,
        total_fn=total_fn,
    )


# =============================================================================
# Aggregation Functions
# =============================================================================


def aggregate_model_profiles(
    profiles: List[InstanceMistakeProfile], task: str
) -> Dict[str, ModelMistakeAggregates]:
    """
    Aggregate instance profiles by model.

    Args:
        profiles: List of InstanceMistakeProfile
        task: 'fo', 'ci', or 'ec'

    Returns:
        Dict mapping model name to ModelMistakeAggregates
    """
    by_model = defaultdict(list)
    for p in profiles:
        by_model[p.model].append(p)

    result = {}
    for model, model_profiles in by_model.items():
        total = len(model_profiles)
        parsed = sum(1 for p in model_profiles if p.parse_ok)
        correct = sum(1 for p in model_profiles if p.correct)

        # Mean rates (only for parsed instances)
        parsed_profiles = [p for p in model_profiles if p.parse_ok]
        if parsed_profiles:
            mean_fp = sum(p.mean_fp_rate for p in parsed_profiles) / len(parsed_profiles)
            mean_fn = sum(p.mean_fn_rate for p in parsed_profiles) / len(parsed_profiles)
            mean_acc = sum(p.mean_accuracy for p in parsed_profiles) / len(parsed_profiles)
        else:
            mean_fp = mean_fn = mean_acc = 0.0

        # Per-band breakdown
        band_stats = defaultdict(
            lambda: {
                "count": 0,
                "parsed": 0,
                "correct": 0,
                "mean_fp": 0.0,
                "mean_fn": 0.0,
                "mean_acc": 0.0,
            }
        )
        for p in model_profiles:
            band_stats[p.band]["count"] += 1
            if p.parse_ok:
                band_stats[p.band]["parsed"] += 1
                band_stats[p.band]["mean_fp"] += p.mean_fp_rate
                band_stats[p.band]["mean_fn"] += p.mean_fn_rate
                band_stats[p.band]["mean_acc"] += p.mean_accuracy
            if p.correct:
                band_stats[p.band]["correct"] += 1

        # Normalize band stats
        for band, stats in band_stats.items():
            if stats["parsed"] > 0:
                stats["mean_fp"] /= stats["parsed"]
                stats["mean_fn"] /= stats["parsed"]
                stats["mean_acc"] /= stats["parsed"]

        # CI-specific
        yes_mean_fp = yes_mean_fn = no_mean_margin = no_exact_rate = None
        if task == "ci":
            ci_parsed = [p for p in parsed_profiles if p.yes_mean_fp_rate is not None]
            if ci_parsed:
                yes_mean_fp = sum(p.yes_mean_fp_rate for p in ci_parsed) / len(ci_parsed)
                yes_mean_fn = sum(p.yes_mean_fn_rate for p in ci_parsed) / len(ci_parsed)
                no_mean_margin = sum(p.no_mean_margin or 0 for p in ci_parsed) / len(ci_parsed)
                total_no_exact = sum(p.no_exact_match_count or 0 for p in ci_parsed)
                total_no_worlds = sum(
                    sum(1 for wp in p.world_profiles if wp.world_type == "NO") for p in ci_parsed
                )
                no_exact_rate = total_no_exact / total_no_worlds if total_no_worlds > 0 else 0.0

        result[model] = ModelMistakeAggregates(
            model=model,
            task=task,
            total_instances=total,
            parsed_instances=parsed,
            correct_instances=correct,
            mean_fp_rate=mean_fp,
            mean_fn_rate=mean_fn,
            mean_accuracy=mean_acc,
            band_stats=dict(band_stats),
            yes_mean_fp_rate=yes_mean_fp,
            yes_mean_fn_rate=yes_mean_fn,
            no_mean_margin=no_mean_margin,
            no_exact_match_rate=no_exact_rate,
        )

    return result


# =============================================================================
# Main Analysis Pipeline
# =============================================================================


def run_fpfn_analysis(
    fo_dataset: Optional[Path],
    ci_dataset: Optional[Path],
    ec_dataset: Optional[Path],
    outdir: Path,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Run the full FP/FN analysis pipeline.

    Args:
        fo_dataset: Path to FO (AD) benchmark YAML
        ci_dataset: Path to CI (C) benchmark YAML
        ec_dataset: Path to EC (E) benchmark YAML
        outdir: Output directory
        verbose: Print progress

    Returns:
        Summary dict with all results
    """
    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    results = {
        "fo": None,
        "ci": None,
        "ec": None,
    }

    # Process FO
    if fo_dataset and Path(fo_dataset).exists():
        if verbose:
            print(f"\n[FO] Loading {fo_dataset}...")
        problems = load_from_yaml(str(fo_dataset))

        # Filter to AD scenario
        fo_problems = [p for p in problems if p.get("problem", {}).get("scenario") == "AD"]
        if verbose:
            print(f"[FO] Found {len(fo_problems)} AD problems")

        fo_profiles = []
        for problem in fo_problems:
            for llm_result in problem.get("llmResults", []):
                profile = analyze_fo_instance(problem, llm_result)
                if profile:
                    fo_profiles.append(profile)

        if verbose:
            print(f"[FO] Analyzed {len(fo_profiles)} instance-model pairs")

        fo_aggregates = aggregate_model_profiles(fo_profiles, "fo")
        results["fo"] = {
            "profiles": [asdict(p) for p in fo_profiles],
            "aggregates": {k: asdict(v) for k, v in fo_aggregates.items()},
        }

        # Save FO results
        with open(outdir / "fo_fpfn_profiles.json", "w") as f:
            json.dump(results["fo"], f, indent=2, default=str)
        if verbose:
            print(f"[FO] Saved to {outdir / 'fo_fpfn_profiles.json'}")

    # Process CI
    if ci_dataset and Path(ci_dataset).exists():
        if verbose:
            print(f"\n[CI] Loading {ci_dataset}...")
        problems = load_from_yaml(str(ci_dataset))

        # Filter to C scenario
        ci_problems = [p for p in problems if p.get("problem", {}).get("scenario") == "C"]
        if verbose:
            print(f"[CI] Found {len(ci_problems)} C problems")

        ci_profiles = []
        for problem in ci_problems:
            for llm_result in problem.get("llmResults", []):
                profile = analyze_ci_instance(problem, llm_result)
                if profile:
                    ci_profiles.append(profile)

        if verbose:
            print(f"[CI] Analyzed {len(ci_profiles)} instance-model pairs")

        ci_aggregates = aggregate_model_profiles(ci_profiles, "ci")
        results["ci"] = {
            "profiles": [asdict(p) for p in ci_profiles],
            "aggregates": {k: asdict(v) for k, v in ci_aggregates.items()},
        }

        # Save CI results
        with open(outdir / "ci_fpfn_profiles.json", "w") as f:
            json.dump(results["ci"], f, indent=2, default=str)
        if verbose:
            print(f"[CI] Saved to {outdir / 'ci_fpfn_profiles.json'}")

    # Process EC
    if ec_dataset and Path(ec_dataset).exists():
        if verbose:
            print(f"\n[EC] Loading {ec_dataset}...")
        problems = load_from_yaml(str(ec_dataset))

        # Filter to E scenario
        ec_problems = [p for p in problems if p.get("problem", {}).get("scenario") == "E"]
        if verbose:
            print(f"[EC] Found {len(ec_problems)} E problems")

        ec_profiles = []
        for problem in ec_problems:
            for llm_result in problem.get("llmResults", []):
                profile = analyze_ec_instance(problem, llm_result)
                if profile:
                    ec_profiles.append(profile)

        if verbose:
            print(f"[EC] Analyzed {len(ec_profiles)} instance-model pairs")

        ec_aggregates = aggregate_model_profiles(ec_profiles, "ec")
        results["ec"] = {
            "profiles": [asdict(p) for p in ec_profiles],
            "aggregates": {k: asdict(v) for k, v in ec_aggregates.items()},
        }

        # Save EC results
        with open(outdir / "ec_fpfn_profiles.json", "w") as f:
            json.dump(results["ec"], f, indent=2, default=str)
        if verbose:
            print(f"[EC] Saved to {outdir / 'ec_fpfn_profiles.json'}")

    # Generate summary report
    generate_fpfn_report(results, outdir, verbose)

    # Generate LaTeX tables
    generate_latex_tables(results, outdir, verbose)

    return results


def generate_latex_tables(results: Dict[str, Any], outdir: Path, verbose: bool = True) -> None:
    """
    Generate LaTeX tables for paper appendix.
    """
    tables_dir = outdir / "latex"
    tables_dir.mkdir(parents=True, exist_ok=True)

    MODEL_DISPLAY = {
        "grok4": "Grok4",
        "gpt-5.2": "GPT-5.2",
        "grok4.1fast": "Grok4.1f",
        "gemini-3-pro-preview": "Gemini 3",
        "gemini-3-pro-pr": "Gemini 3",
        "deepseek-reasoner": "DSR",
        "deepseek-reason": "DSR",
        "claude-opus-4-5": "Opus 4.5",
        "claude-opus-4-5-20251101": "Opus 4.5",
        "hermes4": "Hermes4",
        "gpt-4o": "GPT-4o",
    }

    def fmt_pct(val: float, bold: bool = False) -> str:
        pct = f"{val*100:.1f}\\%"
        return f"\\textbf{{{pct}}}" if bold else pct

    # FO table
    if results.get("fo"):
        aggregates = results["fo"].get("aggregates", {})
        if aggregates:
            lines = []
            lines.append("% FO Per-World FP/FN Table (auto-generated)")
            lines.append("\\begin{table}[h]")
            lines.append("\\centering")
            lines.append(
                "\\caption{FullObs per-world FP/FN rates. Mean rates computed across parsed instances.}"
            )
            lines.append("\\label{tab:fo_fpfn}")
            lines.append("\\small")
            lines.append("\\begin{tabular}{@{}lrrrrr@{}}")
            lines.append("\\toprule")
            lines.append("Model & Parsed & Correct & FP\\% & FN\\% & Acc\\% \\\\")
            lines.append("\\midrule")

            for model, agg in sorted(
                aggregates.items(), key=lambda x: x[1]["mean_accuracy"], reverse=True
            ):
                display = MODEL_DISPLAY.get(model, model[:10])
                lines.append(
                    f"{display} & {agg['parsed_instances']} & {agg['correct_instances']} & "
                    f"{fmt_pct(agg['mean_fp_rate'])} & {fmt_pct(agg['mean_fn_rate'])} & "
                    f"{fmt_pct(agg['mean_accuracy'])} \\\\"
                )

            lines.append("\\bottomrule")
            lines.append("\\end{tabular}")
            lines.append("\\end{table}")

            with open(tables_dir / "fo_fpfn.tex", "w") as f:
                f.write("\n".join(lines))
            if verbose:
                print(f"Generated {tables_dir / 'fo_fpfn.tex'}")

    # CI table
    if results.get("ci"):
        aggregates = results["ci"].get("aggregates", {})
        if aggregates:
            lines = []
            lines.append("% CI Per-World FP/FN Table (auto-generated)")
            lines.append("\\begin{table}[h]")
            lines.append("\\centering")
            lines.append(
                "\\caption{CI per-world error rates. YES columns show FP/FN on positive examples;"
            )
            lines.append("NO Margin shows mean mismatches on negative examples (higher = better);")
            lines.append(
                "NO Exact shows fraction of NO worlds with accidental exact match (lower = better).}"
            )
            lines.append("\\label{tab:ci_fpfn}")
            lines.append("\\small")
            lines.append("\\begin{tabular}{@{}lrrrrrr@{}}")
            lines.append("\\toprule")
            lines.append("Model & Correct & YES FP\\% & YES FN\\% & NO Margin & NO Exact\\% \\\\")
            lines.append("\\midrule")

            for model, agg in sorted(
                aggregates.items(), key=lambda x: x[1]["correct_instances"], reverse=True
            ):
                display = MODEL_DISPLAY.get(model, model[:10])
                no_exact = agg.get("no_exact_match_rate", 0) or 0
                lines.append(
                    f"{display} & {agg['correct_instances']} & "
                    f"{fmt_pct(agg.get('yes_mean_fp_rate', 0) or 0)} & "
                    f"{fmt_pct(agg.get('yes_mean_fn_rate', 0) or 0)} & "
                    f"{agg.get('no_mean_margin', 0) or 0:.2f} & "
                    f"{fmt_pct(no_exact)} \\\\"
                )

            lines.append("\\bottomrule")
            lines.append("\\end{tabular}")
            lines.append("\\end{table}")

            with open(tables_dir / "ci_fpfn.tex", "w") as f:
                f.write("\n".join(lines))
            if verbose:
                print(f"Generated {tables_dir / 'ci_fpfn.tex'}")

    # EC table
    if results.get("ec"):
        aggregates = results["ec"].get("aggregates", {})
        if aggregates:
            lines = []
            lines.append("% EC Per-World FP/FN Table (auto-generated)")
            lines.append("\\begin{table}[h]")
            lines.append("\\centering")
            lines.append(
                "\\caption{EC per-world FP/FN rates (against gold target, not accounting for completion).}"
            )
            lines.append("\\label{tab:ec_fpfn}")
            lines.append("\\small")
            lines.append("\\begin{tabular}{@{}lrrrrr@{}}")
            lines.append("\\toprule")
            lines.append("Model & Parsed & Correct & FP\\% & FN\\% & Acc\\% \\\\")
            lines.append("\\midrule")

            for model, agg in sorted(
                aggregates.items(), key=lambda x: x[1]["mean_accuracy"], reverse=True
            ):
                display = MODEL_DISPLAY.get(model, model[:10])
                lines.append(
                    f"{display} & {agg['parsed_instances']} & {agg['correct_instances']} & "
                    f"{fmt_pct(agg['mean_fp_rate'])} & {fmt_pct(agg['mean_fn_rate'])} & "
                    f"{fmt_pct(agg['mean_accuracy'])} \\\\"
                )

            lines.append("\\bottomrule")
            lines.append("\\end{tabular}")
            lines.append("\\end{table}")

            with open(tables_dir / "ec_fpfn.tex", "w") as f:
                f.write("\n".join(lines))
            if verbose:
                print(f"Generated {tables_dir / 'ec_fpfn.tex'}")


def format_aligned_table(
    headers: List[str], rows: List[List[str]], alignments: Optional[List[str]] = None
) -> str:
    """
    Format a markdown table with proper column alignment.

    Args:
        headers: List of header strings
        rows: List of rows, each row is a list of cell strings
        alignments: List of 'l', 'r', or 'c' for each column (default: 'l' for first, 'r' for rest)

    Returns:
        Formatted markdown table string
    """
    if not headers or not rows:
        return ""

    num_cols = len(headers)
    if alignments is None:
        alignments = ["l"] + ["r"] * (num_cols - 1)

    # Compute column widths
    col_widths = [len(h) for h in headers]
    for row in rows:
        for i, cell in enumerate(row):
            if i < len(col_widths):
                col_widths[i] = max(col_widths[i], len(str(cell)))

    # Format header
    header_cells = []
    for i, h in enumerate(headers):
        if alignments[i] == "r":
            header_cells.append(h.rjust(col_widths[i]))
        elif alignments[i] == "c":
            header_cells.append(h.center(col_widths[i]))
        else:
            header_cells.append(h.ljust(col_widths[i]))
    header_line = "| " + " | ".join(header_cells) + " |"

    # Format separator
    sep_cells = []
    for i, w in enumerate(col_widths):
        if alignments[i] == "r":
            sep_cells.append("-" * (w - 1) + ":")
        elif alignments[i] == "c":
            sep_cells.append(":" + "-" * (w - 2) + ":")
        else:
            sep_cells.append("-" * w)
    sep_line = "| " + " | ".join(sep_cells) + " |"

    # Format rows
    row_lines = []
    for row in rows:
        row_cells = []
        for i, cell in enumerate(row):
            cell_str = str(cell)
            if i < len(col_widths):
                if alignments[i] == "r":
                    row_cells.append(cell_str.rjust(col_widths[i]))
                elif alignments[i] == "c":
                    row_cells.append(cell_str.center(col_widths[i]))
                else:
                    row_cells.append(cell_str.ljust(col_widths[i]))
            else:
                row_cells.append(cell_str)
        row_lines.append("| " + " | ".join(row_cells) + " |")

    return "\n".join([header_line, sep_line] + row_lines)


def generate_fpfn_report(results: Dict[str, Any], outdir: Path, verbose: bool = True) -> None:
    """
    Generate a markdown summary report of FP/FN analysis.
    """
    lines = []
    lines.append("# Per-World FP/FN Profile Analysis")
    lines.append("")
    lines.append("This report summarizes per-world mistake profiles across models.")
    lines.append("")
    lines.append("## Metric Definitions")
    lines.append("")
    lines.append(
        "For each world, we compare the predicted target extension against the gold target:"
    )
    lines.append("")
    lines.append(
        "- **FP (False Positive)**: Elements predicted as T(x)=true but gold is T(x)=false"
    )
    lines.append(
        "- **FN (False Negative)**: Elements with gold T(x)=true but predicted as T(x)=false"
    )
    lines.append("- **FP%/FN%**: FP/FN count divided by domain size")
    lines.append("- **Accuracy**: (TP + TN) / domain size")
    lines.append("")
    lines.append("For CI (Contrastive):")
    lines.append("- **YES worlds**: Must exactly match gold (FP=0, FN=0)")
    lines.append("- **NO worlds**: Must NOT exactly match gold (FP+FN > 0)")
    lines.append("- **YES FP%/FN%**: Error rates on YES worlds only (lower = better)")
    lines.append("- **NO Margin**: Mean mismatches (FP+FN) on NO worlds (higher = better)")
    lines.append(
        "- **NO Exact%**: Fraction of NO worlds with accidental exact match (lower = better)"
    )
    lines.append("")

    for task in ["fo", "ci", "ec"]:
        if results.get(task) is None:
            continue

        task_name = {
            "fo": "FullObs (FO)",
            "ci": "Contrastive (CI)",
            "ec": "Existential Completion (EC)",
        }[task]
        lines.append(f"## {task_name}")
        lines.append("")

        aggregates = results[task].get("aggregates", {})
        if not aggregates:
            lines.append("No data available.")
            lines.append("")
            continue

        # Overall table
        lines.append("### Overall Model Performance")
        lines.append("")

        if task == "ci":
            headers = [
                "Model",
                "Total",
                "Parsed",
                "Correct",
                "YES FP%",
                "YES FN%",
                "NO Margin",
                "NO Exact%",
            ]
            alignments = ["l"] + ["r"] * 7
            rows = []
            for model, agg in sorted(aggregates.items()):
                rows.append(
                    [
                        model[:15],
                        str(agg["total_instances"]),
                        str(agg["parsed_instances"]),
                        str(agg["correct_instances"]),
                        f"{(agg['yes_mean_fp_rate'] or 0)*100:.1f}",
                        f"{(agg['yes_mean_fn_rate'] or 0)*100:.1f}",
                        f"{agg['no_mean_margin'] or 0:.2f}",
                        f"{(agg['no_exact_match_rate'] or 0)*100:.1f}",
                    ]
                )
            lines.append(format_aligned_table(headers, rows, alignments))
        else:
            headers = ["Model", "Total", "Parsed", "Correct", "Mean FP%", "Mean FN%", "Mean Acc%"]
            alignments = ["l"] + ["r"] * 6
            rows = []
            for model, agg in sorted(aggregates.items()):
                rows.append(
                    [
                        model[:15],
                        str(agg["total_instances"]),
                        str(agg["parsed_instances"]),
                        str(agg["correct_instances"]),
                        f"{agg['mean_fp_rate']*100:.1f}",
                        f"{agg['mean_fn_rate']*100:.1f}",
                        f"{agg['mean_accuracy']*100:.1f}",
                    ]
                )
            lines.append(format_aligned_table(headers, rows, alignments))

        lines.append("")

        # Per-band breakdown
        lines.append("### Per-Band Breakdown")
        lines.append("")

        all_bands = set()
        for agg in aggregates.values():
            all_bands.update(agg.get("band_stats", {}).keys())

        for band in sorted(all_bands):
            lines.append(f"#### Band: {band}")
            lines.append("")

            headers = ["Model", "Count", "Correct", "Mean FP%", "Mean FN%"]
            alignments = ["l", "r", "r", "r", "r"]
            rows = []

            for model, agg in sorted(aggregates.items()):
                band_stats = agg.get("band_stats", {}).get(band, {})
                if band_stats:
                    rows.append(
                        [
                            model[:15],
                            str(band_stats.get("count", 0)),
                            str(band_stats.get("correct", 0)),
                            f"{band_stats.get('mean_fp', 0)*100:.1f}",
                            f"{band_stats.get('mean_fn', 0)*100:.1f}",
                        ]
                    )

            if rows:
                lines.append(format_aligned_table(headers, rows, alignments))
            lines.append("")

        lines.append("")

    # Write report
    report_path = outdir / "fpfn_report.md"
    with open(report_path, "w") as f:
        f.write("\n".join(lines))

    if verbose:
        print(f"\nGenerated report: {report_path}")


# =============================================================================
# CLI
# =============================================================================


def main():
    parser = argparse.ArgumentParser(
        description="Per-World FP/FN Profile Analysis",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    # Analyze all three tasks
    python -m concept_synth.analysis.per_world_fpfn \\
        --fo-dataset results/ad_benchmark/ad_benchmark_v1.yaml \\
        --ci-dataset results/c_benchmark/c_benchmark_v1.yaml \\
        --ec-dataset results/e_benchmark/e_benchmark_v1.yaml \\
        --out artifacts/analysis/v1/per_world_fpfn/
    
    # Analyze only FO
    python -m concept_synth.analysis.per_world_fpfn \\
        --fo-dataset results/ad_benchmark/ad_benchmark_v1.yaml \\
        --out artifacts/analysis/v1/per_world_fpfn/
        """,
    )

    parser.add_argument("--fo-dataset", help="Path to FO (AD) benchmark YAML")
    parser.add_argument("--ci-dataset", help="Path to CI (C) benchmark YAML")
    parser.add_argument("--ec-dataset", help="Path to EC (E) benchmark YAML")
    parser.add_argument("--out", "-o", required=True, help="Output directory for results")
    parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output")

    args = parser.parse_args()

    run_fpfn_analysis(
        fo_dataset=Path(args.fo_dataset) if args.fo_dataset else None,
        ci_dataset=Path(args.ci_dataset) if args.ci_dataset else None,
        ec_dataset=Path(args.ec_dataset) if args.ec_dataset else None,
        outdir=Path(args.out),
        verbose=not args.quiet,
    )


if __name__ == "__main__":
    main()
