import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import os
import argparse
from functools import partial
import navix as nx
from typing import Dict, List, Tuple, Callable, Optional, Union, Any
from tqdm import tqdm
from PIL import Image, ImageDraw, ImageFont
import chex

from policies.dfs import jax_dfs_path_planning as jax_dfs
from policies.dfs import DFSState


# Define action indices as constants for consistency
# These numeric constants are used in the policy
NOOP = 0
ROTATE_CW = 1
ROTATE_CCW = 2
FORWARD = 3
RIGHT = 4
BACKWARD = 5
LEFT = 6
PICKUP = 7
OPEN = 8
DONE = 9

# Create a semantic mapping for display and documentation
ACTION_NAMES = {
    NOOP: "NOOP",
    ROTATE_CW: "ROTATE_CW",
    ROTATE_CCW: "ROTATE_CCW",
    FORWARD: "FORWARD",
    RIGHT: "RIGHT",
    BACKWARD: "BACKWARD",
    LEFT: "LEFT",
    PICKUP: "PICKUP",
    OPEN: "OPEN",
    DONE: "DONE"
}

# Create direct action mapping arrays for JAX compatibility
# Index is the policy action (0-9), value is the action in the target action set

# For COMPLETE action set - this is 1:1 mapping
COMPLETE_ACTION_MAP = jnp.array([
    0,  # NOOP -> 0
    1,  # ROTATE_CW -> 1
    2,  # ROTATE_CCW -> 2
    3,  # FORWARD -> 3
    4,  # RIGHT -> 4
    5,  # BACKWARD -> 5
    6,  # LEFT -> 6
    7,  # PICKUP -> 7
    8,  # OPEN -> 8
    9,  # DONE -> 9
])

# For MINIGRID action set
MINIGRID_ACTION_MAP = jnp.array([
    2,  # NOOP -> FORWARD (2)
    1,  # ROTATE_CW -> 1
    0,  # ROTATE_CCW -> 0
    2,  # FORWARD -> 2
    2,  # RIGHT -> FORWARD (2)
    2,  # BACKWARD -> FORWARD (2)
    2,  # LEFT -> FORWARD (2)
    3,  # PICKUP -> 3
    5,  # OPEN -> TOGGLE (5)
    6,  # DONE -> 6
])

# Direction indices from Navix
EAST = 0
SOUTH = 1
WEST = 2
NORTH = 3

# Direction vectors for each direction
DIRECTION_VECTORS = {
    EAST: (0, 1),   # East: (0, 1)
    SOUTH: (1, 0),  # South: (1, 0)
    WEST: (0, -1),  # West: (0, -1)
    NORTH: (-1, 0)  # North: (-1, 0)
}

def translate(position, direction):
    """Move position one step in the given direction.
    
    Args:
        position: (row, col) position tuple or array
        direction: Direction index (0=east, 1=south, 2=west, 3=north)
        
    Returns:
        New position after moving in the specified direction
    """
    direction = int(direction) % 4  # Convert to int and ensure direction is in [0, 3]
    
    if direction == EAST:
        return jnp.array([position[0], position[1] + 1])
    elif direction == SOUTH:
        return jnp.array([position[0] + 1, position[1]])
    elif direction == WEST:
        return jnp.array([position[0], position[1] - 1])
    else:  # NORTH
        return jnp.array([position[0] - 1, position[1]])

def get_manhattan_distance(pos1, pos2):
    """Calculate Manhattan distance between two positions."""
    return jnp.abs(pos1[0] - pos2[0]) + jnp.abs(pos1[1] - pos2[1])

def get_direction_to_target(agent_pos, agent_dir, target_pos, state=None):
    """
    Determine the actions needed to face the target.
    Returns the action needed to face the target, or -1 if already facing the target.
    
    Parameters:
        agent_pos: The agent's current position
        agent_dir: The agent's current direction
        target_pos: The target position to reach
        state: Optional environment state for wall-aware direction choosing
    """
    # Calculate the vector from agent to target
    dy = target_pos[0] - agent_pos[0]  # row difference (positive = south, negative = north)
    dx = target_pos[1] - agent_pos[1]  # col difference (positive = east, negative = west)
    
    # Determine the absolute direction to the target
    abs_dy = jnp.abs(dy)
    abs_dx = jnp.abs(dx)
    
    # Handle tie-breaking for diagonal movements with wall awareness
    def choose_direction():
        # Default directions based on signs
        south_direction = dy > 0
        east_direction = dx > 0
        
        vertical_dir = jax.lax.cond(south_direction, lambda _: SOUTH, lambda _: NORTH, operand=None)
        horizontal_dir = jax.lax.cond(east_direction, lambda _: EAST, lambda _: WEST, operand=None)
        
        # If we have state information, check for walls
        if state is not None:
            # Check positions in vertical and horizontal directions
            vertical_pos = get_next_position(agent_pos, vertical_dir)
            horizontal_pos = get_next_position(agent_pos, horizontal_dir)
            
            # Check if either direction is blocked by a wall
            vertical_blocked = is_wall_at_position(state, vertical_pos)
            horizontal_blocked = is_wall_at_position(state, horizontal_pos)
            
            # Choose unblocked direction with priority
            return jax.lax.cond(
                vertical_blocked,
                lambda _: jax.lax.cond(
                    horizontal_blocked,
                    # Both blocked - default to vertical as fallback
                    lambda _: vertical_dir,
                    # Horizontal path clear
                    lambda _: horizontal_dir,
                    operand=None
                ),
                lambda _: jax.lax.cond(
                    horizontal_blocked,
                    # Vertical path clear
                    lambda _: vertical_dir,
                    # Both clear - prefer vertical for consistency
                    lambda _: vertical_dir,
                    operand=None
                ),
                operand=None
            )
        else:
            # Without state information, default to vertical movement
            return vertical_dir
    
    # Choose direction based on whether vertical or horizontal distance is larger
    equal_distances = abs_dy == abs_dx
    vertical_dominant = abs_dy > abs_dx
    
    # Determine target direction with wall awareness for diagonal moves
    target_dir = jax.lax.cond(
        equal_distances,
        lambda _: choose_direction(),
        lambda _: jax.lax.cond(
            vertical_dominant, 
            # Vertical movement dominates
            lambda _: jax.lax.cond(
                dy > 0, 
                lambda _: SOUTH, 
                lambda _: NORTH, 
                operand=None
            ),
            # Horizontal movement dominates
            lambda _: jax.lax.cond(
                dx > 0, 
                lambda _: EAST, 
                lambda _: WEST, 
                operand=None
            ),
            operand=None
        ),
        operand=None
    )
    
    # Determine the action to take based on current direction and target direction
    dir_diff = (target_dir - agent_dir) % 4
    
    # Return appropriate rotation action or -1 if already facing the right direction
    return jax.lax.cond(
        dir_diff == 0,
        lambda _: -1,  # Already facing the right direction
        lambda _: jax.lax.cond(
            dir_diff == 1,
            lambda _: ROTATE_CW,  # Rotate clockwise
            lambda _: jax.lax.cond(
                dir_diff == 3,
                lambda _: ROTATE_CCW,  # Rotate counterclockwise
                lambda _: ROTATE_CW,   # Default to clockwise for 180 degree turns
                operand=None
            ),
            operand=None
        ),
        operand=None
    )

def find_entity_position(state, entity_type):
    """Find the position of a specific entity type in the state."""
    if entity_type in state.entities:
        entity = state.entities[entity_type]
        if hasattr(entity, 'position'):
            # Single entity
            return entity.position
        elif len(entity) > 0:
            # Multiple entities, use the first one
            return entity[0].position
    return None

def can_move_in_direction(state, agent_pos, direction):
    """Check if the agent can move in a given direction without hitting a wall."""
    # Translate the agent position in the given direction
    next_pos = translate(agent_pos, direction)
    
    # Check grid value (0 = free, other values are walls or obstacles)
    if state.grid[tuple(next_pos)] != 0:
        return False
    
    # Check if there are any entities at that position that can't be walked through
    for k in state.entities:
        if k == 'player':
            continue
            
        entity = state.entities[k]
        if hasattr(entity, 'position'):
            # Single entity
            if jnp.array_equal(entity.position, next_pos) and not entity.walkable:
                return False
        elif len(entity) > 0:
            # Multiple entities
            for e in entity:
                if hasattr(e, 'position') and jnp.array_equal(e.position, next_pos) and not e.walkable:
                    return False
                    
    return True

def find_path_around_walls(state, agent_pos, agent_dir, target_pos):
    """
    Simple wall avoidance strategy - try to find a direction where:
    1. There's no wall
    2. The distance to target decreases
    """
    # Try all four directions, prioritizing the one that gets us closer to the target
    directions = [EAST, SOUTH, WEST, NORTH]
    
    # Get current distance
    current_distance = get_manhattan_distance(agent_pos, target_pos)
    
    # Compute distances for each direction
    valid_directions = []
    for dir in directions:
        # Skip if we can't move in that direction
        if not can_move_in_direction(state, agent_pos, dir):
            continue
            
        # Calculate next position
        next_pos = translate(agent_pos, dir)
        
        # Calculate new distance
        new_distance = get_manhattan_distance(next_pos, target_pos)
        
        # Add direction and distance change
        valid_directions.append((dir, current_distance - new_distance))
    
    if not valid_directions:
        # No valid directions - try to rotate and see if that helps
        return get_direction_to_target(agent_pos, agent_dir, target_pos)
        
    # Sort by distance change (descending)
    valid_directions.sort(key=lambda x: x[1], reverse=True)
    
    # Take the best direction (highest distance reduction)
    best_dir = valid_directions[0][0]
    
    # If we're not already facing that direction, return a rotation action
    dir_diff = (best_dir - agent_dir) % 4
    if dir_diff == 0:
        return None  # Already facing the right direction
    elif dir_diff == 1:
        return ROTATE_CW
    elif dir_diff == 3:
        return ROTATE_CCW
    else:  # dir_diff == 2
        return ROTATE_CW

def goto_key_policy(state, policy_state):
    """Policy to navigate to and pick up the key."""
    # Get agent and key info
    player = state.entities['player'][0]
    agent_pos = player.position
    agent_dir = player.direction
    key_pos = state.entities['key'][0].position
    
    # Check if adjacent to the key
    dist_to_key = get_manhattan_distance(agent_pos, key_pos)
    adjacent_to_key = dist_to_key <= 1

    
    # Function to handle key pickup when adjacent to it
    def handle_key_pickup(policy_state):
        # Calculate direction to the key
        key_dir = calculate_direction_to_entity(agent_pos, key_pos)
        
        # Check if key_dir is valid (one of the cardinal directions)
        has_valid_dir = key_dir != -1
        
        # If we have a valid direction to the key, check if we're facing it
        def handle_valid_dir():
            facing_key = agent_dir == key_dir
            return jax.lax.cond(
                facing_key,
                lambda: PICKUP,
                lambda: get_rotation_action(agent_dir, key_dir),
            )
        
        # If no valid direction, navigate toward the key
        return jax.lax.cond(
            has_valid_dir,
            lambda: handle_valid_dir(),
            lambda: navigate_with_path(state, policy_state.path, policy_state.current_idx),
        ), policy_state
    
    def not_adjacent_to_key(policy_state):
            # compute path to key
        policy_state = jax.lax.cond(
            policy_state.path.path_length > policy_state.current_idx,
            lambda : policy_state, # use cached path
            lambda : PolicyState(path=get_dfs_path(state, state.entities['key'][0].position, adjacent_only=True), current_idx=1, action_set_id=policy_state.action_set_id), # get key path
        )

        # update index
        policy_state = jax.lax.cond(
            jnp.array_equal(agent_pos, policy_state.path.path[policy_state.current_idx]), # reached the position
            lambda: PolicyState(path=policy_state.path, current_idx=policy_state.current_idx+1, action_set_id=policy_state.action_set_id),
            lambda: policy_state,
        )
        return navigate_with_path(state, policy_state.path, policy_state.current_idx), policy_state
    # Navigate to key or pickup based on adjacency
    return jax.lax.cond(
        adjacent_to_key,
        lambda ps: handle_key_pickup(ps),
        lambda ps: not_adjacent_to_key(ps),
        policy_state
    )

def get_grid(state):
    grid = (state.grid < 0).astype(jnp.int32)
    walls = state.entities['wall'].position
    door = state.entities['door'][0]
    grid = jax.lax.cond(
        door.open,
        lambda: grid,
        lambda: grid.at[tuple(door.position)].set(1)
    )
    grid = grid.at[walls[:, 0], walls[:, 1]].set(1)
    return grid

def goto_door_policy(state, policy_state):
    """Policy to navigate to and open the door."""
    # Get agent and door info
    player = state.entities['player'][0]
    agent_pos = player.position
    agent_dir = player.direction
    door = state.entities['door'][0]
    door_pos = door.position
    
    # Check if adjacent to the door
    dist_to_door = get_manhattan_distance(agent_pos, door_pos)
    adjacent_to_door = dist_to_door <= 1
    
    # Function to handle door opening when adjacent to it
    def handle_door_opening(policy_state):
        # Calculate direction to the door
        door_dir = calculate_direction_to_entity(agent_pos, door_pos)
        
        # Check if door_dir is valid (one of the cardinal directions)
        has_valid_dir = door_dir != -1
        
        # If we have a valid direction to the door, check if we're facing it

        facing_door = agent_dir == door_dir
        return jax.lax.cond(
            facing_door,
            lambda _: OPEN,
            lambda _: get_rotation_action(agent_dir, door_dir),
            operand=None
        ), policy_state.replace(current_idx=policy_state.path.path_length+1)
        
    
    def not_adjacent_to_door(policy_state):
        # check path available
        # jax.debug.print('Path_length: {path_length}, current_idx: {current_idx}, current_pos: {current_pos}, expected_pos: {expected_pos}', path_length=policy_state.path.path_length, current_idx=policy_state.current_idx, current_pos=agent_pos, expected_pos=policy_state.path.path[policy_state.current_idx])
        policy_state = jax.lax.cond(
            policy_state.path.path_length > policy_state.current_idx,
            lambda : policy_state, # use cached path
            lambda : PolicyState(path=get_dfs_path(state, door_pos, adjacent_only=True), current_idx=1, action_set_id=policy_state.action_set_id), # get door path
        )

        # check we reached the intended position
        policy_state = jax.lax.cond(
            jnp.array_equal(agent_pos, policy_state.path.path[policy_state.current_idx]),
            lambda: PolicyState(path=policy_state.path, current_idx=policy_state.current_idx+1, action_set_id=policy_state.action_set_id),
            lambda: policy_state,
        )
        return navigate_with_path(state, policy_state.path, policy_state.current_idx), policy_state

    
    # Navigate to door or open based on adjacency
    return jax.lax.cond(
        adjacent_to_door,
        lambda ps: handle_door_opening(ps),
        lambda ps: not_adjacent_to_door(ps),
        policy_state
    )

def goto_goal_policy(state, policy_state):  
    """Policy to navigate to the goal."""
    # Get agent and goal info
    player = state.entities['player'][0]
    agent_pos = player.position
    goal_pos = state.entities['goal'][0].position
    # Check if at the goal
    at_goal = jnp.array_equal(agent_pos, goal_pos)

    def handle_go_to_goal(policy_state):
        # check we reached the intended position
        policy_state = jax.lax.cond(
            jnp.array_equal(agent_pos, policy_state.path.path[policy_state.current_idx]),
            lambda: PolicyState(path=policy_state.path, current_idx=policy_state.current_idx+1, action_set_id=policy_state.action_set_id),
            lambda: policy_state,
        )
        # check path available
        policy_state = jax.lax.cond(
            policy_state.path.path_length <= policy_state.current_idx,
            lambda : PolicyState(path=get_dfs_path(state, goal_pos, adjacent_only=False), current_idx=1, action_set_id=policy_state.action_set_id), # get goal path
            lambda : policy_state, # use cached path
        )
        # jax.debug.print('Path_length: {path_length}, current_idx: {current_idx}, current_pos: {current_pos}, expected_pos: {expected_pos}', path_length=policy_state.path.path_length, current_idx=policy_state.current_idx, current_pos=agent_pos, expected_pos=policy_state.path.path[policy_state.current_idx])
        
        return navigate_with_path(state, policy_state.path, policy_state.current_idx), policy_state

    # Return NOOP at goal or navigate toward it
    return jax.lax.cond(
        at_goal,
        lambda ps: (NOOP, ps.replace(current_idx=ps.path.path_length)),
        lambda ps: handle_go_to_goal(ps),
        policy_state
    )

def calculate_direction_to_entity(agent_pos, entity_pos):
    """Calculate the direction from agent to entity when adjacent."""
    # Check each of the four cardinal directions
    is_south = jnp.logical_and(agent_pos[0] == entity_pos[0] - 1, agent_pos[1] == entity_pos[1])
    is_north = jnp.logical_and(agent_pos[0] == entity_pos[0] + 1, agent_pos[1] == entity_pos[1])
    is_east = jnp.logical_and(agent_pos[0] == entity_pos[0], agent_pos[1] == entity_pos[1] - 1)
    is_west = jnp.logical_and(agent_pos[0] == entity_pos[0], agent_pos[1] == entity_pos[1] + 1)
    
    # Use nested jax.lax.cond to determine the direction
    return jax.lax.cond(
        is_south,
        lambda _: SOUTH,
        lambda _: jax.lax.cond(
            is_north,
            lambda _: NORTH,
            lambda _: jax.lax.cond(
                is_east,
                lambda _: EAST,
                lambda _: jax.lax.cond(
                    is_west,
                    lambda _: WEST,
                    lambda _: -1,  # Invalid direction
                    operand=None
                ),
                operand=None
            ),
            operand=None
        ),
        operand=None
    )

def get_rotation_action(current_dir, target_dir):
    """Get the appropriate rotation action to turn from current direction to target direction."""
    dir_diff = (target_dir - current_dir) % 4
    
    # Use jax.lax.cond for control flow
    return jax.lax.cond(
        dir_diff == 1,
        lambda _: ROTATE_CW,
        lambda _: jax.lax.cond(
            dir_diff == 3,
            lambda _: ROTATE_CCW,
            lambda _: ROTATE_CW,  # Default to clockwise for 180 degree turns
            operand=None
        ),
        operand=None
    )

def is_wall_at_position(state, position):
    """Check if there's a wall at the given position."""
    # First check grid value (non-zero values are usually walls or other obstacles)
    grid_obstacle = state.grid[tuple(position)] != 0
    
    # Check for closed door at this position
    door = state.entities['door'][0]
    door_closed = jnp.logical_not(door.open)
    door_obstacle = jnp.logical_and(door_closed, jnp.array_equal(position, door.position))
    
    # Combine obstacles
    return jnp.logical_or(grid_obstacle, door_obstacle)

def get_next_position(pos, direction):
    """Get the next position when moving in the given direction."""
    direction = direction % 4  # Ensure direction is in [0, 3]
    
    # Use jax.lax.switch for more efficient conditionals in JAX
    return jax.lax.switch(
        direction,
        [
            lambda _: jnp.array([pos[0], pos[1] + 1]),  # EAST
            lambda _: jnp.array([pos[0] + 1, pos[1]]),  # SOUTH
            lambda _: jnp.array([pos[0], pos[1] - 1]),  # WEST
            lambda _: jnp.array([pos[0] - 1, pos[1]]),  # NORTH
        ],
        None
    )

def navigate_with_path(state, dfs_path, path_idx):

    agent_pos = state.entities['player'][0].position
    agent_dir = state.entities['player'][0].direction
    # Get the next position in the path (index 1 if available, since index 0 is current position)
    # jax.debug.print('Path length: {path_length}, path: {path}', path_length=dfs_path.path_length, path=dfs_path.path)
    next_pos_index = path_idx
    next_pos = dfs_path.path[next_pos_index]
    # jax.debug.print('Current pos: {agent_pos}, Next pos: {next_pos}', agent_pos=agent_pos, next_pos=next_pos)
    # Calculate which direction we need to face
    dx = next_pos[1] - agent_pos[1]
    dy = next_pos[0] - agent_pos[0]
    
    # Determine the target direction based on dx and dy
    target_dir = jax.lax.cond(
        dx > 0,
        lambda _: EAST,
        lambda _: jax.lax.cond(
            dx < 0,
            lambda _: WEST,
            lambda _: jax.lax.cond(
                dy > 0,
                lambda _: SOUTH,
                lambda _: jax.lax.cond(
                    dy < 0,
                    lambda _: NORTH,
                    lambda _: agent_dir,  # Default to current direction if no movement
                    operand=None
                ),
                operand=None
            ),
            operand=None
        ),
        operand=None
    )
    
    # Check if we're facing the right direction
    facing_target = agent_dir == target_dir
    
    # jax.debug.print('Facing target: {facing_target}, agent_dir: {agent_dir}, target_dir: {target_dir}', facing_target=facing_target, agent_dir=agent_dir, target_dir=target_dir)
    # If facing the right direction, move forward; otherwise rotate
    return jax.lax.cond(
        facing_target,
        lambda _: FORWARD,
        lambda _: get_rotation_action(agent_dir, target_dir),
        operand=None
    )

def get_dfs_path(state, target_pos, adjacent_only=False):
    """Navigate toward a target position using DFS path planning."""
    # Get agent info
    player = state.entities['player'][0]
    agent_pos = player.position
    agent_dir = player.direction

    # Create a grid of free spaces (0) and obstacles (1)
    # add closed doors as obstacles
    grid = get_grid(state)
    # jax.debug.print('Computing path from {agent_pos} to target: {target_pos}', agent_pos=agent_pos, target_pos=target_pos)
    # Use DFS to find a path

    dfs_result = jax_dfs(
        grid,
        agent_pos,
        target_pos,
        adjacent_only=adjacent_only
    )
    # jax.debug.print('Path length: {path_length}', path_length=dfs_result.path_length)
    return dfs_result

def navigate_toward_target(state, target_pos):
    """Navigate toward a target position with wall awareness."""
    # Get agent info
    player = state.entities['player'][0]
    agent_pos = player.position
    agent_dir = player.direction
    
    # Get the ideal direction toward the target with wall awareness
    direction_action = get_direction_to_target(agent_pos, agent_dir, target_pos, state)
    
    # Function to handle case when already facing the target
    def handle_facing_target():
        # Calculate the position in front of the agent
        front_pos = get_next_position(agent_pos, agent_dir)
        
        # Check if there's a wall in front
        wall_in_front = is_wall_at_position(state, front_pos)
        
        # If no wall in front, move forward, otherwise rotate
        return jax.lax.cond(
            wall_in_front,
            lambda _: ROTATE_CW,  # Turn if blocked
            lambda _: FORWARD,    # Otherwise move forward
            operand=None
        )
    
    # If already facing the right direction, check for walls; otherwise rotate
    facing_target = direction_action == -1
    
    return jax.lax.cond(
        facing_target,
        lambda _: handle_facing_target(),
        lambda _: direction_action,
        operand=None
    )

@chex.dataclass
class PolicyState:
    path : DFSState
    current_idx : int
    action_set_id : int = 0  # 0 for COMPLETE, 1 for MINIGRID

def get_action_set_from_id(action_set_id):
    """Maps action set ID to string name.
    
    Args:
        action_set_id: 0 for COMPLETE, 1 for MINIGRID
        
    Returns:
        String name of the action set
    """
    return "COMPLETE" if action_set_id == 0 else "MINIGRID"

def get_action_index(action_name, action_set_id=0):
    """Maps a policy action constant to its index in the specified action set.
    
    Args:
        action_name: Action constant (e.g., ROTATE_CW, FORWARD)
        action_set_id: 0 for COMPLETE, 1 for MINIGRID
        
    Returns:
        The index of the action in the specified action set
    """
    # Handle special case for -1 (no action signal)
    use_default = action_name == -1
    
    # Get index from appropriate mapping array
    # For COMPLETE, index is the same as the action
    # For MINIGRID, we need to look up the mapping
    action_idx = jax.lax.cond(
        action_set_id == 0,
        lambda _: action_name,  # COMPLETE - direct mapping
        lambda _: MINIGRID_ACTION_MAP[action_name],  # MINIGRID - use mapping array
        operand=None
    )
    
    # Return original -1 if that was the input, otherwise return mapped action
    return jax.lax.cond(
        use_default,
        lambda _: -1,
        lambda _: action_idx,
        operand=None
    )

def doorkey_policy(rng, state, policy_state):
    """
    A direct policy for the DoorKey environment that:
    1. Finds and collects the key
    2. Finds and opens the door 
    3. Navigates to the goal
    
    Args:
        rng: Random key
        state: Environment state
        policy_state: Policy state including action set to use
        
    Returns:
        Tuple of (action_index, updated_policy_state)
    """
    # Get player state
    player = state.entities['player'][0]
    has_key = player.pocket != nx.components.EMPTY_POCKET_ID
    
    # Get door state
    door = state.entities['door'][0]
    door_open = door.open
    
    # Function for when the agent has the key
    def has_key_policy(state, policy_state):
        return jax.lax.cond(
            door_open,
            lambda _: goto_goal_policy(state, policy_state),  # Door open, go to goal
            lambda _: goto_door_policy(state, policy_state),  # Has key but door closed, go to door
            operand=None
        )

    # Choose key-based or door/go-to-goal policy using operand-based cond
    action_name, new_policy_state = jax.lax.cond(
        has_key,
        lambda ps: has_key_policy(state, ps),
        lambda ps: goto_key_policy(state, ps),
        policy_state
    )
    
    # Convert semantic action to index based on action set
    action_index = get_action_index(action_name, policy_state.action_set_id)
    
    return action_index, new_policy_state

def run_episode(env_name="DoorKey-8x8-v0", max_steps=100, seed=42, action_set="COMPLETE"):
    """Run an episode with the doorkey policy and return frames for visualization.
    
    Args:
        env_name: Name of the environment to run
        max_steps: Maximum number of steps to run
        seed: Random seed
        action_set: Which action set to use ("COMPLETE" or "MINIGRID")
        
    Returns:
        Tuple of (frames, total_reward)
    """
    # Create environment with the specified action set
    if action_set == "COMPLETE":
        env_action_set = nx.actions.COMPLETE_ACTION_SET
        action_set_id = 0
    else:
        env_action_set = nx.actions.MINIGRID_ACTION_SET
        action_set_id = 1
        
    # Create environment with the selected action set
    env = nx.make(
        f"Navix-{env_name}", 
        observation_fn=nx.observations.rgb,
        action_set=env_action_set
    )
    
    # Initialize random key
    rng = jax.random.key(seed)
    
    # Reset environment
    rng, key_reset = jax.random.split(rng)
    initial_timestep = env.reset(key_reset)
    
    # Create initial policy state with specified action set
    dummy_path = get_dfs_path(
        initial_timestep.state,
        initial_timestep.state.entities['player'][0].position,
    )
    dummy_path = dummy_path.replace(path_length=0)
    initial_policy_state = PolicyState(path=dummy_path, current_idx=0, action_set_id=action_set_id)
    
    # Add progress bar just for visual feedback
    pbar = tqdm(total=max_steps, desc=f"Steps in {env_name}")
    
    # JAX-based rollout function
    def jax_rollout(env, rng, max_steps, initial_timestep, initial_policy_state):
        # State tuple: (timestep, policy_state, rng, actions, rewards, done, step_count)
        # Actions and rewards will be fixed-shape arrays
        actions = jnp.zeros(max_steps, dtype=jnp.int32)
        rewards = jnp.zeros(max_steps, dtype=jnp.float32)
        observations = jnp.zeros((max_steps + 1, *initial_timestep.observation.shape), dtype=initial_timestep.observation.dtype)
        observations = observations.at[0].set(initial_timestep.observation)
        
        # Initial state
        init_state = (
            initial_timestep,          # timestep
            initial_policy_state,      # policy_state
            rng,                       # rng
            actions,                   # actions
            rewards,                   # rewards
            observations,              # observations
            False,                     # done flag
            0                          # step counter
        )
        
        # Step function
        def _step(state):
            timestep, policy_state, rng, actions, rewards, observations, done, step_count = state
            
            # Select action
            rng, key_action = jax.random.split(rng)
            action, new_policy_state = doorkey_policy(key_action, timestep.state, policy_state)
            
            # Store action and step the environment
            actions = actions.at[step_count].set(action)
            new_timestep = env.step(timestep, action)
            rewards = rewards.at[step_count].set(new_timestep.reward)
            observations = observations.at[step_count + 1].set(new_timestep.observation)
            
            # Update done flag and step counter
            new_done = new_timestep.is_done() | (step_count >= max_steps - 1)
            new_step_count = step_count + 1
            
            return (new_timestep, new_policy_state, rng, actions, rewards, observations, new_done, new_step_count)
        
        # Use while_loop to run until done or max_steps
        final_state = jax.lax.while_loop(
            lambda state: ~state[6],  # continue until done=True
            _step,
            init_state
        )
        
        return final_state
    
    # Run the JAX rollout
    final_state = jax_rollout(env, rng, max_steps, initial_timestep, initial_policy_state)
    timestep, policy_state, rng, actions, rewards, observations, done, step_count = final_state
    
    # Convert to numpy for post-processing
    np_actions = np.array(actions)
    np_rewards = np.array(rewards)
    np_observations = np.array(observations)
    total_steps = int(step_count)
    
    # Update progress bar to show completion
    pbar.update(total_steps)
    pbar.close()
    
    # Post-process to create frames with labels
    frames = []
    total_reward = 0
    
    # Update ACTION_NAMES based on action set
    action_names = {}
    if action_set == "COMPLETE":
        action_names = {
            0: "NOOP",
            1: "ROTATE_CW",
            2: "ROTATE_CCW",
            3: "FORWARD",
            4: "RIGHT",
            5: "BACKWARD",
            6: "LEFT",
            7: "PICKUP",
            8: "OPEN",
            9: "DONE"
        }
    else:  # MINIGRID
        action_names = {
            0: "ROTATE_CCW",
            1: "ROTATE_CW",
            2: "FORWARD",
            3: "PICKUP",
            4: "DROP",
            5: "TOGGLE",
            6: "DONE"
        }
    
    for step in range(total_steps + 1):  # +1 to include the final observation
        # Get the observation
        obs = np_observations[step]
        
        # Add frame with action label (except for first frame)
        if step == 0:
            # First frame has no previous action
            labeled_frame = add_action_label(obs, None, reward=0.0, action_names=action_names)
        else:
            # Get the action that led to this state
            action_int = int(np_actions[step - 1])
            total_reward += np_rewards[step - 1]
            
            # Label with action
            is_final = step == total_steps
            labeled_frame = add_action_label(obs, action_int, reward=total_reward, is_final=is_final, action_names=action_names)
        
        frames.append(labeled_frame)
    
    return frames, float(np.sum(np_rewards[:total_steps]))

def add_action_label(frame, action, reward=0.0, is_final=False, action_names=None):
    """Add a text label showing the action to a frame."""
    # Convert JAX array to numpy if needed
    if hasattr(frame, 'device_buffer'):
        frame = np.array(frame)
    
    # Make a copy of the frame to avoid modifying the original
    labeled_frame = frame.copy()
    
    # If frame is float in [0,1], scale to [0,255]
    if labeled_frame.dtype == np.float32 or labeled_frame.dtype == np.float64:
        labeled_frame = (labeled_frame * 255).astype(np.uint8)
    
    # Create a blank area at the bottom for text
    padding = 30  # Height of text area
    new_frame = np.zeros((labeled_frame.shape[0] + padding, labeled_frame.shape[1], 3), dtype=np.uint8)
    new_frame[:labeled_frame.shape[0], :, :] = labeled_frame
    
    # Convert to PIL image
    img = Image.fromarray(new_frame)
    draw = ImageDraw.Draw(img)
    
    # Try to load a font, or use default
    try:
        font = ImageFont.truetype("arial.ttf", 12)
    except:
        font = ImageFont.load_default()
    
    # Add text with action name
    if action is not None:
        if action_names is None:
            action_name = f"Action {action}"
        else:
            action_name = action_names.get(action, f"Action {action}")
        status = "FINAL STATE" if is_final else action_name
        text = f"Action: {status} | Reward: {reward:.2f}"
    else:
        text = f"INITIAL STATE | Reward: {reward:.2f}"
    
    # Position text at the bottom center
    try:
        # For newer PIL versions
        text_width = draw.textlength(text, font=font)
    except AttributeError:
        # Fallback for older PIL versions
        text_width = font.getsize(text)[0]
    
    text_position = ((img.width - text_width) // 2, labeled_frame.shape[0] + 5)
    
    # Draw text with black outline for better visibility
    outline_color = (0, 0, 0)
    text_color = (255, 255, 255)  # White text
    
    # Draw text with outline for visibility
    for dx, dy in [(-1, -1), (-1, 1), (1, -1), (1, 1)]:
        draw.text((text_position[0] + dx, text_position[1] + dy), text, font=font, fill=outline_color)
    
    draw.text(text_position, text, font=font, fill=text_color)
    
    # Convert back to numpy array
    return np.array(img)

def create_gif(frames, filename='episode.gif', fps=5):
    """Create a GIF from a list of frames."""
    print(f"Creating GIF from {len(frames)} frames...")
    
    # Convert frames to uint8 if they're not already
    frames = [np.clip(frame, 0, 255).astype(np.uint8) for frame in frames]
    
    fig, ax = plt.subplots(figsize=(frames[0].shape[1]/100, frames[0].shape[0]/100))
    ax.axis('off')
    
    # Create animation
    img = ax.imshow(frames[0])
    
    def update(frame):
        img.set_array(frame)
        return [img]
    
    anim = FuncAnimation(fig, update, frames=frames, interval=1000/fps, blit=True)
    
    # Save animation with progress indicator
    with tqdm(total=100, desc="Saving GIF", unit="%") as pbar:
        # Define a callback to update the progress bar
        def progress_callback(current_frame, total_frames):
            progress = int((current_frame / total_frames) * 100)
            pbar.update(progress - pbar.n)
        
        # Unfortunately, pillow writer doesn't support progress callback directly
        # So we'll simulate progress based on the number of frames
        anim.save(filename, writer='pillow', fps=fps)
        pbar.update(100 - pbar.n)  # Ensure we reach 100%
    
    plt.close(fig)
    
    print(f"GIF saved to {filename}")
    return filename


def test_dfs_path_planning():
    """Test a non-JAX implementation of DFS path planning with a simple maze."""
    print("Testing DFS path planning...")
    
    # Create a simple grid with walls (1) and free space (0)
    # 0 0 0 0 0
    # 0 1 1 1 0
    # 0 0 0 0 0
    # 1 1 1 1 0
    # 0 0 0 0 0
    grid = np.zeros((5, 5), dtype=np.int32)
    grid[1, 1:4] = 1  # Horizontal wall in row 1
    grid[3, 0:4] = 1  # Horizontal wall in row 3
    
    print("Grid representation (0=free, 1=wall):")
    for i in range(grid.shape[0]):
        print(" ".join(str(int(grid[i, j])) for j in range(grid.shape[1])))
    
    def dfs(start_pos, target_pos, adjacent_only=False, max_steps=100):
        """Simple Python implementation of DFS path planning."""
        print(f"Running DFS from {start_pos} to {target_pos}")
        # Initialize
        start_pos = np.array(start_pos)
        target_pos = np.array(target_pos)
        visited = np.zeros_like(grid, dtype=bool)
        stack = [(start_pos, [])]  # (current_pos, path_so_far)
        
        # Target check function
        def is_target(pos):
            if adjacent_only:
                # Check if adjacent to target
                adjacent_positions = [
                    [target_pos[0] + 1, target_pos[1]],  # South
                    [target_pos[0] - 1, target_pos[1]],  # North
                    [target_pos[0], target_pos[1] + 1],  # East
                    [target_pos[0], target_pos[1] - 1],  # West
                ]
                return any(np.array_equal(pos, adj_pos) for adj_pos in adjacent_positions)
            else:
                return np.array_equal(pos, target_pos)
        
        # Check if a position is valid
        def is_valid(pos):
            row, col = pos
            in_bounds = (0 <= row < grid.shape[0]) and (0 <= col < grid.shape[1])
            if not in_bounds:
                return False
            is_wall = grid[row, col] != 0
            return not is_wall
        
        # DFS loop
        steps = 0
        while stack and steps < max_steps:
            steps += 1
            current, path = stack.pop()
            print(f"Step {steps}: Checking position {current}")
            
            # Check if we've reached the target
            if is_target(current):
                print(f"Found target at step {steps}!")
                path = path + [current]  # Include current position in path
                return path
            
            # Mark as visited
            visited[tuple(current)] = True
            
            # Add current position to path
            new_path = path + [current]
            
            # Get neighbors (in reverse order for DFS stack)
            neighbors = [
                np.array([current[0] - 1, current[1]]),  # up
                np.array([current[0], current[1] - 1]),  # left
                np.array([current[0] + 1, current[1]]),  # down
                np.array([current[0], current[1] + 1]),  # right
            ]
            
            # Add valid unvisited neighbors to stack
            for neighbor in neighbors:
                if is_valid(neighbor) and not visited[tuple(neighbor)]:
                    stack.append((neighbor, new_path))
                    print(f"  Adding neighbor {neighbor} to stack")
        
        print(f"DFS terminated after {steps} steps without finding target")
        return None
    
    # Test case 1: Simple path from top-left to bottom-right
    start_pos = [0, 0]
    target_pos = [4, 4]
    
    path = dfs(start_pos, target_pos)
    jax_path = jax_dfs(jnp.asarray(grid), jnp.asarray(start_pos), jnp.asarray(target_pos))
    print(f'JAX Path: {jax_path.path[:jax_path.path_length]}')

    print(f"\nPath from {start_pos} to {target_pos}:")
    if path:
        for i, pos in enumerate(path):
            print(f"  Step {i}: {pos}")
    else:
        print("No path found.")
    
    # Test case 2: Path to a position adjacent to the target
    path = dfs(start_pos, target_pos, adjacent_only=True)
    jax_path = jax_dfs(jnp.asarray(grid), jnp.asarray(start_pos), jnp.asarray(target_pos))
    print(f"\nPath from {start_pos} to adjacent to {target_pos}:")
    if path:
        for i, pos in enumerate(path):
            print(f"  Step {i}: {pos}")
    else:
        print("No path found.")
    
    print(f'JAX Path: {jax_path.path[:jax_path.path_length]}')
    # Validate the path
    if path:
        print("\nTest PASSED: DFS found a path")
        
        # Check if path avoids walls
        valid_path = True
        for pos in path:
            if grid[tuple(pos)] != 0:
                valid_path = False
                print(f"Error: Path goes through wall at position {pos}")
        
        if valid_path:
            print("Path validation PASSED: Path avoids walls")
        else:
            print("Path validation FAILED: Path goes through walls")
    else:
        print("\nTest FAILED: DFS did not find a path")

if __name__ == "__main__":
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Run Navix DoorKey policy demonstrations")
    parser.add_argument("--output-dir", type=str, default="outputs/navix_demos",
                        help="Directory to save output GIFs")
    parser.add_argument("--env", type=str, default=None,
                        help="Specific environment to run (e.g., 'DoorKey-8x8-v0')")
    parser.add_argument("--all", action="store_true", 
                        help="Run all environments, including larger ones")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for reproducibility")
    parser.add_argument("--fps", type=int, default=5,
                        help="Frames per second for output GIFs")
    parser.add_argument("--max-steps", type=int, default=100,
                        help="Maximum number of steps to run per episode")
    parser.add_argument("--test-dfs", action="store_true",
                        help="Run the DFS path planning test")
    parser.add_argument("--action-set", type=str, default="COMPLETE", choices=["COMPLETE", "MINIGRID"],
                        help="Action set to use (COMPLETE or MINIGRID)")
    args = parser.parse_args()
    
    # Run the DFS test if requested
    if args.test_dfs:
        test_dfs_path_planning()
        exit(0)
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Define environments to run
    default_envs = [
        {"name": "DoorKey-5x5-v0", "max_steps": 50},
        {"name": "DoorKey-8x8-v0", "max_steps": 100},
    ]
    
    larger_envs = [
        {"name": "DoorKey-16x16-v0", "max_steps": 250},
        {"name": "DoorKey-Random-8x8-v0", "max_steps": 100}
    ]
    
    # Select environments to run
    if args.env:
        # Run only the specified environment
        env_config = next((env for env in default_envs + larger_envs if env["name"] == args.env), None)
        if not env_config:
            # If not found in predefined configs, create a default one
            env_config = {"name": args.env, "max_steps": args.max_steps}
        else:
            # Override the max_steps with the command line argument
            env_config["max_steps"] = args.max_steps
        envs_to_run = [env_config]
    elif args.all:
        # Run all environments with custom max_steps
        envs_to_run = []
        for env in default_envs + larger_envs:
            env_copy = env.copy()
            env_copy["max_steps"] = args.max_steps
            envs_to_run.append(env_copy)
    else:
        # Run only default environments with custom max_steps
        envs_to_run = []
        for env in default_envs:
            env_copy = env.copy()
            env_copy["max_steps"] = args.max_steps
            envs_to_run.append(env_copy)
    
    # Show overall progress
    for i, env_config in enumerate(envs_to_run):
        env_name = env_config["name"]
        max_steps = env_config["max_steps"]
        
        print(f"\nRunning demo for {env_name} [{i+1}/{len(envs_to_run)}]")
        frames, reward = run_episode(env_name=env_name, max_steps=max_steps, seed=args.seed, action_set=args.action_set)
        
        # Create GIF
        action_set_suffix = args.action_set.lower()
        gif_path = os.path.join(args.output_dir, f"{env_name}_smart_{action_set_suffix}.gif")
        create_gif(frames, filename=gif_path, fps=args.fps)
        
        print(f"Completed with reward: {reward}")
        print(f"Saved GIF to: {gif_path}")
    
    print("\nAll demos completed!") 