"""
Full experiment runner for Overcooked JO evaluation.

Supports: NO, JO_static, JO_dynamic conditions.
Collects: VR by constraint, edit rates, streak lengths, rewards.
"""

import json
from collections import defaultdict
from dataclasses import dataclass, field, asdict
from typing import Dict, List, Optional
from pathlib import Path
import time

from .adapter import create_adapter, OvercookedAdapter
from .constraints import TaskSpec, Action, SymbolicState
from .llm_agent import LLMAgent, MockLLMAgent, ReflexionAgent, CRITICAgent, LlamaGuardAgent, ConstrainedDecodingAgent
from .judgment_operator import JudgmentOperator, NoOperator, RuleBasedOperator, OperatorOutcome, EpisodeMetrics


@dataclass
class ExperimentConfig:
    """Configuration for an experiment run."""
    name: str
    condition: str  # "NO", "JO_static", "JO_dynamic" (NO = No Operator baseline)
    task_spec: TaskSpec
    layout: str = "counter_circuit"
    n_episodes: int = 30
    horizon: int = 100
    agent_type: str = "llm"
    model: str = "gpt-4o-mini"
    backend: str = "openai"  # "openai" or "anthropic"
    seed: int = 42
    use_nav_hints: bool = True  # If False, give simple goals without directions
    n_agents: int = 1  # Number of LLM-controlled agents (1=K1 single, 2=K2 dual)

    # Theta ablation settings (for P vs P+Theta comparison)
    enable_theta_updates: bool = False  # If True, enables online theta learning
    theta_lambda: float = 1.0  # Weight for theta penalty term
    theta_eta: float = 0.01  # Learning rate for theta updates
    theta_max_norm: float = 5.0  # Clip bound [-B, B] for theta


@dataclass
class EpisodeResult:
    """Results from a single episode."""
    episode_id: int
    condition: str
    total_steps: int
    total_reward: float
    soups_delivered: int

    # Violation metrics
    violations_total: int
    violations_by_id: Dict[str, int]
    violation_rate: float

    # JO metrics (only for JO conditions)
    edits_total: int
    edits_by_type: Dict[str, int]
    intervention_rate: float

    # Stuckness metrics
    max_streak: int
    avg_streak: float

    # H1 Recovery metrics
    h1_fires: int = 0  # Number of H1 violations
    h1_recovery_times: List[int] = None  # Steps to escape after each H1
    h1_mean_recovery: float = 0.0  # Mean steps to recover from H1
    h1_escape_rate: float = 0.0  # % of H1 that escaped within N steps

    # JO-guided metrics (repair sequences after violations)
    repair_seqs_given: int = 0  # Number of repair sequences provided
    repair_steps_followed: int = 0  # Steps from sequences executed
    guided_soups: int = 0  # Soups delivered during guided recovery

    # Progress metrics (primary outcome - replaces soups)
    max_pot_fill: int = 0  # Highest #onions in any pot during episode
    cooking_started: bool = False  # Episode ever reached is_cooking=True
    time_to_fill: int = -1  # Steps to reach required onions in pot (-1 = never)
    h1_to_pot_success: int = -1  # Steps from H1 to successful pot interaction (-1 = never)

    # Utility metrics (shows progress even without task success)
    onion_pickups: int = 0  # Times agent picked up onion
    pot_deposits: int = 0  # Times onion deposited in pot
    unique_positions: int = 0  # Unique grid positions visited
    non_stay_actions: int = 0  # Actions that weren't STAY (throughput)

    # Timing
    duration_seconds: float = 0.0


@dataclass
class ExperimentResult:
    """Aggregated results from an experiment."""
    config: ExperimentConfig
    episodes: List[EpisodeResult]

    # Aggregates
    mean_reward: float = 0.0
    mean_soups: float = 0.0
    mean_vr: float = 0.0
    mean_ir: float = 0.0
    vr_by_constraint: Dict[str, float] = field(default_factory=dict)

    # RBR retry statistics (only populated for RBR condition)
    rbr_mean_retries: float = 0.0
    rbr_max_retries: int = 0
    rbr_pct_hitting_max_k: float = 0.0
    rbr_total_regenerations: int = 0


def run_episode(env, adapter: OvercookedAdapter, agent, operator,
                agent_id: int, horizon: int, episode_id: int,
                condition: str, task_spec: "TaskSpec" = None,
                verbose: bool = False, use_hints: bool = True) -> EpisodeResult:
    """Run a single episode with the specified operator.

    Args:
        task_spec: Task specification (for required_ingredients count)
        use_hints: If True, follow JO repair sequences for guided recovery (JO-guided)
    """
    from .constraints import TaskSpec
    if task_spec is None:
        task_spec = TaskSpec.default()
    start_time = time.time()

    env.reset()
    state = env.state

    # Reset agent and operator
    if hasattr(agent, 'reset'):
        agent.reset()
    operator.reset()

    # Metrics
    violations_by_id = defaultdict(int)
    edits_by_type = defaultdict(int)
    total_reward = 0.0
    soups_delivered = 0
    total_edits = 0
    streaks = []

    # H1 Recovery tracking
    h1_fires = 0
    h1_recovery_times = []
    h1_in_recovery = False  # Currently recovering from H1?
    h1_recovery_start = 0  # Step when H1 fired
    h1_no_violation_streak = 0  # Consecutive steps without H1
    H1_ESCAPE_THRESHOLD = 3  # Consider "escaped" after N steps without H1

    # JO-guided: repair sequences (L≤3) after violations
    repair_seq_buffer = []  # Queue of repair actions to follow
    repair_seqs_given = 0  # Total sequences provided
    repair_steps_followed = 0  # Steps from sequences executed
    guided_soups = 0  # Soups delivered while following sequences
    in_guided_mode = False  # Currently following repair sequence?

    # Progress tracking
    max_pot_fill = 0  # Highest pot fill seen
    cooking_started = False  # Ever reached cooking state
    time_to_fill = -1  # Steps to 3 onions (-1 = never)
    h1_to_pot_success = -1  # Steps from H1 to pot success
    last_h1_step = -1  # Step when last H1 fired (for h1_to_pot_success)
    waiting_for_pot_success = False  # Tracking H1→pot success

    for step in range(horizon):
        # Get symbolic state
        sym_state = adapter.to_symbolic_state(state)

        # Check if we should use repair sequence (JO-guided)
        if use_hints and repair_seq_buffer:
            # Use repair sequence instead of asking LLM
            proposed_action = repair_seq_buffer.pop(0)
            repair_steps_followed += 1
            in_guided_mode = True
        else:
            in_guided_mode = False
            # Agent proposes action
            if isinstance(agent, (LLMAgent, MockLLMAgent)):
                obs = adapter.to_text_observation(state, agent_id, step, horizon)
                response = agent.act(obs)
                proposed_action = response.parsed_action
            else:
                obs = None
                proposed_action = agent.act(None)

        # Project through JO (or RuleBasedOperator)
        # RuleBasedOperator needs observation for LLM regeneration
        if isinstance(operator, RuleBasedOperator):
            decision = operator.project(sym_state, proposed_action, agent_id, observation=obs)
        else:
            decision = operator.project(sym_state, proposed_action, agent_id)

        # If repair sequence provided (JO-guided), add to buffer
        if use_hints and decision.repair_seq:
            repair_seq_buffer.extend(decision.repair_seq)
            repair_seqs_given += 1  # Count sequences, not steps
            if verbose:
                print(f"  Step {step}: Repair seq: {[a.name for a in decision.repair_seq]}")

        # Track violations (only count if NOT repaired)
        h1_this_step = False
        for v in decision.violations:
            # Only count violation if not repaired (EDIT = repaired)
            if decision.outcome != OperatorOutcome.EDIT:
                violations_by_id[v.constraint_id] += 1
            # Still track H1 for recovery metrics regardless of repair
            if v.constraint_id == "H1":
                h1_this_step = True

        # H1 Recovery tracking
        if h1_this_step:
            h1_fires += 1
            if not h1_in_recovery:
                # Start new recovery period
                h1_in_recovery = True
                h1_recovery_start = step
            h1_no_violation_streak = 0
        else:
            if h1_in_recovery:
                h1_no_violation_streak += 1
                if h1_no_violation_streak >= H1_ESCAPE_THRESHOLD:
                    # Successfully escaped H1 loop
                    recovery_time = step - h1_recovery_start - H1_ESCAPE_THRESHOLD + 1
                    h1_recovery_times.append(recovery_time)
                    h1_in_recovery = False
                    h1_no_violation_streak = 0

        # Track edits
        if decision.outcome == OperatorOutcome.EDIT:
            total_edits += 1
            if decision.repair_source:
                edits_by_type[decision.repair_source] += 1

        # Track streaks
        current_streak = operator.get_streak()
        if current_streak > 1:
            streaks.append(current_streak)

        # Execute final action
        final_action = decision.final_action
        oc_action_0 = adapter.to_overcooked_action(final_action)
        oc_action_1 = adapter.to_overcooked_action(Action.STAY)

        joint_action = (oc_action_0, oc_action_1) if agent_id == 0 else (oc_action_1, oc_action_0)
        prev_state = state
        state, reward, done, info = env.step(joint_action)

        # Learning signal for JO_dynamic: record if repair was successful
        # Success = action executed without error (state changed or valid no-op)
        if decision.outcome == OperatorOutcome.EDIT and hasattr(operator, 'record_repair_outcome'):
            # Better success heuristic: check if agent state actually changed
            # (position or holding different, or reward received)
            try:
                new_sym = adapter.to_symbolic_state(state)
                old_sym = adapter.to_symbolic_state(prev_state)
                pos_changed = new_sym.agent_positions != old_sym.agent_positions
                hold_changed = new_sym.agent_holdings != old_sym.agent_holdings
                repair_success = reward > 0 or pos_changed or hold_changed
            except:
                repair_success = True  # Assume success if can't compare
            # Pass original and violated actions for theta updates
            operator.record_repair_outcome(
                repair_success,
                original_action=decision.original_action,
                violated_action=decision.original_action,  # The action that violated
            )

        total_reward += reward
        if reward > 0:
            soup_count = int(reward / 20)
            soups_delivered += soup_count
            # Track if soup was delivered during guided mode
            if in_guided_mode or repair_seq_buffer:
                guided_soups += soup_count

        # === Progress tracking ===
        new_sym_state = adapter.to_symbolic_state(state)

        # Track pot fill progress
        required = task_spec.required_ingredients
        for pot_id, pot in new_sym_state.pots.items():
            n_onions = len([i for i in pot.get("ingredients", []) if i == "onion"])
            if n_onions > max_pot_fill:
                max_pot_fill = n_onions
            # Time to fill (first time reaching required count)
            if n_onions >= required and time_to_fill == -1:
                time_to_fill = step + 1
            # Cooking started
            if pot.get("is_cooking") and not cooking_started:
                cooking_started = True

        # Track H1 → pot success
        if h1_this_step:
            last_h1_step = step
            waiting_for_pot_success = True

        if waiting_for_pot_success and h1_to_pot_success == -1:
            # Check if we successfully interacted with pot (added ingredient)
            for pot_id, pot in new_sym_state.pots.items():
                old_pot = sym_state.pots.get(pot_id, {})
                old_count = len(old_pot.get("ingredients", []))
                new_count = len(pot.get("ingredients", []))
                if new_count > old_count:
                    # Successfully added to pot after H1
                    h1_to_pot_success = step - last_h1_step
                    waiting_for_pot_success = False
                    break

        if verbose and decision.violations:
            print(f"  Step {step}: {[v.constraint_id for v in decision.violations]} -> {decision.outcome.value}")

        if done:
            break

    duration = time.time() - start_time
    total_steps = step + 1
    total_violations = sum(violations_by_id.values())

    # H1 recovery metrics
    h1_mean_recovery = sum(h1_recovery_times) / len(h1_recovery_times) if h1_recovery_times else 0.0
    h1_escape_rate = len(h1_recovery_times) / h1_fires if h1_fires > 0 else 1.0

    return EpisodeResult(
        episode_id=episode_id,
        condition=condition,
        total_steps=total_steps,
        total_reward=total_reward,
        soups_delivered=soups_delivered,
        violations_total=total_violations,
        violations_by_id=dict(violations_by_id),
        violation_rate=total_violations / total_steps,
        edits_total=total_edits,
        edits_by_type=dict(edits_by_type),
        intervention_rate=total_edits / total_steps,
        max_streak=max(streaks) if streaks else 1,
        avg_streak=sum(streaks) / len(streaks) if streaks else 1.0,
        h1_fires=h1_fires,
        h1_recovery_times=h1_recovery_times,
        h1_mean_recovery=h1_mean_recovery,
        h1_escape_rate=h1_escape_rate,
        repair_seqs_given=repair_seqs_given,
        repair_steps_followed=repair_steps_followed,
        guided_soups=guided_soups,
        max_pot_fill=max_pot_fill,
        cooking_started=cooking_started,
        time_to_fill=time_to_fill,
        h1_to_pot_success=h1_to_pot_success,
        duration_seconds=duration,
    )


def run_episode_dual_agent(env, adapter: OvercookedAdapter, agents: List, operator,
                           horizon: int, episode_id: int, condition: str,
                           task_spec: "TaskSpec" = None, verbose: bool = False) -> EpisodeResult:
    """Run a single episode with TWO LLM-controlled agents (K2 composition).

    Both agents act every step. Same constraints apply to both.
    This tests composition invariance: does adding a second agent change JO's enforcement?
    """
    from .constraints import TaskSpec
    if task_spec is None:
        task_spec = TaskSpec.default()
    start_time = time.time()

    env.reset()
    state = env.state

    # Reset agents and operator
    for agent in agents:
        if hasattr(agent, 'reset'):
            agent.reset()
    operator.reset()

    # Metrics (aggregate across both agents)
    violations_by_id = defaultdict(int)
    edits_by_type = defaultdict(int)
    total_reward = 0.0
    soups_delivered = 0
    total_edits = 0
    total_actions = 0  # Track total actions across both agents

    # Progress tracking
    max_pot_fill = 0
    cooking_started = False
    time_to_fill = -1

    # Utility tracking
    onion_pickups = 0
    pot_deposits = 0
    positions_visited = set()
    non_stay_actions = 0
    prev_holdings = {0: None, 1: None}
    prev_pot_counts = {}

    for step in range(horizon):
        sym_state = adapter.to_symbolic_state(state)

        # Both agents propose actions
        proposed_actions = []
        for agent_id, agent in enumerate(agents):
            if isinstance(agent, (LLMAgent, MockLLMAgent)):
                obs = adapter.to_text_observation(state, agent_id, step, horizon)
                response = agent.act(obs)
                proposed_actions.append(response.parsed_action)
            else:
                proposed_actions.append(agent.act(None))

        # Project both agents through JO with SAME constraints
        final_actions = []
        decisions = []  # Track for learning signal
        for agent_id, proposed_action in enumerate(proposed_actions):
            decision = operator.project(sym_state, proposed_action, agent_id)
            decisions.append(decision)
            total_actions += 1

            # Track violations (only count if NOT repaired)
            # EDIT = repaired, don't count as violation
            # BLOCK/ALLOW with violations = count as violation
            if decision.outcome != OperatorOutcome.EDIT:
                for v in decision.violations:
                    violations_by_id[v.constraint_id] += 1

            # Track edits (interventions)
            if decision.outcome == OperatorOutcome.EDIT:
                total_edits += 1
                if decision.repair_source:
                    edits_by_type[decision.repair_source] += 1

            final_actions.append(decision.final_action)

        # Execute joint action (both agents move)
        oc_action_0 = adapter.to_overcooked_action(final_actions[0])
        oc_action_1 = adapter.to_overcooked_action(final_actions[1])
        joint_action = (oc_action_0, oc_action_1)

        prev_state = state
        state, reward, done, info = env.step(joint_action)

        # Learning signal for JO_dynamic (dual-agent)
        # Call once per step (processes all pending learnings from both agents)
        # Assume success if repair didn't cause error (lenient learning)
        if hasattr(operator, 'record_repair_outcome'):
            for d in decisions:
                if d.outcome == OperatorOutcome.EDIT:
                    # Lenient: assume repair succeeded if we didn't crash
                    operator.record_repair_outcome(
                        True,
                        original_action=d.original_action,
                        violated_action=d.original_action,
                    )

        total_reward += reward
        if reward > 0:
            soup_count = int(reward / 20)
            soups_delivered += soup_count

        # Track progress
        new_sym_state = adapter.to_symbolic_state(state)
        required = task_spec.required_ingredients
        for pot_id, pot in new_sym_state.pots.items():
            fill = len(pot.get("ingredients", []))
            if fill > max_pot_fill:
                max_pot_fill = fill
                if fill >= required and time_to_fill < 0:
                    time_to_fill = step
            if pot.get("is_cooking") or pot.get("is_ready"):
                cooking_started = True

        if done:
            break

    # Compute rates
    total_violations = sum(violations_by_id.values())
    violation_rate = total_violations / total_actions if total_actions > 0 else 0.0
    intervention_rate = total_edits / total_actions if total_actions > 0 else 0.0
    duration = time.time() - start_time

    return EpisodeResult(
        episode_id=episode_id,
        condition=condition,
        total_steps=step + 1,
        total_reward=total_reward,
        soups_delivered=soups_delivered,
        violations_total=total_violations,
        violations_by_id=dict(violations_by_id),
        violation_rate=violation_rate,
        edits_total=total_edits,
        edits_by_type=dict(edits_by_type),
        intervention_rate=intervention_rate,
        max_streak=0,
        avg_streak=0.0,
        max_pot_fill=max_pot_fill,
        cooking_started=cooking_started,
        time_to_fill=time_to_fill,
        duration_seconds=duration,
    )


def run_experiment(config: ExperimentConfig, verbose: bool = False) -> ExperimentResult:
    """Run a full experiment with the given configuration."""
    print(f"\n{'='*60}")
    print(f"Running: {config.name}")
    print(f"Condition: {config.condition}")
    print(f"Episodes: {config.n_episodes}, Horizon: {config.horizon}, Agents: {config.n_agents}")
    print(f"{'='*60}\n")

    # Setup
    env, adapter = create_adapter(
        config.layout,
        required_ingredients=config.task_spec.required_ingredients,
        use_nav_hints=config.use_nav_hints
    )

    # Create agent(s)
    agents = []
    for agent_id in range(config.n_agents):
        if config.agent_type == "llm":
            agent = LLMAgent(model=config.model, task_spec=config.task_spec, agent_id=agent_id,
                            backend=config.backend)
        elif config.agent_type == "reflexion":
            agent = ReflexionAgent(model=config.model, task_spec=config.task_spec, agent_id=agent_id,
                                   backend=config.backend)
        elif config.agent_type == "critic":
            agent = CRITICAgent(model=config.model, task_spec=config.task_spec, agent_id=agent_id,
                                backend=config.backend)
        elif config.agent_type == "llama_guard":
            agent = LlamaGuardAgent(model=config.model, task_spec=config.task_spec, agent_id=agent_id,
                                    backend=config.backend)
        elif config.agent_type == "constrained_decoding":
            agent = ConstrainedDecodingAgent(model=config.model, task_spec=config.task_spec, agent_id=agent_id,
                                             backend=config.backend)
        else:
            agent = MockLLMAgent(task_spec=config.task_spec, seed=config.seed + agent_id)
        agents.append(agent)

    # Create operator (same for all agents - tests composition invariance)
    if config.condition in ["NO"]:  # NO = No Operator 
        operator = NoOperator(config.task_spec)
    elif config.condition == "JO_static":
        operator = JudgmentOperator(config.task_spec, mode="static")
    elif config.condition == "JO_dynamic":
        # Pass theta settings from config for P vs P+Theta ablation
        operator = JudgmentOperator(
            config.task_spec,
            mode="dynamic",
            enable_theta_updates=config.enable_theta_updates,
            theta_lambda=config.theta_lambda,
            theta_eta=config.theta_eta,
            theta_max_norm=config.theta_max_norm,
        )
    elif config.condition == "JO_guided":
        # JO-guided: static repair + repair sequences (L≤3)
        operator = JudgmentOperator(config.task_spec, mode="guided")
    elif config.condition == "RBR":
        # RBR: Runtime Backtracking Repair baseline (stateless retry up to k=3)
        # Create regeneration function using the agent
        primary_agent = agents[0]
        def llm_regenerate_fn(observation: str, violation_feedback: str):
            """Regenerate action with violation feedback."""
            if isinstance(primary_agent, LLMAgent):
                # Add violation feedback to observation
                augmented_obs = f"{observation}\n\n{violation_feedback}"
                response = primary_agent.act(augmented_obs)
                return response.parsed_action
            return None
        operator = RuleBasedOperator(config.task_spec, llm_regenerate_fn=llm_regenerate_fn, k=3)
    else:
        raise ValueError(f"Unknown condition: {config.condition}")

    # Run episodes
    episodes = []
    for ep in range(config.n_episodes):
        if config.n_agents == 1:
            # K1: Single agent (original behavior)
            result = run_episode(
                env, adapter, agents[0], operator,
                agent_id=0, horizon=config.horizon,
                episode_id=ep, condition=config.condition,
                task_spec=config.task_spec, verbose=verbose
            )
        else:
            # K2: Dual agent (composition invariance test)
            result = run_episode_dual_agent(
                env, adapter, agents, operator,
                horizon=config.horizon, episode_id=ep,
                condition=config.condition, task_spec=config.task_spec,
                verbose=verbose
            )
        episodes.append(result)

        # For ReflexionAgent: generate reflection after failed episodes
        if config.n_agents == 1 and isinstance(agents[0], ReflexionAgent):
            success = result.soups_delivered > 0
            agents[0].end_episode(success, result.total_reward, result.max_pot_fill)

        if (ep + 1) % 5 == 0 or ep == 0:
            # Show precedent store stats for JO_dynamic
            store_info = ""
            if hasattr(operator, 'get_retrieval_stats'):
                stats = operator.get_retrieval_stats()
                store_info = f", store={stats['store_size']}, hits={stats['hits']}/{stats['attempts']}"
            print(f"  Episode {ep+1}/{config.n_episodes}: "
                  f"VR={result.violation_rate:.1%}, "
                  f"IR={result.intervention_rate:.1%}, "
                  f"soups={result.soups_delivered}{store_info}")

    # Aggregate
    mean_reward = sum(e.total_reward for e in episodes) / len(episodes)
    mean_soups = sum(e.soups_delivered for e in episodes) / len(episodes)
    mean_vr = sum(e.violation_rate for e in episodes) / len(episodes)
    mean_ir = sum(e.intervention_rate for e in episodes) / len(episodes)

    # VR by constraint
    total_by_id = defaultdict(int)
    total_steps = sum(e.total_steps for e in episodes)
    for e in episodes:
        for cid, count in e.violations_by_id.items():
            total_by_id[cid] += count
    vr_by_constraint = {cid: count / total_steps for cid, count in total_by_id.items()}

    # RBR retry statistics
    rbr_stats = {"mean_retries": 0.0, "max_retries": 0, "pct_hitting_max_k": 0.0, "total_regenerations": 0}
    if isinstance(operator, RuleBasedOperator):
        rbr_stats = operator.get_retry_stats()

    return ExperimentResult(
        config=config,
        episodes=episodes,
        mean_reward=mean_reward,
        mean_soups=mean_soups,
        mean_vr=mean_vr,
        mean_ir=mean_ir,
        vr_by_constraint=vr_by_constraint,
        rbr_mean_retries=rbr_stats.get("mean_retries", 0.0),
        rbr_max_retries=rbr_stats.get("max_retries", 0),
        rbr_pct_hitting_max_k=rbr_stats.get("pct_hitting_max_k", 0.0),
        rbr_total_regenerations=rbr_stats.get("total_regenerations", 0),
    )


def print_experiment_report(result: ExperimentResult):
    """Print formatted experiment report."""
    print(f"\n{'='*60}")
    print(f"EXPERIMENT REPORT: {result.config.name}")
    print(f"{'='*60}")

    print(f"\nCondition: {result.config.condition}")
    print(f"Episodes: {result.config.n_episodes}")

    print(f"\n--- Progress Metrics (Primary) ---")
    required = result.config.task_spec.required_ingredients
    mean_pot_fill = sum(e.max_pot_fill for e in result.episodes) / len(result.episodes)
    cooking_rate = sum(1 for e in result.episodes if e.cooking_started) / len(result.episodes)
    fill_times = [e.time_to_fill for e in result.episodes if e.time_to_fill > 0]
    h1_pot_times = [e.h1_to_pot_success for e in result.episodes if e.h1_to_pot_success > 0]

    print(f"Mean max pot fill: {mean_pot_fill:.2f} / {required} onions")
    print(f"Cooking started rate: {cooking_rate:.1%}")
    if fill_times:
        print(f"Mean time to fill ({required} onions): {sum(fill_times)/len(fill_times):.1f} steps ({len(fill_times)}/{len(result.episodes)} reached)")
    else:
        print(f"Mean time to fill: N/A (0/{len(result.episodes)} reached {required} onions)")
    if h1_pot_times:
        print(f"H1→pot success: {sum(h1_pot_times)/len(h1_pot_times):.1f} steps ({len(h1_pot_times)} recoveries)")

    print(f"\n--- Utility (Secondary) ---")
    print(f"Mean reward: {result.mean_reward:.1f}")
    print(f"Mean soups/ep: {result.mean_soups:.2f}")
    episodes_with_soup = sum(1 for e in result.episodes if e.soups_delivered > 0)
    pct_with_soup = episodes_with_soup / len(result.episodes) if result.episodes else 0
    print(f"Episodes with ≥1 soup: {episodes_with_soup}/{len(result.episodes)} ({pct_with_soup:.0%})")

    print(f"\n--- Violation Rate ---")
    print(f"Overall VR: {result.mean_vr:.1%}")
    for cid in ["R1", "H1", "T1", "T2", "T3", "T4", "T5", "T6", "H3"]:
        rate = result.vr_by_constraint.get(cid, 0)
        if rate > 0:
            print(f"  {cid}: {rate:.1%}")

    if result.config.condition not in ["NO"]:
        print(f"\n--- Intervention Rate ---")
        print(f"Overall IR: {result.mean_ir:.1%}")

    # H1 Recovery metrics
    h1_fires_total = sum(e.h1_fires for e in result.episodes)
    if h1_fires_total > 0:
        h1_recoveries = [t for e in result.episodes for t in (e.h1_recovery_times or [])]
        h1_escape_rates = [e.h1_escape_rate for e in result.episodes if e.h1_fires > 0]
        print(f"\n--- H1 Recovery ---")
        print(f"Total H1 fires: {h1_fires_total}")
        if h1_recoveries:
            print(f"Mean recovery time: {sum(h1_recoveries)/len(h1_recoveries):.1f} steps")
        if h1_escape_rates:
            print(f"Mean escape rate: {sum(h1_escape_rates)/len(h1_escape_rates):.1%}")

    # JO-guided metrics (repair sequences)
    seqs_total = sum(e.repair_seqs_given for e in result.episodes)
    steps_total = sum(e.repair_steps_followed for e in result.episodes)
    guided_soups_total = sum(e.guided_soups for e in result.episodes)
    if seqs_total > 0:
        print(f"\n--- JO-Guided (Repair Sequences) ---")
        print(f"Repair sequences given: {seqs_total}")
        print(f"Steps followed: {steps_total}")
        print(f"Soups during guided: {guided_soups_total}")

    # RBR retry statistics
    if result.config.condition == "RBR" and result.rbr_mean_retries > 0:
        print(f"\n--- RBR Retry Statistics ---")
        print(f"Mean retries/step: {result.rbr_mean_retries:.2f}")
        print(f"Max retries: {result.rbr_max_retries}")
        print(f"% steps hitting k=3: {result.rbr_pct_hitting_max_k:.1f}%")
        print(f"Total LLM regenerations: {result.rbr_total_regenerations}")

    print(f"\n--- Stuckness ---")
    max_streaks = [e.max_streak for e in result.episodes]
    print(f"Mean max streak: {sum(max_streaks)/len(max_streaks):.1f}")
    print(f"Max streak seen: {max(max_streaks)}")

    print(f"{'='*60}\n")


def save_results(result: ExperimentResult, output_dir: str = "runs"):
    """Save results to JSON."""
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    filename = f"{result.config.name}_{result.config.condition}.json"
    filepath = output_path / filename

    # Convert to serializable format
    data = {
        "config": {
            "name": result.config.name,
            "condition": result.config.condition,
            "layout": result.config.layout,
            "n_episodes": result.config.n_episodes,
            "horizon": result.config.horizon,
        },
        "summary": {
            "mean_reward": result.mean_reward,
            "mean_soups": result.mean_soups,
            "mean_vr": result.mean_vr,
            "mean_ir": result.mean_ir,
            "vr_by_constraint": result.vr_by_constraint,
        },
        "episodes": [asdict(e) for e in result.episodes],
    }

    with open(filepath, 'w') as f:
        json.dump(data, f, indent=2)

    print(f"Results saved to: {filepath}")


def run_comparison(task_spec: TaskSpec, n_episodes: int = 10,
                   horizon: int = 50, layout: str = "counter_circuit",
                   agent_type: str = "mock_llm",
                   model: str = "gpt-4o-mini",
                   backend: str = "openai") -> Dict[str, ExperimentResult]:
    """
    Run NO vs JO_static vs JO_dynamic vs JO_guided comparison.

    Returns dict of condition -> result.
    """
    results = {}

    for condition in ["NO", "JO_static", "JO_dynamic", "JO_guided"]:
        config = ExperimentConfig(
            name="comparison",
            condition=condition,
            task_spec=task_spec,
            layout=layout,
            n_episodes=n_episodes,
            horizon=horizon,
            agent_type=agent_type,
            model=model,
            backend=backend,
        )
        results[condition] = run_experiment(config)
        print_experiment_report(results[condition])

    return results


if __name__ == "__main__":
    # Quick comparison test
    task_spec = TaskSpec.onion_only()
    results = run_comparison(task_spec, n_episodes=5, horizon=50, agent_type="mock_llm")

    print("\n" + "="*60)
    print("COMPARISON SUMMARY")
    print("="*60)
    for cond, res in results.items():
        print(f"{cond}: VR={res.mean_vr:.1%}, IR={res.mean_ir:.1%}, soups={res.mean_soups:.2f}")
