"""
Judgment Operator for Overcooked.

JO_static: Deterministic repair rules
JO_dynamic: Learned context-dependent repairs (uses precedent store)
"""

from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
from enum import Enum
from collections import defaultdict
import sys
from pathlib import Path
import numpy as np
import re

# Add src/jo to path for PrecedentStore
sys.path.insert(0, str(Path(__file__).parent.parent / "src" / "jo"))
from precedent_store import PrecedentStore, RetrievalHit

from .constraints import (
    ViolationDetector, TaskSpec, SymbolicState, Action, Violation, Severity
)


class OperatorOutcome(Enum):
    ALLOW = "allow"      # No violation, execute as-is
    EDIT = "edit"        # Violation detected, action replaced
    BLOCK = "block"      # Violation detected, no valid repair found


@dataclass
class OperatorDecision:
    """Result of projecting an action through JO."""
    outcome: OperatorOutcome
    original_action: Action
    final_action: Action
    violations: List[Violation]
    repair_source: Optional[str] = None  # "static", "precedent", "detector"
    repair_explanation: Optional[str] = None
    # JO-guided: short repair sequence (L≤3) after violation
    repair_seq: Optional[List[Action]] = None
    repair_seq_explanation: Optional[str] = None


@dataclass
class EpisodeMetrics:
    """Metrics collected during an episode."""
    total_steps: int = 0
    total_reward: float = 0.0
    soups_delivered: int = 0

    # Violation tracking
    violations_by_id: Dict[str, int] = field(default_factory=lambda: defaultdict(int))

    # JO tracking
    edits_by_type: Dict[str, int] = field(default_factory=lambda: defaultdict(int))  # "move_to_pot", "interact_pot", etc.
    total_edits: int = 0

    # Stuckness tracking
    repeated_action_streaks: List[int] = field(default_factory=list)
    max_streak: int = 0

    def violation_rate(self) -> float:
        return sum(self.violations_by_id.values()) / max(1, self.total_steps)

    def intervention_rate(self) -> float:
        return self.total_edits / max(1, self.total_steps)


class JudgmentOperator:
    """
    Base Judgment Operator.

    Checks actions for violations and applies repairs.
    """

    def __init__(self, task_spec: TaskSpec, mode: str = "static",
                 precedent_store: Optional[PrecedentStore] = None,
                 retrieve_threshold: float = 0.55,
                 use_semantic: bool = True,
                 enable_theta_updates: bool = False,
                 theta_lambda: float = 1.0,
                 theta_eta: float = 0.01,
                 theta_max_norm: float = 5.0):
        """
        Args:
            task_spec: Task constraints
            mode: "static" (fixed rules), "dynamic" (learned precedents), or "guided" (static + repair sequences)
            precedent_store: Optional external precedent store (for sharing across episodes)
            retrieve_threshold: Minimum similarity score to use retrieved precedent
            use_semantic: Use semantic similarity (requires sentence-transformers)
            enable_theta_updates: Enable online theta parameter updates (for P+Θ ablation)
            theta_lambda: Weight for theta penalty in scoring
            theta_eta: Learning rate for theta updates
            theta_max_norm: Clip bound [-B, B] for theta
        """
        self.task_spec = task_spec
        self.mode = mode
        self.detector = ViolationDetector(task_spec)
        self.retrieve_threshold = retrieve_threshold

        # ========== Theta Parameters for Learnable Scoring (MVP) ==========
        self.enable_theta_updates = enable_theta_updates
        self.theta_lambda = theta_lambda
        self.theta_eta = theta_eta
        self.theta_max_norm = theta_max_norm
        # Theta vector: [f1_support, f2_frequency, f3_confidence, f4_risk]
        self.theta = np.zeros(4, dtype=np.float32)
        self.theta_updates_count = 0

        # For tracking repeated actions (stuckness)
        self._last_action: Optional[Action] = None
        self._action_streak: int = 0

        # Precedent store for JO_dynamic
        if mode == "dynamic":
            if precedent_store is not None:
                self.precedent_store = precedent_store
            else:
                self.precedent_store = PrecedentStore(
                    use_semantic=use_semantic,
                    retrieve_threshold=retrieve_threshold,
                    max_capacity=200,
                )
        else:
            self.precedent_store = None

        # Track last repair for learning signal (list for multi-agent)
        self._pending_learnings: List[Dict[str, Any]] = []
        self._retrieval_hit_rate: int = 0
        self._retrieval_attempts: int = 0

        # Episode context for T4/T5/T6 constraints
        self._episode_context: Dict[str, Any] = {
            "action_history": [],  # List of (agent_id, action) tuples
            "agent_step_counts": defaultdict(int),  # Steps per agent
        }

    def reset_episode(self):
        """Reset episode-level tracking for new episode."""
        self._episode_context = {
            "action_history": [],
            "agent_step_counts": defaultdict(int),
        }
        self._last_action = None
        self._action_streak = 0

    def project(self, state: SymbolicState, action: Action, agent_id: int) -> OperatorDecision:
        """
        Project an action through the Judgment Operator.

        Args:
            state: Current symbolic state
            action: Proposed action from agent
            agent_id: Which agent

        Returns:
            OperatorDecision with outcome and possibly repaired action
        """
        # Track repeated actions
        self._update_streak(action)

        # Check for violations (pass episode context for T4/T5/T6)
        violations = self.detector.check(state, action, agent_id, episode_context=self._episode_context)

        # Update episode context after checking
        self._episode_context["action_history"].append((agent_id, action))
        self._episode_context["agent_step_counts"][agent_id] += 1

        if not violations:
            return OperatorDecision(
                outcome=OperatorOutcome.ALLOW,
                original_action=action,
                final_action=action,
                violations=[],
            )

        # Try to repair
        repair_result = self._get_repair(state, action, agent_id, violations)
        repair_action, repair_source, repair_explanation, hints, hint_explanation = repair_result

        if repair_action is not None:
            return OperatorDecision(
                outcome=OperatorOutcome.EDIT,
                original_action=action,
                final_action=repair_action,
                violations=violations,
                repair_source=repair_source,
                repair_explanation=repair_explanation,
                repair_seq=hints[:3] if hints else None,  # JO-guided: L≤3
                repair_seq_explanation=hint_explanation if hint_explanation else None,
            )

        # No repair found - block (use STAY)
        return OperatorDecision(
            outcome=OperatorOutcome.BLOCK,
            original_action=action,
            final_action=Action.STAY,
            violations=violations,
            repair_source="block",
            repair_explanation="No valid repair found",
        )

    def _get_repair(self, state: SymbolicState, action: Action,
                    agent_id: int, violations: List[Violation]) -> tuple:
        """Get repair action based on mode.

        Returns: (action, source, explanation, hints, hint_explanation)
        """
        if self.mode == "static":
            return self._static_repair(state, action, agent_id, violations)
        elif self.mode == "guided":
            # JO-guided: static repair + always include repair sequences (L≤3)
            return self._static_repair(state, action, agent_id, violations, include_sequences=True)
        else:
            return self._dynamic_repair(state, action, agent_id, violations)

    def _static_repair(self, state: SymbolicState, action: Action,
                       agent_id: int, violations: List[Violation],
                       include_sequences: bool = False) -> tuple:
        """
        Static repair rules - deterministic, hand-coded.

        Args:
            include_sequences: If True (JO-guided), always generate repair sequences

        Returns: (action, source, explanation, hints, hint_explanation) or (None, None, None, None, None)
        """
        # Process violations by priority
        for v in violations:
            repair = self._static_repair_for_violation(state, agent_id, v, include_sequences)
            if repair is not None:
                # H1 returns 5-tuple with hints, others return 3-tuple
                if len(repair) == 5:
                    return repair
                else:
                    # Pad with empty hints for non-H1 repairs
                    return repair + ([], "")

        return None, None, None, None, None

    def _static_repair_for_violation(self, state: SymbolicState,
                                      agent_id: int, v: Violation,
                                      include_sequences: bool = False) -> Optional[tuple]:
        """Get static repair for a specific violation.

        Args:
            include_sequences: If True (JO-guided), all repairs return 5-tuple with hints

        Returns: (action, source, explanation) or (action, source, explanation, hints, hint_explanation)
        """

        if v.constraint_id == "H1":
            # H1: Holding item, trying to pick → place held item instead
            # Always returns 5-tuple with hints for guided recovery
            return self._repair_h1(state, agent_id, v)

        elif v.constraint_id == "R1":
            # R1: Premature cook → add more ingredients first
            repair = self._repair_r1(state, agent_id, v)
            if repair and include_sequences:
                # Add repair sequence for R1
                hints = [Action.LEFT, Action.INTERACT, Action.RIGHT]  # Get more ingredients
                return repair + (hints, "Get more ingredients then return to pot")
            return repair

        elif v.constraint_id == "T1":
            # T1: Wrong ingredient → go to correct dispenser
            repair = self._repair_t1(state, agent_id, v)
            if repair and include_sequences:
                hints = [Action.INTERACT, Action.UP, Action.RIGHT]  # Get correct ingredient
                return repair + (hints, "Get correct ingredient")
            return repair

        elif v.constraint_id == "T2":
            # T2: Role violation → do role-appropriate action
            return self._repair_t2(state, agent_id, v)

        elif v.constraint_id == "T3":
            # T3: Wrong pot → go to correct pot
            return self._repair_t3(state, agent_id, v)

        elif v.constraint_id == "H3":
            # H3: Premature plate pickup → get ingredients instead
            repair = self._repair_h3(state, agent_id, v)
            if repair and include_sequences:
                # Add repair sequence for H3
                hints = [Action.INTERACT, Action.RIGHT, Action.INTERACT]  # Get ingredient, go to pot, add
                return repair + (hints, "Get ingredient and add to pot first")
            return repair

        elif v.constraint_id == "T4":
            # T4: Too many consecutive STAY actions → force movement
            # Pick a direction to move
            repair_action = Action.UP  # Default to UP to break the STAY streak
            if include_sequences:
                hints = [Action.UP, Action.RIGHT, Action.INTERACT]
                return (repair_action, "static_t4", "Break STAY streak with movement", hints, "Move and interact")
            return (repair_action, "static_t4", "Break STAY streak with movement")

        elif v.constraint_id == "T5":
            # T5: Wrong agent turn → must stay
            repair_action = Action.STAY
            if include_sequences:
                return (repair_action, "static_t5", "Wait for other agent", [], "Wait")
            return (repair_action, "static_t5", "Wait for other agent")

        elif v.constraint_id == "T6":
            # T6: Agent exceeded step limit → must stay
            repair_action = Action.STAY
            if include_sequences:
                return (repair_action, "static_t6", "Step limit reached", [], "Rest")
            return (repair_action, "static_t6", "Step limit reached")

        return None

    def _repair_h1(self, state: SymbolicState, agent_id: int,
                   v: Violation) -> Optional[tuple]:
        """
        Repair H1: Agent holding item, trying to pick.

        Strategy: Place held item appropriately.
        - If holding ingredient and adjacent to pot with room: place in pot
        - If holding plate and adjacent to ready pot: get soup
        - If holding soup and adjacent to serve: deliver
        - Otherwise: move toward appropriate target

        Also generates multi-step hints for guided recovery.
        """
        holding = state.agent_holdings.get(agent_id)
        pos = state.agent_positions[agent_id]
        orient = state.agent_orientations[agent_id]
        facing_pos = (pos[0] + orient[0], pos[1] + orient[1])

        # What are we facing?
        facing_target = state.get_object_at(facing_pos)

        # Generate repair action and hints
        repair_action = None
        repair_source = "static_h1"
        repair_explanation = ""
        hints = []
        hint_explanation = ""

        if holding in ["onion", "tomato"]:
            # Holding ingredient - try to place in pot

            # Check if facing a valid pot
            if facing_target and facing_target.startswith("pot"):
                pot = state.pots.get(facing_target)
                if pot and not pot.get("is_cooking") and not pot.get("is_ready"):
                    if len(pot.get("ingredients", [])) < self.task_spec.required_ingredients:
                        # Check pot selection constraint
                        if self.task_spec.allowed_pots is None or facing_target in self.task_spec.allowed_pots:
                            repair_action = Action.INTERACT
                            repair_explanation = f"Place {holding} in {facing_target}"
                            # Generate hints for completing soup
                            hints, hint_explanation = self._generate_soup_hints(state, agent_id, facing_target)

            # Not facing pot - find one and move toward it
            if repair_action is None:
                target_pot = self._find_valid_pot(state)
                if target_pot:
                    move = self._move_toward(pos, state.pots[target_pot]["pos"])
                    if move:
                        repair_action = move
                        repair_explanation = f"Move toward {target_pot}"
                        hints = [Action.INTERACT]  # Next: place in pot
                        hint_explanation = f"Then place {holding} in {target_pot}"

        elif holding == "plate":
            # Holding plate - try to get soup from ready pot
            if facing_target and facing_target.startswith("pot"):
                pot = state.pots.get(facing_target)
                if pot and pot.get("is_ready"):
                    repair_action = Action.INTERACT
                    repair_explanation = "Get soup from pot"
                    # Hint: deliver the soup
                    if state.serve_locations:
                        hints = self._path_to_target(pos, state.serve_locations[0])
                        hints.append(Action.INTERACT)
                        hint_explanation = "Then deliver soup to serve area"

            # Find ready pot and move toward it
            if repair_action is None:
                for pot_id, pot in state.pots.items():
                    if pot.get("is_ready"):
                        move = self._move_toward(pos, pot["pos"])
                        if move:
                            repair_action = move
                            repair_explanation = f"Move toward ready {pot_id}"
                            hints = [Action.INTERACT]  # Next: get soup
                            hint_explanation = "Then get soup from pot"
                            break

        elif holding == "soup":
            # Holding soup - deliver it
            if facing_target == "serve":
                repair_action = Action.INTERACT
                repair_explanation = "Deliver soup"
                # Hint: get more ingredients for next soup
                hints = [Action.UP, Action.INTERACT]  # Move toward dispenser
                hint_explanation = "Then start next soup"

            # Move toward serve location
            elif state.serve_locations:
                move = self._move_toward(pos, state.serve_locations[0])
                if move:
                    repair_action = move
                    repair_explanation = "Move toward serve area"
                    hints = [Action.INTERACT]  # Next: deliver
                    hint_explanation = "Then deliver soup"

        # Fallback: drop on counter if available
        if repair_action is None and state.counters:
            if facing_pos in state.counters:
                repair_action = Action.INTERACT
                repair_explanation = "Drop on counter"
            else:
                move = self._move_toward(pos, state.counters[0])
                if move:
                    repair_action = move
                    repair_explanation = "Move toward counter to drop"

        if repair_action is not None:
            # Return tuple with hints embedded (will be extracted by caller)
            return (repair_action, repair_source, repair_explanation, hints, hint_explanation)
        return None

    def _generate_soup_hints(self, state: SymbolicState, agent_id: int,
                             pot_id: str) -> tuple:
        """Generate hints for completing soup after placing ingredient."""
        pot = state.pots.get(pot_id, {})
        n_ingredients = len(pot.get("ingredients", []))
        hints = []

        required_count = self.task_spec.required_ingredients
        if n_ingredients < required_count - 1:  # Need more ingredients after placing
            # Hint: get more ingredients
            ingredient_name = self.task_spec.required_ingredient or "onion"
            hints = [Action.LEFT, Action.INTERACT, Action.RIGHT, Action.INTERACT]
            explanation = f"Get more {ingredient_name} (need {required_count - n_ingredients - 1} more)"
        elif n_ingredients == required_count - 1:  # Will reach required after placing, then cook
            # Hint: wait for cooking, then get plate
            hints = [Action.STAY, Action.STAY, Action.LEFT, Action.INTERACT]
            explanation = "Wait for cooking, then get plate"
        else:
            explanation = ""

        return hints[:3], explanation  # Max 3 steps (JO-guided bound)

    def _path_to_target(self, current: tuple, target: tuple) -> List[Action]:
        """Generate simple path (sequence of moves) to target."""
        path = []
        pos = current
        for _ in range(4):  # Max 4 moves
            move = self._move_toward(pos, target)
            if move is None:
                break
            path.append(move)
            # Update position estimate
            if move == Action.UP:
                pos = (pos[0], pos[1] - 1)
            elif move == Action.DOWN:
                pos = (pos[0], pos[1] + 1)
            elif move == Action.LEFT:
                pos = (pos[0] - 1, pos[1])
            elif move == Action.RIGHT:
                pos = (pos[0] + 1, pos[1])
        return path

    def _repair_r1(self, state: SymbolicState, agent_id: int,
                   v: Violation) -> Optional[tuple]:
        """
        Repair R1: Premature cook attempt.

        Strategy: Go get more ingredients instead.
        """
        pos = state.agent_positions[agent_id]

        # Find ingredient dispenser
        required = self.task_spec.required_ingredient or "onion"
        dispenser_key = f"{required}_dispenser"

        if dispenser_key in state.dispensers:
            disp_pos = state.dispensers[dispenser_key]
            # If adjacent, pick up
            if self._is_adjacent(pos, disp_pos):
                return (Action.INTERACT, "static_r1", f"Pick up {required}")
            # Otherwise move toward
            move = self._move_toward(pos, disp_pos)
            if move:
                return (move, "static_r1", f"Move toward {required} dispenser")

        return None

    def _repair_t1(self, state: SymbolicState, agent_id: int,
                   v: Violation) -> Optional[tuple]:
        """
        Repair T1: Wrong ingredient.

        Strategy: Go to correct ingredient dispenser.
        """
        required = self.task_spec.required_ingredient
        if not required:
            return None

        pos = state.agent_positions[agent_id]
        dispenser_key = f"{required}_dispenser"

        if dispenser_key in state.dispensers:
            disp_pos = state.dispensers[dispenser_key]
            move = self._move_toward(pos, disp_pos)
            if move:
                return (move, "static_t1", f"Move toward {required} dispenser")

        return None

    def _repair_t2(self, state: SymbolicState, agent_id: int,
                   v: Violation) -> Optional[tuple]:
        """
        Repair T2: Role violation.

        Strategy: Do role-appropriate action instead.
        """
        role = self.task_spec.agent_roles.get(agent_id) if self.task_spec.agent_roles else None
        if not role:
            return None

        pos = state.agent_positions[agent_id]
        holding = state.agent_holdings.get(agent_id)

        if role == "cook":
            # Cook should: get ingredients, add to pot, start cook
            if holding is None:
                # Get an ingredient
                required = self.task_spec.required_ingredient or "onion"
                disp_key = f"{required}_dispenser"
                if disp_key in state.dispensers:
                    move = self._move_toward(pos, state.dispensers[disp_key])
                    if move:
                        return (move, "static_t2", f"Cook: go get {required}")
            else:
                # Place ingredient in pot
                target_pot = self._find_valid_pot(state)
                if target_pot:
                    move = self._move_toward(pos, state.pots[target_pot]["pos"])
                    if move:
                        return (move, "static_t2", f"Cook: go to {target_pot}")

        elif role == "deliver":
            # Deliver should: get plates, get soup, deliver
            if holding is None:
                # Get a plate
                if "dish_dispenser" in state.dispensers:
                    move = self._move_toward(pos, state.dispensers["dish_dispenser"])
                    if move:
                        return (move, "static_t2", "Deliver: go get plate")
            elif holding == "plate":
                # Find ready pot
                for pot_id, pot in state.pots.items():
                    if pot.get("is_ready"):
                        move = self._move_toward(pos, pot["pos"])
                        if move:
                            return (move, "static_t2", f"Deliver: go to ready {pot_id}")
            elif holding == "soup":
                # Go deliver
                if state.serve_locations:
                    move = self._move_toward(pos, state.serve_locations[0])
                    if move:
                        return (move, "static_t2", "Deliver: go to serve area")

        return None

    def _repair_t3(self, state: SymbolicState, agent_id: int,
                   v: Violation) -> Optional[tuple]:
        """
        Repair T3: Wrong pot.

        Strategy: Go to allowed pot instead.
        """
        if not self.task_spec.allowed_pots:
            return None

        pos = state.agent_positions[agent_id]
        allowed_pot = self.task_spec.allowed_pots[0]

        if allowed_pot in state.pots:
            pot_pos = state.pots[allowed_pot]["pos"]
            move = self._move_toward(pos, pot_pos)
            if move:
                return (move, "static_t3", f"Go to {allowed_pot}")

        return None

    def _repair_h3(self, state: SymbolicState, agent_id: int,
                   v: Violation) -> Optional[tuple]:
        """
        Repair H3: Premature plate pickup.

        Strategy: Go get ingredients instead (continue making soup progress).
        """
        pos = state.agent_positions[agent_id]

        # Find ingredient dispenser
        required = self.task_spec.required_ingredient or "onion"
        dispenser_key = f"{required}_dispenser"

        if dispenser_key in state.dispensers:
            disp_pos = state.dispensers[dispenser_key]
            # If adjacent, pick up
            if self._is_adjacent(pos, disp_pos):
                return (Action.INTERACT, "static_h3", f"Pick up {required} instead of plate")
            # Otherwise move toward
            move = self._move_toward(pos, disp_pos)
            if move:
                return (move, "static_h3", f"Move toward {required} dispenser (no pot cooking yet)")

        return None

    def _dynamic_repair(self, state: SymbolicState, action: Action,
                        agent_id: int, violations: List[Violation]) -> tuple:
        """
        Dynamic repair - use learned precedents.

        Strategy:
        1. Build state summary for retrieval query
        2. Try to retrieve matching precedent
        3. If hit above threshold, use retrieved action
        4. Otherwise fall back to static repair
        5. Queue learning signal for when repair outcome is known

        Returns: (action, source, explanation, hints, hint_explanation)
        """
        self._retrieval_attempts += 1

        # Build state summary for query
        state_dict = self._state_to_dict(state, agent_id, violations)
        action_text = f"proposed={action.name}"
        for v in violations:
            action_text += f" violation={v.constraint_id}"

        # Try precedent retrieval
        if self.precedent_store and self.precedent_store.size() > 0:
            hit = self.precedent_store.retrieve(state_dict, action_text)

            if hit and hit.score >= self.retrieve_threshold:
                self._retrieval_hit_rate += 1
                # Parse retrieved action
                retrieved_action = self._parse_retrieved_action(hit.action)
                if retrieved_action is not None:
                    # Queue learning signal (will confirm if this repair works)
                    self._pending_learnings.append({
                        "state_dict": state_dict,
                        "violations": violations,
                        "repair_action": retrieved_action,
                        "source": "precedent",
                        "hit_score": hit.score,
                    })
                    # Precedent retrieval doesn't include hints (learned from static)
                    return (retrieved_action, "precedent",
                            f"Retrieved (score={hit.score:.2f}): {hit.key}", [], "")

        # Fallback to static repair (includes hints for H1)
        static_result = self._static_repair(state, action, agent_id, violations)

        if static_result[0] is not None:
            # Queue learning signal for static repair
            self._pending_learnings.append({
                "state_dict": state_dict,
                "violations": violations,
                "repair_action": static_result[0],
                "source": "static",
            })

        return static_result

    def _state_to_dict(self, state: SymbolicState, agent_id: int,
                       violations: List[Violation]) -> Dict[str, Any]:
        """Convert symbolic state to dict for precedent matching."""
        holding = state.agent_holdings.get(agent_id)
        facing_pos = state.get_facing_pos(agent_id)
        facing = state.get_object_at(facing_pos)

        # Pot summary
        pot_summary = []
        for pot_id, pot in state.pots.items():
            n_ing = len(pot.get("ingredients", []))
            status = "ready" if pot.get("is_ready") else \
                     "cooking" if pot.get("is_cooking") else \
                     f"{n_ing}/3"
            pot_summary.append(f"{pot_id}:{status}")

        return {
            "scenario": "overcooked",
            "site": facing or "none",
            "intent": violations[0].constraint_id if violations else "none",
            "failure_mode": violations[0].evidence.get("noop_reason", "violation") if violations else "none",
            "holding": holding or "none",
            "facing": facing or "none",
            "pots": ",".join(pot_summary),
            "constraint_text": f"violated={[v.constraint_id for v in violations]}",
        }

    def _parse_retrieved_action(self, action_dict: Dict[str, Any]) -> Optional[Action]:
        """Parse action from retrieved precedent."""
        if "action" in action_dict:
            action_name = action_dict["action"]
            if isinstance(action_name, str):
                try:
                    return Action[action_name.upper()]
                except KeyError:
                    pass
            elif isinstance(action_name, Action):
                return action_name
        # Try text field
        text = action_dict.get("text", "")
        for a in Action:
            if a.name.lower() in text.lower():
                return a
        return None

    def record_repair_outcome(self, success: bool, original_action: Optional[Action] = None,
                              violated_action: Optional[Action] = None):
        """
        Call this after executing repaired action to learn from outcome.

        Args:
            success: True if repair led to valid state progression
            original_action: The original proposed action (x_t that violated)
            violated_action: The chosen candidate that was executed (x~ = chosen)
        """
        if not self._pending_learnings:
            return

        # Process all pending learnings (record stats for BOTH success and fail)
        for learning in self._pending_learnings:
            # Get pattern ID for this repair
            pattern_id = self._get_pattern_id(
                learning["state_dict"].get("intent", "unknown"),
                learning["source"]
            )

            # ALWAYS record pattern stats (success OR fail)
            if self.precedent_store is not None:
                self.precedent_store.record_pattern_outcome(pattern_id, success)

            # Only add to precedent store if successful static repair
            if success and self.precedent_store is not None:
                if learning["source"] == "static":
                    action = learning["repair_action"]
                    self.precedent_store.add(
                        site=learning["state_dict"].get("site", "unknown"),
                        intent=learning["state_dict"].get("intent", "unknown"),
                        failure_mode=learning["state_dict"].get("failure_mode", "violation"),
                        approved_action={
                            "action": action.name,
                            "text": f"repair={action.name}",
                        },
                        bad_action_text=f"violations={[v.constraint_id for v in learning['violations']]}",
                        state=learning["state_dict"],
                    )

            # Theta update: only when success=True and theta enabled
            # x* = repair_action (the good one), x~ = original_action (the one that violated)
            if success and self.enable_theta_updates and original_action is not None:
                self._update_theta(
                    repair_x_star=learning["repair_action"],
                    violated_x_tilde=original_action,  # The action that violated
                    learning=learning
                )

        # Clear all pending learnings
        self._pending_learnings = []

    # ========== Theta Helper Functions ==========

    def _get_pattern_id(self, failure_mode: str, repair_type: str) -> str:
        """Get canonical pattern identifier for tracking stats."""
        fm = (failure_mode or "unknown").lower().replace(" ", "_")
        rt = (repair_type or "unknown").lower().replace(" ", "_")
        return f"{fm}|{rt}"

    def _levenshtein_distance(self, s1: str, s2: str) -> int:
        """Compute Levenshtein edit distance between two strings."""
        if len(s1) < len(s2):
            s1, s2 = s2, s1
        if len(s2) == 0:
            return len(s1)
        prev_row = list(range(len(s2) + 1))
        for i, c1 in enumerate(s1):
            curr_row = [i + 1]
            for j, c2 in enumerate(s2):
                insertions = prev_row[j + 1] + 1
                deletions = curr_row[j] + 1
                substitutions = prev_row[j] + (c1 != c2)
                curr_row.append(min(insertions, deletions, substitutions))
            prev_row = curr_row
        return prev_row[-1]

    def _compute_edit_distance(self, candidate: Action, original: Action) -> float:
        """
        Compute normalized token-level Levenshtein edit distance.
        Returns value in [0, 1] where 0 = identical, 1 = completely different.
        """
        # For Overcooked actions, use action name as string
        s1 = candidate.name if hasattr(candidate, 'name') else str(candidate)
        s2 = original.name if hasattr(original, 'name') else str(original)

        if s1 == s2:
            return 0.0

        # Normalize by max length
        max_len = max(len(s1), len(s2))
        if max_len == 0:
            return 0.0

        dist = self._levenshtein_distance(s1, s2)
        return min(1.0, dist / max_len)

    def _compute_features(self, action: Action, candidate_meta: Dict[str, Any],
                          violations: List[Violation]) -> np.ndarray:
        """
        Compute feature vector f(x') in R^4 for theta scoring.

        Args:
            action: The candidate action
            candidate_meta: Per-candidate metadata with keys:
                - source: "original" | "precedent" | "static"
                - support: similarity score if from precedent (else 0)
                - pattern_id: pattern identifier for stats lookup
            violations: List of violations for severity computation

        Features:
        - f1: precedent_support (similarity score if from precedent, else 0)
        - f2: log(1 + pattern_count)
        - f3: pattern success rate with Beta(1,1) smoothing
        - f4: severity score (sum of violation severities, normalized)
        """
        # f1: Precedent support (candidate-specific)
        f1 = 0.0
        if candidate_meta.get("source") == "precedent":
            f1 = float(candidate_meta.get("support", 0.0))

        # f2, f3: Pattern statistics
        pattern_id = candidate_meta.get("pattern_id", "unknown|unknown")
        stats = {"count": 0, "success": 0, "fail": 0}
        if self.precedent_store is not None:
            stats = self.precedent_store.get_pattern_stats(pattern_id)

        f2 = np.log(1 + stats.get("count", 0))

        # f3: Success rate with Beta(1,1) smoothing
        success = stats.get("success", 0)
        fail = stats.get("fail", 0)
        f3 = (success + 1) / (success + fail + 2)

        # f4: Severity score (varies based on actual violations)
        f4 = 0.0
        if violations:
            severity_scores = []
            for v in violations:
                if hasattr(v, 'severity'):
                    if v.severity == Severity.BLOCK:
                        severity_scores.append(1.0)
                    elif v.severity == Severity.WARN:
                        severity_scores.append(0.3)
                    else:
                        severity_scores.append(0.1)
            if severity_scores:
                f4 = min(1.0, sum(severity_scores) / len(severity_scores))

        features = np.array([f1, f2, f3, f4], dtype=np.float32)

        # Track feature stats for sanity logging
        if not hasattr(self, '_feature_history'):
            self._feature_history = []
        self._feature_history.append(features.copy())

        return features

    def _compute_score(self, action: Action, original: Action, features: np.ndarray) -> float:
        """Compute score for argmin selection: edit_dist + lambda * theta.dot(features)."""
        edit_dist = self._compute_edit_distance(action, original)
        penalty = self.theta_lambda * np.dot(self.theta, features)
        return edit_dist + penalty

    def _update_theta(self, repair_x_star: Action, violated_x_tilde: Action,
                      learning: Dict[str, Any]) -> None:
        """
        Update theta using hinge ranking loss.

        We want score(x*) < score(x~), i.e., repair beats violated.
        loss = max(0, 1 + score(x*) - score(x~))
        If loss > 0: theta <- theta - eta * lambda * (f* - f~)

        Args:
            repair_x_star: The admissible repair action (x*)
            violated_x_tilde: The original action that violated (x~)
            learning: Learning context with source, violations, state_dict
        """
        violations = learning.get("violations", [])

        # Compute features for repair (x*) - has precedent/static support
        meta_star = {
            "source": learning.get("source", "static"),
            "support": learning.get("hit_score", 0.0),
            "pattern_id": self._get_pattern_id(
                learning["state_dict"].get("intent", "unknown"),
                learning.get("source", "static")
            ),
        }
        f_star = self._compute_features(repair_x_star, meta_star, violations)

        # Compute features for violated (x~) - original action, no support
        meta_tilde = {
            "source": "original",
            "support": 0.0,
            "pattern_id": self._get_pattern_id(
                learning["state_dict"].get("intent", "unknown"),
                "original"
            ),
        }
        f_tilde = self._compute_features(violated_x_tilde, meta_tilde, violations)

        # Compute scores (lower is better)
        s_star = self._compute_score(repair_x_star, violated_x_tilde, f_star)
        s_tilde = self._compute_score(violated_x_tilde, violated_x_tilde, f_tilde)

        # Hinge ranking loss: want s_star + 1 <= s_tilde
        loss = max(0.0, 1.0 + s_star - s_tilde)

        if loss > 0:
            # Gradient: d(loss)/d(theta) = lambda * (f_star - f_tilde)
            # Update: theta <- theta - eta * grad
            grad = self.theta_lambda * (f_star - f_tilde)
            self.theta = self.theta - self.theta_eta * grad
            self.theta = np.clip(self.theta, -self.theta_max_norm, self.theta_max_norm)
            self.theta_updates_count += 1

            # Log update for debugging
            if not hasattr(self, '_theta_update_history'):
                self._theta_update_history = []
            self._theta_update_history.append({
                "loss": loss,
                "f_star": f_star.tolist(),
                "f_tilde": f_tilde.tolist(),
                "theta_after": self.theta.tolist(),
            })

    def get_theta_info(self) -> Dict[str, Any]:
        """Get theta state and feature statistics for logging/analysis."""
        info = {
            "theta": self.theta.tolist(),
            "theta_updates_count": self.theta_updates_count,
            "enable_theta_updates": self.enable_theta_updates,
            "theta_lambda": self.theta_lambda,
            "theta_eta": self.theta_eta,
        }

        # Add feature distribution stats if available
        if hasattr(self, '_feature_history') and self._feature_history:
            features = np.array(self._feature_history)
            info["feature_stats"] = {
                "n_samples": len(self._feature_history),
                "f1_support_mean": float(np.mean(features[:, 0])),
                "f1_support_std": float(np.std(features[:, 0])),
                "f2_frequency_mean": float(np.mean(features[:, 1])),
                "f2_frequency_std": float(np.std(features[:, 1])),
                "f3_confidence_mean": float(np.mean(features[:, 2])),
                "f3_confidence_std": float(np.std(features[:, 2])),
                "f4_severity_mean": float(np.mean(features[:, 3])),
                "f4_severity_std": float(np.std(features[:, 3])),
            }

        # Add pattern stats summary
        if self.precedent_store is not None:
            all_stats = self.precedent_store.get_all_pattern_stats()
            info["pattern_stats_summary"] = {
                "n_patterns": len(all_stats),
                "total_count": sum(s.get("count", 0) for s in all_stats.values()),
                "total_success": sum(s.get("success", 0) for s in all_stats.values()),
                "total_fail": sum(s.get("fail", 0) for s in all_stats.values()),
            }

        return info

    def print_theta_summary(self) -> None:
        """Print theta and feature summary for sanity checking."""
        info = self.get_theta_info()
        print("\n" + "=" * 50)
        print("THETA SUMMARY")
        print("=" * 50)
        print(f"Theta: {info['theta']}")
        print(f"  [f1=support, f2=frequency, f3=confidence, f4=severity]")
        print(f"Updates: {info['theta_updates_count']}")

        if "feature_stats" in info:
            fs = info["feature_stats"]
            print(f"\nFeature Distribution ({fs['n_samples']} samples):")
            print(f"  f1 (support):    mean={fs['f1_support_mean']:.3f}, std={fs['f1_support_std']:.3f}")
            print(f"  f2 (frequency):  mean={fs['f2_frequency_mean']:.3f}, std={fs['f2_frequency_std']:.3f}")
            print(f"  f3 (confidence): mean={fs['f3_confidence_mean']:.3f}, std={fs['f3_confidence_std']:.3f}")
            print(f"  f4 (severity):   mean={fs['f4_severity_mean']:.3f}, std={fs['f4_severity_std']:.3f}")

        if "pattern_stats_summary" in info:
            ps = info["pattern_stats_summary"]
            print(f"\nPattern Stats:")
            print(f"  Patterns: {ps['n_patterns']}, Count: {ps['total_count']}")
            print(f"  Success: {ps['total_success']}, Fail: {ps['total_fail']}")
        print("=" * 50 + "\n")

    def get_retrieval_stats(self) -> Dict[str, Any]:
        """Get retrieval statistics."""
        return {
            "attempts": self._retrieval_attempts,
            "hits": self._retrieval_hit_rate,
            "hit_rate": self._retrieval_hit_rate / max(1, self._retrieval_attempts),
            "store_size": self.precedent_store.size() if self.precedent_store else 0,
        }

    # =========================================================================
    # Helper methods
    # =========================================================================

    def _find_valid_pot(self, state: SymbolicState) -> Optional[str]:
        """Find a pot that can accept ingredients and is allowed by task spec."""
        for pot_id, pot in state.pots.items():
            # Check task constraint
            if self.task_spec.allowed_pots and pot_id not in self.task_spec.allowed_pots:
                continue
            # Check if pot can accept
            if not pot.get("is_cooking") and not pot.get("is_ready"):
                if len(pot.get("ingredients", [])) < self.task_spec.required_ingredients:
                    return pot_id
        return None

    def _is_adjacent(self, pos1: tuple, pos2: tuple) -> bool:
        """Check if two positions are adjacent."""
        dx = abs(pos1[0] - pos2[0])
        dy = abs(pos1[1] - pos2[1])
        return (dx == 1 and dy == 0) or (dx == 0 and dy == 1)

    def _move_toward(self, current: tuple, target: tuple) -> Optional[Action]:
        """Get one-step greedy move toward target."""
        dx = target[0] - current[0]
        dy = target[1] - current[1]

        # Prefer larger delta
        if abs(dx) >= abs(dy):
            if dx > 0:
                return Action.RIGHT
            elif dx < 0:
                return Action.LEFT

        if dy > 0:
            return Action.DOWN
        elif dy < 0:
            return Action.UP

        # Already adjacent or at target
        return None

    def _update_streak(self, action: Action):
        """Track repeated action streaks for stuckness metrics."""
        if action == self._last_action:
            self._action_streak += 1
        else:
            self._action_streak = 1
            self._last_action = action

    def get_streak(self) -> int:
        """Get current action streak length."""
        return self._action_streak

    def reset(self):
        """Reset operator state for new episode."""
        self._last_action = None
        self._action_streak = 0
        # Reset episode context for T4/T5/T6 constraints
        self._episode_context = {
            "action_history": [],
            "agent_step_counts": defaultdict(int),
        }


class NoOperator:
    """
    No-Operator baseline - passes actions through unchanged.

    Still tracks violations for comparison metrics.
    """

    def __init__(self, task_spec: TaskSpec):
        self.task_spec = task_spec
        self.detector = ViolationDetector(task_spec)
        self._last_action = None
        self._action_streak = 0
        # Episode context for T4/T5/T6 constraints
        self._episode_context: Dict[str, Any] = {
            "action_history": [],
            "agent_step_counts": defaultdict(int),
        }

    def reset_episode(self):
        """Reset episode-level tracking for new episode."""
        self._episode_context = {
            "action_history": [],
            "agent_step_counts": defaultdict(int),
        }
        self._last_action = None
        self._action_streak = 0

    def project(self, state: SymbolicState, action: Action, agent_id: int) -> OperatorDecision:
        """Pass action through, just track violations."""
        self._update_streak(action)

        violations = self.detector.check(state, action, agent_id, episode_context=self._episode_context)

        # Update episode context
        self._episode_context["action_history"].append((agent_id, action))
        self._episode_context["agent_step_counts"][agent_id] += 1

        return OperatorDecision(
            outcome=OperatorOutcome.ALLOW,
            original_action=action,
            final_action=action,  # No changes
            violations=violations,
        )

    def _update_streak(self, action: Action):
        if action == self._last_action:
            self._action_streak += 1
        else:
            self._action_streak = 1
            self._last_action = action

    def get_streak(self) -> int:
        return self._action_streak

    def reset(self):
        self._last_action = None
        self._action_streak = 0
        # Reset episode context for T4/T5/T6 constraints
        self._episode_context = {
            "action_history": [],
            "agent_step_counts": defaultdict(int),
        }


class RuleBasedOperator:
    """
    Runtime Backtracking Repair (RBR) baseline.

    Retries up to k=3 times on violation:
    1. Try rule-based repair (same as JO_static)
    2. If still violating: request LLM regeneration with violation feedback
    3. No precedent storage (stateless across steps)

    This baseline represents the Agent-C/shielding family without
    task-specific solvers - stateless retry with repair.
    """

    def __init__(self, task_spec: TaskSpec, llm_regenerate_fn=None, k: int = 3):
        """
        Args:
            task_spec: Task constraints
            llm_regenerate_fn: Function(observation, violation_feedback) -> Action
                               Called when rule-based repair fails
            k: Maximum retry attempts (default 3)
        """
        self.task_spec = task_spec
        self.detector = ViolationDetector(task_spec)
        self.llm_regenerate_fn = llm_regenerate_fn
        self.k = k

        # Retry statistics
        self.retries_per_step: List[int] = []
        self.steps_hitting_max_k: int = 0
        self.total_regenerations: int = 0

        # Episode context for T4/T5/T6 constraints
        self._episode_context: Dict[str, Any] = {
            "action_history": [],
            "agent_step_counts": defaultdict(int),
        }
        self._last_action: Optional[Action] = None
        self._action_streak: int = 0

    def reset_episode(self):
        """Reset episode-level tracking."""
        self._episode_context = {
            "action_history": [],
            "agent_step_counts": defaultdict(int),
        }
        self._last_action = None
        self._action_streak = 0

    def reset_stats(self):
        """Reset retry statistics (call between experiments)."""
        self.retries_per_step = []
        self.steps_hitting_max_k = 0
        self.total_regenerations = 0

    def project(self, state: SymbolicState, action: Action, agent_id: int,
              observation: Optional[str] = None) -> OperatorDecision:
        """
        Project action with backtracking repair.

        Args:
            state: Current symbolic state
            action: Proposed action from agent
            agent_id: Which agent
            observation: Text observation (needed for LLM regeneration)

        Returns:
            OperatorDecision with outcome and possibly repaired action
        """
        self._update_streak(action)

        current_action = action
        retries = 0
        all_violations = []
        repair_source = None
        repair_explanation = None

        for attempt in range(self.k):
            # Check for violations
            violations = self.detector.check(
                state, current_action, agent_id,
                episode_context=self._episode_context
            )

            if not violations:
                # Success - no violations
                break

            all_violations = violations
            retries += 1

            # Attempt 1: Try rule-based static repair
            repair = self._static_repair(state, current_action, agent_id, violations)
            if repair is not None:
                current_action = repair
                repair_source = "static_rbr"
                repair_explanation = f"Rule-based repair (attempt {attempt + 1})"
                # Re-check this repaired action
                continue

            # Attempt 2: LLM regeneration with violation feedback
            if attempt < self.k - 1 and self.llm_regenerate_fn and observation:
                violation_feedback = self._format_violation_feedback(violations)
                try:
                    regenerated = self.llm_regenerate_fn(observation, violation_feedback)
                    if regenerated is not None:
                        current_action = regenerated
                        repair_source = "llm_regeneration"
                        repair_explanation = f"LLM regeneration (attempt {attempt + 1})"
                        self.total_regenerations += 1
                        continue
                except Exception as e:
                    # LLM call failed, continue to next attempt
                    pass

            # No repair available, will exit loop with violation

        # Record retry statistics
        self.retries_per_step.append(retries)
        if retries >= self.k:
            self.steps_hitting_max_k += 1

        # Update episode context
        self._episode_context["action_history"].append((agent_id, current_action))
        self._episode_context["agent_step_counts"][agent_id] += 1

        # Final violation check
        final_violations = self.detector.check(
            state, current_action, agent_id,
            episode_context=self._episode_context
        )

        if not final_violations:
            if current_action == action:
                return OperatorDecision(
                    outcome=OperatorOutcome.ALLOW,
                    original_action=action,
                    final_action=current_action,
                    violations=[],
                )
            else:
                return OperatorDecision(
                    outcome=OperatorOutcome.EDIT,
                    original_action=action,
                    final_action=current_action,
                    violations=all_violations,
                    repair_source=repair_source,
                    repair_explanation=repair_explanation,
                )
        else:
            # Still violating after k attempts - use STAY as fallback
            return OperatorDecision(
                outcome=OperatorOutcome.BLOCK,
                original_action=action,
                final_action=Action.STAY,
                violations=final_violations,
                repair_source="block_rbr",
                repair_explanation=f"No valid repair after {self.k} attempts",
            )

    def _static_repair(self, state: SymbolicState, action: Action,
                       agent_id: int, violations: List[Violation]) -> Optional[Action]:
        """
        Apply rule-based static repair (same logic as JudgmentOperator._static_repair).

        Returns repaired action or None if no rule applies.
        """
        for v in violations:
            repair = self._repair_for_violation(state, agent_id, v)
            if repair is not None:
                return repair
        return None

    def _repair_for_violation(self, state: SymbolicState, agent_id: int,
                              v: Violation) -> Optional[Action]:
        """Get static repair for a specific violation."""
        pos = state.agent_positions[agent_id]
        holding = state.agent_holdings.get(agent_id)

        if v.constraint_id == "H1":
            # H1: Holding item, trying to pick → place held item
            return self._repair_h1_simple(state, agent_id)

        elif v.constraint_id == "R1":
            # R1: Premature cook → get more ingredients
            return self._move_toward_dispenser(state, agent_id)

        elif v.constraint_id == "T1":
            # T1: Wrong ingredient → go to correct dispenser
            return self._move_toward_dispenser(state, agent_id)

        elif v.constraint_id == "T3":
            # T3: Wrong pot → go to correct pot
            return self._move_toward_allowed_pot(state, agent_id)

        elif v.constraint_id == "H3":
            # H3: Premature plate pickup → get ingredients instead
            return self._move_toward_dispenser(state, agent_id)

        elif v.constraint_id == "T4":
            # T4: Too many STAYs → force movement
            return Action.UP

        elif v.constraint_id in ["T5", "T6"]:
            # T5/T6: Wrong turn or step limit → STAY
            return Action.STAY

        return None

    def _repair_h1_simple(self, state: SymbolicState, agent_id: int) -> Optional[Action]:
        """Simple H1 repair: try to place held item appropriately."""
        holding = state.agent_holdings.get(agent_id)
        pos = state.agent_positions[agent_id]
        orient = state.agent_orientations.get(agent_id, (0, 1))
        facing_pos = (pos[0] + orient[0], pos[1] + orient[1])
        facing = state.get_object_at(facing_pos)

        if holding in ["onion", "tomato"]:
            # Holding ingredient - place in pot if facing one
            if facing and facing.startswith("pot"):
                pot = state.pots.get(facing)
                if pot and not pot.get("is_cooking") and not pot.get("is_ready"):
                    return Action.INTERACT
            # Move toward pot
            return self._move_toward_any_pot(state, agent_id)

        elif holding == "plate":
            # Holding plate - get soup if pot is ready
            if facing and facing.startswith("pot"):
                pot = state.pots.get(facing)
                if pot and pot.get("is_ready"):
                    return Action.INTERACT

        elif holding == "soup":
            # Holding soup - deliver
            if facing == "serve":
                return Action.INTERACT
            # Move toward serve
            if state.serve_locations:
                return self._move_toward(pos, state.serve_locations[0])

        # Fallback: try to place on counter
        if state.counters and facing_pos in state.counters:
            return Action.INTERACT

        return Action.UP  # Default movement

    def _move_toward_dispenser(self, state: SymbolicState, agent_id: int) -> Optional[Action]:
        """Move toward the required ingredient dispenser."""
        pos = state.agent_positions[agent_id]
        required = self.task_spec.required_ingredient or "onion"
        dispenser_key = f"{required}_dispenser"

        if dispenser_key in state.dispensers:
            return self._move_toward(pos, state.dispensers[dispenser_key])
        return None

    def _move_toward_allowed_pot(self, state: SymbolicState, agent_id: int) -> Optional[Action]:
        """Move toward an allowed pot."""
        if not self.task_spec.allowed_pots:
            return None
        pos = state.agent_positions[agent_id]
        allowed = self.task_spec.allowed_pots[0]
        if allowed in state.pots:
            return self._move_toward(pos, state.pots[allowed]["pos"])
        return None

    def _move_toward_any_pot(self, state: SymbolicState, agent_id: int) -> Optional[Action]:
        """Move toward any valid pot."""
        pos = state.agent_positions[agent_id]
        for pot_id, pot in state.pots.items():
            if self.task_spec.allowed_pots is None or pot_id in self.task_spec.allowed_pots:
                return self._move_toward(pos, pot["pos"])
        return None

    def _move_toward(self, current: tuple, target: tuple) -> Optional[Action]:
        """Simple movement toward target."""
        dx = target[0] - current[0]
        dy = target[1] - current[1]

        if abs(dx) > abs(dy):
            return Action.RIGHT if dx > 0 else Action.LEFT
        elif dy != 0:
            return Action.DOWN if dy > 0 else Action.UP
        return None

    def _format_violation_feedback(self, violations: List[Violation]) -> str:
        """Format violations as feedback for LLM regeneration."""
        feedback_parts = ["Your action was invalid. Violations:"]
        for v in violations:
            feedback_parts.append(f"- {v.constraint_id}: {v.description}")
        feedback_parts.append("\nPlease choose a different action.")
        return "\n".join(feedback_parts)

    def _update_streak(self, action: Action):
        if action == self._last_action:
            self._action_streak += 1
        else:
            self._action_streak = 1
            self._last_action = action

    def get_streak(self) -> int:
        return self._action_streak

    def get_retry_stats(self) -> Dict[str, Any]:
        """Get retry statistics for reporting."""
        if not self.retries_per_step:
            return {
                "mean_retries": 0.0,
                "max_retries": 0,
                "pct_hitting_max_k": 0.0,
                "total_regenerations": 0,
            }
        return {
            "mean_retries": sum(self.retries_per_step) / len(self.retries_per_step),
            "max_retries": max(self.retries_per_step),
            "pct_hitting_max_k": 100.0 * self.steps_hitting_max_k / len(self.retries_per_step),
            "total_regenerations": self.total_regenerations,
        }

    def reset(self):
        self._last_action = None
        self._action_streak = 0
        self._episode_context = {
            "action_history": [],
            "agent_step_counts": defaultdict(int),
        }
