"""
Overcooked Constraint Definitions and Violation Detector

Design principles:
- All constraints are task-level (environment unchanged)
- Detector is pure function of (state, action, task_spec)
- H1 uses state-based predicates, not "step and check"
"""

from dataclasses import dataclass, field
from enum import Enum
from typing import List, Dict, Optional, Any


class Severity(Enum):
    WARN = "warn"
    BLOCK = "block"
    ESCALATE = "escalate"


class Action(Enum):
    UP = 0
    DOWN = 1
    LEFT = 2
    RIGHT = 3
    INTERACT = 4
    STAY = 5


@dataclass
class TaskSpec:
    """Task-level constraints (not environment physics)."""
    required_ingredient: Optional[str] = None  # "onion" for onion-only task
    required_ingredients: int = 3  # Number of ingredients needed for soup (default 3)
    agent_roles: Optional[Dict[int, str]] = None  # {0: "cook", 1: "deliver"}
    allowed_pots: Optional[List[str]] = None  # ["pot_1"] for pot-1-only task
    enforce_plate_timing: bool = False  # H3: Don't pick plate until pot is cooking/ready

    # STRICT constraints for higher VR
    max_consecutive_stays: int = 0  # T4: Max STAY actions in a row (0=disabled)
    require_alternating_agents: bool = False  # T5: Agents must alternate (no consecutive same-agent actions)
    max_steps_per_agent: int = 0  # T6: Each agent limited to N steps per episode (0=disabled)
    enforce_pickup_before_move: bool = False  # H4: Must pick up item before moving away from dispenser
    require_both_agents_contribute: bool = False  # T7: Both agents must add to pot

    @classmethod
    def default(cls) -> "TaskSpec":
        """No task constraints - baseline."""
        return cls()

    @classmethod
    def onion_only(cls, required_ingredients: int = 2, enforce_plate_timing: bool = False) -> "TaskSpec":
        """T1: Only onion soup allowed. Default 2 onions for easier progress.

        Args:
            required_ingredients: Number of onions needed (default 2)
            enforce_plate_timing: If True, enable H3 constraint (don't pick plate until pot cooking/ready)
        """
        return cls(required_ingredient="onion", required_ingredients=required_ingredients,
                   enforce_plate_timing=enforce_plate_timing)

    @classmethod
    def role_constrained(cls) -> "TaskSpec":
        """T2: Agent 0 cooks, Agent 1 delivers."""
        return cls(agent_roles={0: "cook", 1: "deliver"})

    @classmethod
    def pot_1_only(cls) -> "TaskSpec":
        """T3: Only use pot_1."""
        return cls(allowed_pots=["pot_1"])


@dataclass
class Violation:
    """A detected constraint violation."""
    constraint_id: str  # R1, H1, T1, T2, T3
    severity: Severity
    evidence: Dict[str, Any]
    repair: "Repair"


@dataclass
class Repair:
    """A suggested repair action."""
    action: Action
    target: Optional[str] = None  # e.g., "pot_1", "onion_dispenser"
    explanation: str = ""

    # For H1: multiple candidate repairs for JO_dynamic to learn from
    alternatives: List["Repair"] = field(default_factory=list)


@dataclass
class SymbolicState:
    """
    Symbolic observation (not 26-channel grid).
    Designed for reproducibility and clean constraint checking.
    """
    # Agent state
    agent_positions: Dict[int, tuple]  # {0: (2, 1), 1: (3, 2)}
    agent_orientations: Dict[int, tuple]  # {0: (0, -1), 1: (1, 0)} facing direction
    agent_holdings: Dict[int, Optional[str]]  # {0: "onion", 1: None}

    # Pot state
    pots: Dict[str, Dict[str, Any]]  # {"pot_1": {"pos": (2,0), "ingredients": ["onion", "onion"], "is_cooking": False, "is_ready": False}}

    # Layout info
    dispensers: Dict[str, tuple]  # {"onion_dispenser": (0, 1), "tomato_dispenser": (4, 1)}
    counters: List[tuple]  # Empty counter positions
    serve_locations: List[tuple]

    def get_facing_pos(self, agent_id: int) -> tuple:
        """Get position agent is facing."""
        pos = self.agent_positions[agent_id]
        orient = self.agent_orientations[agent_id]
        return (pos[0] + orient[0], pos[1] + orient[1])

    def get_object_at(self, pos: tuple) -> Optional[str]:
        """Get what's at a position."""
        for pot_id, pot in self.pots.items():
            if pot["pos"] == pos:
                return pot_id
        for disp_name, disp_pos in self.dispensers.items():
            if disp_pos == pos:
                return disp_name
        if pos in self.serve_locations:
            return "serve"
        if pos in self.counters:
            return "counter"
        return None


class ViolationDetector:
    """
    Pure function detector: (state, action, task_spec, agent_id) -> List[Violation]

    All constraints are task-level. Environment is unchanged.
    """

    def __init__(self, task_spec: TaskSpec):
        self.task_spec = task_spec

    def check(self,
              state: SymbolicState,
              action: Action,
              agent_id: int,
              action_target: Optional[str] = None,
              episode_context: Optional[Dict[str, Any]] = None) -> List[Violation]:
        """
        Check all constraints. Returns list of violations (may be empty).

        Args:
            state: Symbolic state observation
            action: Discrete action (UP/DOWN/LEFT/RIGHT/INTERACT/STAY)
            agent_id: Which agent (0 or 1)
            action_target: Optional explicit target (e.g., "pot_1") for INTERACT
            episode_context: Optional context for episode-level constraints (step counts, history)

        Returns:
            List of Violation objects with constraint_id, severity, evidence, repair
        """
        violations = []
        episode_context = episode_context or {}

        # R1: Premature cooking (crisp)
        if v := self._check_r1_premature_cook(state, action, agent_id, action_target):
            violations.append(v)

        # T1: Wrong ingredient (crisp)
        if v := self._check_t1_wrong_ingredient(state, action, agent_id):
            violations.append(v)

        # T2: Role violation (crisp)
        if v := self._check_t2_role_violation(state, action, agent_id):
            violations.append(v)

        # T3: Wrong pot (crisp)
        if v := self._check_t3_wrong_pot(state, action, agent_id, action_target):
            violations.append(v)

        # H1: Recovery choice (state-based, not behavior-based)
        if v := self._check_h1_recovery_choice(state, action, agent_id):
            violations.append(v)

        # H3: Premature plate pickup (don't pick plate until pot is cooking/ready)
        if v := self._check_h3_premature_plate(state, action, agent_id):
            violations.append(v)

        # STRICT CONSTRAINTS (for higher VR)
        # T4: Max consecutive STAY actions
        if v := self._check_t4_consecutive_stays(action, agent_id, episode_context):
            violations.append(v)

        # T5: Alternating agents
        if v := self._check_t5_alternating_agents(agent_id, episode_context):
            violations.append(v)

        # T6: Max steps per agent
        if v := self._check_t6_max_steps(agent_id, episode_context):
            violations.append(v)

        return violations

    # =========================================================================
    # STRICT CONSTRAINTS (T4, T5, T6) for higher baseline VR
    # =========================================================================

    def _check_t4_consecutive_stays(self, action: Action, agent_id: int,
                                     episode_context: Dict[str, Any]) -> Optional[Violation]:
        """T4: Limit consecutive STAY actions (prevents idle agents)."""
        max_stays = self.task_spec.max_consecutive_stays
        if max_stays <= 0:
            return None

        if action != Action.STAY:
            return None

        # Count consecutive stays from history
        # History stores (agent_id, action) tuples
        history = episode_context.get("action_history", [])
        consecutive = 0
        for past_agent, past_action in reversed(history):
            if past_agent == agent_id and past_action == Action.STAY:
                consecutive += 1
            elif past_agent == agent_id:
                break

        if consecutive < max_stays:
            return None

        return Violation(
            constraint_id="T4",
            severity=Severity.BLOCK,
            evidence={
                "agent": agent_id,
                "consecutive_stays": consecutive + 1,
                "max_allowed": max_stays
            },
            repair=Repair(
                action=Action.INTERACT,
                explanation=f"Agent {agent_id} exceeded max consecutive STAY actions ({max_stays})"
            )
        )

    def _check_t5_alternating_agents(self, agent_id: int,
                                      episode_context: Dict[str, Any]) -> Optional[Violation]:
        """T5: Agents must alternate (no consecutive same-agent meaningful actions)."""
        if not self.task_spec.require_alternating_agents:
            return None

        history = episode_context.get("action_history", [])
        if not history:
            return None

        # Check if last meaningful action was by same agent
        # History stores (agent_id, action) tuples
        last_agent, last_action = history[-1]
        if last_agent == agent_id and last_action != Action.STAY:
            return Violation(
                constraint_id="T5",
                severity=Severity.BLOCK,
                evidence={
                    "agent": agent_id,
                    "last_agent": last_agent,
                    "reason": "consecutive_same_agent"
                },
                repair=Repair(
                    action=Action.STAY,
                    explanation=f"Agent {agent_id} must wait - other agent's turn"
                )
            )
        return None

    def _check_t6_max_steps(self, agent_id: int,
                            episode_context: Dict[str, Any]) -> Optional[Violation]:
        """T6: Each agent has limited steps per episode."""
        max_steps = self.task_spec.max_steps_per_agent
        if max_steps <= 0:
            return None

        agent_steps = episode_context.get("agent_step_counts", {}).get(agent_id, 0)
        if agent_steps < max_steps:
            return None

        return Violation(
            constraint_id="T6",
            severity=Severity.BLOCK,
            evidence={
                "agent": agent_id,
                "steps_taken": agent_steps,
                "max_allowed": max_steps
            },
            repair=Repair(
                action=Action.STAY,
                explanation=f"Agent {agent_id} exceeded step budget ({max_steps})"
            )
        )

    # =========================================================================
    # R1: Premature Cooking
    # =========================================================================

    def _check_r1_premature_cook(self, state: SymbolicState, action: Action,
                                  agent_id: int, action_target: Optional[str]) -> Optional[Violation]:
        """
        R1: Starting cook with insufficient ingredients.

        Trigger: INTERACT at pot that is idle with < required_ingredients, agent not holding anything
        (This would start cooking prematurely)
        """
        if action != Action.INTERACT:
            return None

        # Agent must not be holding anything (starting cook, not adding)
        if state.agent_holdings.get(agent_id) is not None:
            return None

        # Find what agent is facing
        facing_pos = state.get_facing_pos(agent_id)
        target = action_target or state.get_object_at(facing_pos)

        if target is None or not target.startswith("pot"):
            return None

        pot = state.pots.get(target)
        if pot is None:
            return None

        # Check: pot is idle (not cooking/ready) AND has < required ingredients
        if pot.get("is_cooking") or pot.get("is_ready"):
            return None

        required = self.task_spec.required_ingredients
        ingredient_count = len(pot.get("ingredients", []))
        if ingredient_count >= required:
            return None

        # Violation: trying to start cook with insufficient ingredients
        return Violation(
            constraint_id="R1",
            severity=Severity.BLOCK,
            evidence={
                "pot": target,
                "ingredient_count": ingredient_count,
                "action": "start_cook",
                "needed": required
            },
            repair=Repair(
                action=Action.INTERACT,
                target=self._find_ingredient_dispenser(state),
                explanation=f"Add more ingredients first (have {ingredient_count}, need {required})"
            )
        )

    # =========================================================================
    # T1: Wrong Ingredient
    # =========================================================================

    def _check_t1_wrong_ingredient(self, state: SymbolicState, action: Action,
                                    agent_id: int) -> Optional[Violation]:
        """
        T1: Picking wrong ingredient when task specifies ingredient constraint.

        Trigger: INTERACT at wrong INGREDIENT dispenser (onion/tomato only, not dish)
        """
        if self.task_spec.required_ingredient is None:
            return None

        if action != Action.INTERACT:
            return None

        # Agent must not be holding anything (picking up)
        if state.agent_holdings.get(agent_id) is not None:
            return None

        facing_pos = state.get_facing_pos(agent_id)
        target = state.get_object_at(facing_pos)

        if target is None or "dispenser" not in target:
            return None

        # Only check INGREDIENT dispensers (onion, tomato), not dish dispenser
        if "dish" in target:
            return None  # Dish/plate dispenser is not an ingredient constraint

        # Extract ingredient type from dispenser name
        ingredient = target.replace("_dispenser", "").replace("_1", "").replace("_2", "")
        required = self.task_spec.required_ingredient

        if ingredient == required:
            return None

        # Violation: picking wrong ingredient
        return Violation(
            constraint_id="T1",
            severity=Severity.BLOCK,
            evidence={
                "picked": ingredient,
                "required": required,
                "dispenser": target
            },
            repair=Repair(
                action=Action.INTERACT,
                target=f"{required}_dispenser",
                explanation=f"Task requires {required}, not {ingredient}"
            )
        )

    # =========================================================================
    # T2: Role Violation
    # =========================================================================

    def _check_t2_role_violation(self, state: SymbolicState, action: Action,
                                  agent_id: int) -> Optional[Violation]:
        """
        T2: Agent doing action outside their assigned role.

        Roles:
        - "cook": can pick ingredients, add to pot, start cook
        - "deliver": can pick plates, get soup, deliver
        """
        if self.task_spec.agent_roles is None:
            return None

        role = self.task_spec.agent_roles.get(agent_id)
        if role is None:
            return None

        if action != Action.INTERACT:
            return None

        facing_pos = state.get_facing_pos(agent_id)
        target = state.get_object_at(facing_pos)
        holding = state.agent_holdings.get(agent_id)

        violation_action = None

        if role == "cook":
            # Cook should NOT: pick plates, deliver soup
            if target == "dish_dispenser" and holding is None:
                violation_action = "pick_plate"
            elif target == "serve" and holding == "soup":
                violation_action = "deliver"

        elif role == "deliver":
            # Deliver should NOT: pick ingredients, add to pot, start cook
            if target and "dispenser" in target and "dish" not in target and holding is None:
                violation_action = "pick_ingredient"
            elif target and target.startswith("pot"):
                pot = state.pots.get(target, {})
                if holding in ["onion", "tomato"]:
                    violation_action = "add_to_pot"
                elif holding is None and not pot.get("is_cooking") and not pot.get("is_ready"):
                    violation_action = "start_cook"

        if violation_action is None:
            return None

        return Violation(
            constraint_id="T2",
            severity=Severity.BLOCK,
            evidence={
                "agent": agent_id,
                "role": role,
                "attempted": violation_action,
                "target": target
            },
            repair=Repair(
                action=Action.STAY,
                explanation=f"Agent {agent_id} has role '{role}', cannot {violation_action}"
            )
        )

    # =========================================================================
    # T3: Wrong Pot
    # =========================================================================

    def _check_t3_wrong_pot(self, state: SymbolicState, action: Action,
                            agent_id: int, action_target: Optional[str]) -> Optional[Violation]:
        """
        T3: Using disallowed pot when task specifies pot constraint.

        Trigger: INTERACT with pot not in allowed_pots
        """
        if self.task_spec.allowed_pots is None:
            return None

        if action != Action.INTERACT:
            return None

        facing_pos = state.get_facing_pos(agent_id)
        target = action_target or state.get_object_at(facing_pos)

        if target is None or not target.startswith("pot"):
            return None

        if target in self.task_spec.allowed_pots:
            return None

        # Violation: using wrong pot
        allowed = self.task_spec.allowed_pots[0]  # Default repair to first allowed
        return Violation(
            constraint_id="T3",
            severity=Severity.BLOCK,
            evidence={
                "used_pot": target,
                "allowed_pots": self.task_spec.allowed_pots
            },
            repair=Repair(
                action=Action.INTERACT,
                target=allowed,
                explanation=f"Task requires using {allowed}, not {target}"
            )
        )

    # =========================================================================
    # H1: Recovery Choice (State-Based)
    # =========================================================================

    def _check_h1_recovery_choice(self, state: SymbolicState, action: Action,
                                   agent_id: int) -> Optional[Violation]:
        """
        H1: Agent taking action that would no-op, when better recovery exists.

        STATE-BASED predicates (not behavior-based):
        - Holding item + INTERACT at dispenser → guaranteed no-op (can't pick while holding)
        - Holding item + INTERACT at full counter → no-op (can't place)

        Returns violation with candidate repairs for JO_dynamic to learn from.
        """
        holding = state.agent_holdings.get(agent_id)

        # Only relevant if agent is holding something
        if holding is None:
            return None

        if action != Action.INTERACT:
            return None

        facing_pos = state.get_facing_pos(agent_id)
        target = state.get_object_at(facing_pos)

        # State-based no-op predicates:
        would_noop = False
        noop_reason = None

        # Case 1: Holding item, trying to pick from dispenser
        if target and "dispenser" in target:
            would_noop = True
            noop_reason = "cannot_pick_while_holding"

        # Case 2: Holding non-ingredient, trying to interact with pot
        # (Only ingredients can be added to pot)
        if target and target.startswith("pot"):
            if holding not in ["onion", "tomato"]:
                # Holding plate or soup - can't add to pot (unless pot is ready and holding plate)
                pot = state.pots.get(target, {})
                if holding == "plate" and pot.get("is_ready"):
                    would_noop = False  # This is valid: picking up soup
                else:
                    would_noop = True
                    noop_reason = f"cannot_add_{holding}_to_pot"

        if not would_noop:
            return None

        # Generate context-dependent repair candidates
        repairs = self._generate_h1_repairs(state, agent_id, holding)

        return Violation(
            constraint_id="H1",
            severity=Severity.WARN,  # Less severe - it's about efficiency
            evidence={
                "holding": holding,
                "target": target,
                "noop_reason": noop_reason
            },
            repair=Repair(
                action=repairs[0].action if repairs else Action.STAY,
                target=repairs[0].target if repairs else None,
                explanation="Action would no-op; better recovery exists",
                alternatives=repairs[1:] if len(repairs) > 1 else []
            )
        )

    def _generate_h1_repairs(self, state: SymbolicState, agent_id: int,
                              holding: str) -> List[Repair]:
        """
        Generate context-dependent repair candidates for H1.
        JO_dynamic learns which is optimal per context.
        """
        repairs = []

        # Repair 1: Use what you're holding appropriately
        if holding == "onion" or holding == "tomato":
            # Find pot that can accept ingredient
            required = self.task_spec.required_ingredients
            for pot_id, pot in state.pots.items():
                if not pot.get("is_cooking") and not pot.get("is_ready"):
                    if len(pot.get("ingredients", [])) < required:
                        repairs.append(Repair(
                            action=Action.INTERACT,
                            target=pot_id,
                            explanation=f"Add {holding} to {pot_id}"
                        ))
                        break

        elif holding == "plate":
            # Find ready pot to get soup
            for pot_id, pot in state.pots.items():
                if pot.get("is_ready"):
                    repairs.append(Repair(
                        action=Action.INTERACT,
                        target=pot_id,
                        explanation=f"Get soup from {pot_id}"
                    ))
                    break

        elif holding == "soup":
            # Deliver it
            repairs.append(Repair(
                action=Action.INTERACT,
                target="serve",
                explanation="Deliver soup"
            ))

        # Repair 2: Drop on empty counter (always available fallback)
        if state.counters:
            repairs.append(Repair(
                action=Action.INTERACT,
                target="counter",
                explanation="Drop item on counter"
            ))

        return repairs

    # =========================================================================
    # H3: Premature Plate Pickup
    # =========================================================================

    def _check_h3_premature_plate(self, state: SymbolicState, action: Action,
                                   agent_id: int) -> Optional[Violation]:
        """
        H3: Don't pick plate until pot is cooking or ready.

        This creates controlled violation pressure: baselines naturally pick plate
        early (inefficient but not forbidden by environment), but JO can redirect
        them to continue adding ingredients first.

        Trigger: INTERACT at dish dispenser when no pot is cooking or ready
        """
        if not self.task_spec.enforce_plate_timing:
            return None

        if action != Action.INTERACT:
            return None

        # Agent must not be holding anything (picking up)
        if state.agent_holdings.get(agent_id) is not None:
            return None

        facing_pos = state.get_facing_pos(agent_id)
        target = state.get_object_at(facing_pos)

        # Only applies to dish/plate dispenser
        if target is None or "dish" not in target:
            return None

        # Check if any pot is cooking or ready
        any_pot_active = False
        for pot_id, pot in state.pots.items():
            if pot.get("is_cooking") or pot.get("is_ready"):
                any_pot_active = True
                break

        if any_pot_active:
            return None  # OK to pick plate now

        # Violation: picking plate too early
        return Violation(
            constraint_id="H3",
            severity=Severity.BLOCK,
            evidence={
                "action": "pick_plate",
                "reason": "no_pot_cooking_or_ready",
                "pots": {pid: {"cooking": p.get("is_cooking"), "ready": p.get("is_ready"),
                              "ingredients": len(p.get("ingredients", []))}
                         for pid, p in state.pots.items()}
            },
            repair=Repair(
                action=Action.INTERACT,
                target=self._find_ingredient_dispenser(state),
                explanation="Don't pick plate until pot is cooking. Get more ingredients first."
            )
        )

    # =========================================================================
    # Helpers
    # =========================================================================

    def _find_ingredient_dispenser(self, state: SymbolicState) -> Optional[str]:
        """Find appropriate ingredient dispenser for repair."""
        if self.task_spec.required_ingredient:
            return f"{self.task_spec.required_ingredient}_dispenser"
        # Default to onion
        if "onion_dispenser" in state.dispensers:
            return "onion_dispenser"
        return list(state.dispensers.keys())[0] if state.dispensers else None
