import jax
import jax.numpy as jnp
from jax import lax
import chex
from flax import struct
from gymnax.environments import environment
from gymnax.environments.spaces import Discrete, Box
from typing import Optional, Tuple, Dict, Any
from functools import partial
import numpy as np


@struct.dataclass
class TaxiState:
    taxi_pos: chex.Array  # (2,) array of int32
    passengers: chex.Array  # (num_depots, 4) array of int32
    passenger_in_taxi: jnp.int32  # passenger index or -1
    time: jnp.int32  # current timestep
    key: chex.PRNGKey

@struct.dataclass
class TaxiParams:
    size: jnp.int32 = 5
    max_steps_in_episode: jnp.int32 = 200
    n_passengers: jnp.int32 = 1
    allow_dropoff_anywhere: bool = False

class TaxiGymnax(environment.Environment):
    def __init__(
            self,
            size: int = 5,
            n_passengers: int = 1,
            render_mode: int = 0,
            allow_dropoff_anywhere: bool = True,
            exploring_starts: bool = False,
            img_size: int = 64,
            max_steps_in_episode: int = 200
        ):

        """Initialize the environment.
        
        Args:
            size: Size of the grid (5 or 10)
            n_passengers: Number of passengers (1-4 for size=5, 1-8 for size=10)
            render_mode: 0 for vector observations, 1 for RGB rendering
            allow_dropoff_anywhere: If True, allows dropping off passengers at any location
            exploring_starts: If True, the passenger positions are sampled from the set of all possible positions. 
            If False, the passenger positions are sampled from the set of all possible positions that are not the depot locations.
        """
        super().__init__()
        self.size = size
        self.n_passengers = n_passengers
        self.params = TaxiParams(size=size, n_passengers=n_passengers, allow_dropoff_anywhere=allow_dropoff_anywhere, max_steps_in_episode=max_steps_in_episode)
        self.render_mode = render_mode
        self.allow_dropoff_anywhere = allow_dropoff_anywhere
        self.exploring_starts = exploring_starts
        self.max_steps_in_episode = max_steps_in_episode
        # Initialize grid with size (rows*2+1, cols*2+1) to match original implementation
        self.grid = jnp.zeros((size*2+1, size*2+1), dtype=jnp.bool_)
        state_names = ['taxi_y', 'taxi_x']
        for i in range(n_passengers):
            state_names.extend([f'passenger_{i}_y', f'passenger_{i}_x', f'passenger_{i}_goal_id'])
        self.state_names = state_names + ['passenger_in_taxi']
        # Reset valid positions and walls
        self.grid = self.grid.at[1:-1:2, 1:-1].set(False)
        self.grid = self.grid.at[1:-1, 1:-1:2].set(False)
        self.img_size = img_size
        # Define depot locations and walls based on size
        if size == 5:
            # Set depot locations
            self.depot_locs = jnp.array([
                [0, 0],  # R: Red depot
                [0, 4],  # G: Green depot
                [4, 0],  # Y: Yellow depot
                [4, 3],  # B: Blue depot
            ])

            self.depot_colors = jnp.array([
                [1.0, 0.0, 0.0],  # red
                [0.0, 1.0, 0.0],  # green
                [1.0, 1.0, 0.0],  # yellow
                [0.0, 0.0, 1.0],  # blue
            ])
            
            # Set walls to match original implementation exactly
            # In the original implementation: self.grid[1:4, 4] = 1
            self.grid = self.grid.at[1:4, 4].set(True)
            # In the original implementation: self.grid[7:10, 2] = 1
            self.grid = self.grid.at[7:10, 2].set(True)
            # In the original implementation: self.grid[7:10, 6] = 1
            self.grid = self.grid.at[7:10, 6].set(True)
            
            # Check passenger count is valid
            max_passengers = 4  # Maximum 4 passengers for 5x5 grid (number of depots)
            if n_passengers < 0 or n_passengers > max_passengers:
                raise ValueError(f"n_passengers must be between 0 and {max_passengers} for size {size}")
                
        elif size == 10:
            # Set depot locations
            self.depot_locs = jnp.array([
                [0, 0],   # R: Red depot
                [0, 9],   # G: Green depot
                [9, 0],   # Y: Yellow depot
                [9, 6],   # B: Blue depot
                [3, 3],   # Gray depot
                [4, 6],   # Magenta depot
                [0, 8],   # Cyan depot
                [9, 9],   # Orange depot
            ])

            self.depot_colors = jnp.array([
                [1.0, 0.0, 0.0],  # red
                [0.0, 1.0, 0.0],  # green
                [1.0, 1.0, 0.0],  # yellow
                [0.0, 0.0, 1.0],  # blue
                [0.5, 0.5, 0.5],  # light gray
                [1.0, 0.0, 1.0],  # magenta
                [0.0, 1.0, 1.0],  # cyan
                [1.0, 0.5, 0.0],  # orange
            ])
            # Set walls to match original implementation exactly
            self.grid = self.grid.at[1:8, 6].set(True)
            self.grid = self.grid.at[13:20, 2].set(True)
            self.grid = self.grid.at[13:20, 8].set(True)
            self.grid = self.grid.at[5:12, 12].set(True)
            self.grid = self.grid.at[1:8, 16].set(True)
            self.grid = self.grid.at[13:20, 16].set(True)
            
            # Check passenger count is valid
            max_passengers = 8  # Maximum 8 passengers for 10x10 grid (number of depots)
            if n_passengers < 0 or n_passengers > max_passengers:
                raise ValueError(f"n_passengers must be between 0 and {max_passengers} for size {size}")
        else:
            raise ValueError(f"Size must be 5 or 10, got {size}")
        
        self.num_depots = len(self.depot_locs)
        self._num_actions = 6  # noop, right, left, up, down, interact
        self._action_space = Discrete(self._num_actions)
        self._observation_space = Box(
            low=0,
            high=max(self.size, self.num_depots),
            shape=(2 + 4 * self.n_passengers,),  # taxi_pos (2) + passengers (4 * n_passengers)
            dtype=jnp.int32
        )
        self._state_space = {
            "taxi_pos": Box(low=0, high=self.size-1, shape=(2,), dtype=jnp.int32),
            "passengers": Box(low=0, high=self.size-1, shape=(self.n_passengers, 4), dtype=jnp.int32),
            "passenger_in_taxi": Discrete(self.n_passengers + 1),  # -1 for no passenger
            "time": Discrete(1000),
        }

    def has_wall(self, position, offset=(0, 0)):
        """Check if there is a wall at the given position and offset."""
        row, col = position
        d_row, d_col = offset
        wall_row = 2 * row + 1 + d_row
        wall_col = 2 * col + 1 + d_col
        
        # Make sure the wall position is within bounds using a JAX-friendly approach
        out_of_bounds = (wall_row < 0) | (wall_row >= self.grid.shape[0]) | (wall_col < 0) | (wall_col >= self.grid.shape[1])
        
        # Return True for out of bounds, otherwise check the grid
        # Use where to handle the conditional
        return jnp.where(
            out_of_bounds,
            jnp.array(True),  # Treat out of bounds as a wall
            self.grid[wall_row, wall_col]  # Check actual grid value
        )

    @property
    def default_params(self) -> TaxiParams:
        return self.params

    @property
    def num_actions(self) -> int:
        return self._num_actions

    def action_space(self, params: Optional[TaxiParams] = None) -> Discrete:
        return Discrete(self._num_actions)

    def observation_space(self, params: Optional[TaxiParams] = None) -> Box:
        """Get the observation space."""
        params = params if params is not None else self.default_params
        
        # Use static value for render_mode to avoid JAX tracing issues
        render_mode = jax.lax.stop_gradient(self.render_mode)
        
        # Return appropriate observation space based on render mode
        # return lax.cond(
        #     render_mode == 1,
        #     lambda: Box(low=0, high=1, shape=(self.size, self.size, 3), dtype=jnp.float32),
        #     lambda: Box(low=0, high=max(self.size, self.num_depots), shape=(2 + 4 * self.n_passengers,), dtype=jnp.int32)
        # )
        if render_mode == 1:
            return Box(low=0, high=1, shape=(self.img_size, self.img_size, 3), dtype=jnp.float32)
        else:
            return Box(low=0, high=max(self.size, self.num_depots), shape=(2 + 4 * self.n_passengers,), dtype=jnp.int32)

    def state_space(self, params: Optional[TaxiParams] = None) -> Dict[str, Box]:
        return self._state_space

    def reset_env(self, key: chex.PRNGKey, params: Optional[TaxiParams] = None) -> Tuple[chex.Array, TaxiState]:
        """Implementation of reset."""
        # Split key for different random operations
        key_taxi, key_passengers, key_step = jax.random.split(key, 3)
        
        # Initialize passengers
        passengers = jnp.zeros((self.n_passengers, 4), dtype=jnp.int32)
        assigned_goal_depots = jnp.zeros((self.num_depots,), dtype=jnp.bool_)
        assigned_start_depots = jnp.zeros((self.num_depots,), dtype=jnp.bool_)
        assigned_passenger_start_positions = jnp.zeros((self.size ** 2,), dtype=jnp.bool_)
        def init_passenger(carry, i):
            passengers, key, assigned_goal_depots, assigned_start_depots, assigned_passenger_start_positions = carry
            key, subkey_pos, subkey_goal = jax.random.split(key, 3)
            
            if not self.exploring_starts:
                # Find an unassigned depot for initial position
                pos_idx = i % self.num_depots  # Default fallback
                # Update assigned_depots to mark this depot as used
                assigned_start_depots = assigned_start_depots.at[pos_idx].set(True)
                # Set initial position to depot location
                passengers = passengers.at[i, :2].set(self.depot_locs[pos_idx])
                available_mask = (jnp.arange(self.num_depots) != pos_idx) & ~assigned_goal_depots

            else:
                pos_probs = (~assigned_passenger_start_positions).astype(jnp.float32) / jnp.sum(~assigned_passenger_start_positions)
                sampled_pos = jax.random.categorical(subkey_pos, jnp.log(pos_probs))
                passengers = passengers.at[i, :2].set(jnp.unravel_index(sampled_pos, (self.size, self.size)))
                assigned_passenger_start_positions = assigned_passenger_start_positions.at[sampled_pos].set(True)
                available_mask = ~assigned_goal_depots


            # Find an unassigned depot for goal position
            # Create a mask of available depots (not the same as current position)
            
            goal_probs = available_mask.astype(jnp.float32) / jnp.sum(available_mask)
            goal_idx = jax.random.categorical(subkey_goal, jnp.log(goal_probs))
            # Set goal position
            passengers = passengers.at[i, 2:].set(self.depot_locs[goal_idx])
            assigned_goal_depots = assigned_goal_depots.at[goal_idx].set(True)
            
            return (passengers, key, assigned_goal_depots, assigned_start_depots, assigned_passenger_start_positions), None
        
        (passengers, _, assigned_goal_depots, assigned_start_depots, assigned_passenger_start_positions), _ = jax.lax.scan(
            init_passenger,
            (passengers, key_passengers, assigned_goal_depots, assigned_start_depots, assigned_passenger_start_positions),
            jnp.arange(self.n_passengers)
        )
        
        if self.exploring_starts:
            # assign taxi position 
            # assign taxi randomly to a position 50% of the time
            # 50% in one of the passenger's positions
            rng_taxi, rng_case = jax.random.split(key_taxi)
            case = jax.random.randint(rng_case, shape=(), minval=0, maxval=2)

            random_passenger = jax.random.choice(rng_taxi, jnp.arange(self.n_passengers))
            taxi_pos = jax.lax.cond(
                case == 0,
                lambda: passengers[random_passenger, :2],
                lambda: jax.random.randint(rng_taxi, shape=(2,), minval=0, maxval=self.size)
            )
        else:
            taxi_pos = jax.random.randint(key_taxi, shape=(2,), minval=0, maxval=self.size)
       
        # Initialize state
        state = TaxiState(
            taxi_pos=taxi_pos.astype(jnp.int32),
            passengers=passengers,
            passenger_in_taxi=jnp.array(-1, dtype=jnp.int32),  # No passenger in taxi initially
            time=jnp.array(0, dtype=jnp.int32),
            key=key_step,
        )
        
        # Get observation
        if self.render_mode == 1:
            # For RGB rendering, use a non-jittable approach to avoid tracing issues
            obs = self._render_rgb(state)
        else:
            obs = self._get_state_obs(state)
        
        return obs, state

    def step(
        self,
        key: chex.PRNGKey,
        state: TaxiState,
        action: int,
        params: Optional[TaxiParams] = None,
    ) -> Tuple[chex.Array, TaxiState, float, bool, Dict[Any, Any]]:
        """Perform one timestep of the environment."""
        params = self.params if params is None else params
        key, key_reset = jax.random.split(key)
        

        info = {'state': self._get_state_obs(state)}
        # Get current taxi position
        taxi_pos = state.taxi_pos
        
        # Initialize reward and done flag
        reward = 0.0
        done = False
        
        # Handle movement actions
        # Use lax.switch instead of if/else to avoid boolean conversion
        def handle_action():
            # Define action offsets - must match exact order of original implementation
            action_offsets = jnp.array([
                [0, 0],   # noop (0)
                [0, 1],   # right (1)
                [0, -1],  # left (2)
                [-1, 0],  # up (3)
                [1, 0],   # down (4)
                [0, 0],   # interact (5) - this won't be used for movement
            ])
            
            def handle_movement():
                # Skip NOOP (action 0)
                def handle_nonzero_movement():
                    # Get offset for this action
                    offset = jax.lax.dynamic_slice_in_dim(action_offsets, action, 1)[0]
                    
                    # First check if there's a wall in the way
                    has_wall = self.has_wall(taxi_pos, offset)
                    
                    # Then check if the new position would be in bounds
                    new_pos = taxi_pos + offset
                    valid_pos = jnp.logical_and(
                        jnp.all(new_pos >= 0),
                        jnp.all(new_pos < params.size)
                    )
                    
                

                    # Only update position if valid and no wall
                    new_taxi_pos = jnp.where(
                        jnp.logical_and(valid_pos, ~has_wall),
                        new_pos,
                        taxi_pos
                    )

                    passengers = state.passengers
                    passengers = jax.lax.cond(
                        state.passenger_in_taxi > -1,
                        lambda: passengers.at[state.passenger_in_taxi, :2].set(new_taxi_pos),
                        lambda: passengers
                    )
                    return state.replace(taxi_pos=new_taxi_pos, passengers=passengers)
                    
                
                # Only move if action is not 0 (NOOP)
                return jax.lax.cond(
                    action > 0,
                    handle_nonzero_movement,
                    lambda: state
                )
            
            def handle_interaction():
                # If no passenger in taxi, try to pick up
                # Otherwise, try to drop off
                return jax.lax.cond(
                    state.passenger_in_taxi == -1,
                    lambda: self._pickup_passenger(state, taxi_pos),
                    lambda: self._dropoff_passenger(state, taxi_pos)
                )
            
            # Use lax.cond to choose between movement and interaction
            updated_state = jax.lax.cond(
                action < 5, 
                lambda: handle_movement(),
                lambda: handle_interaction()
            )
            
            return updated_state
        
        # Update state
        state = handle_action()
        
        # Update time
        state = state.replace(time=state.time + 1)
        
        # Check if goal is reached (all passengers at their destinations)
        goal_reached = self._check_goal(state)
        reward = jnp.float32(goal_reached)  # 1.0 if goal reached, 0.0 otherwise
        
        # Check for episode termination (either goal reached or max steps)
        done = goal_reached | (state.time >= params.max_steps_in_episode)
        
        # Get observation
        if self.render_mode == 1:
            # For RGB rendering, use a non-jittable approach to avoid tracing issues
            obs = self._render_rgb(state)
        else:
            obs = self._get_state_obs(state)
        
        return obs, state, reward, done, info

    def get_obs(self, state: TaxiState, params: Optional[TaxiParams] = None) -> chex.Array:
        """Get observation from state."""
        return self._get_obs_impl(state, params if params is not None else self.default_params)

    def _get_state_obs(self, state: TaxiState) -> chex.Array:
        """Get state vector observation."""
        # Create observation vector: [taxi_pos, passenger_states]
        # passenger_states for each passenger: [at_depot, in_taxi, delivered, goal_depot]
        # passenger_states = jnp.zeros((self.n_passengers, 4), dtype=jnp.int32)
        
        # # Set passenger states
        # for i in range(self.n_passengers):
        #     at_depot = jnp.all(state.passengers[i, :2] == self.depot_locs[i])
        #     in_taxi = (state.passenger_in_taxi == i)
        #     delivered = jnp.all(state.passengers[i, :2] == state.passengers[i, 2:4])
            
        #     # Find goal depot using a JAX-friendly approach
        #     goal_pos = state.passengers[i, 2:4]
        #     goal_matches = jnp.all(goal_pos == self.depot_locs, axis=1)
        #     goal_depot = jnp.sum(jnp.arange(self.num_depots) * goal_matches)
            
        #     passenger_states = passenger_states.at[i].set(jnp.array([at_depot, in_taxi, delivered, goal_depot]))
        goal_ids = jnp.argmax(jnp.all(state.passengers[:, None, 2:4] == self.depot_locs[None], axis=-1), axis=1)
        passenger_states = jnp.concatenate([state.passengers[:, :2], goal_ids[:, None]], axis=1)
        # Flatten and concatenate
        obs = jnp.concatenate([
            state.taxi_pos,
            passenger_states.reshape(-1),
            jnp.array([state.passenger_in_taxi], dtype=jnp.int32)
        ])
        
        return obs

    def _render_rgb(self, state: TaxiState) -> chex.Array:
        """Render the environment as an RGB array in a JAX-compatible way."""
        # Define constants
        cell_width = 11
        wall_width = 1
        border_widths = (2, 1)  # (pad_top_left, pad_bot_right)
        depot_width = 2
        character_width = 7
        
        # Calculate grid dimensions
        rows, cols = self.size, self.size
        grid_height = rows * cell_width + (rows + 1) * wall_width
        grid_width = cols * cell_width + (cols + 1) * wall_width
        
        # Create a white background
        grid = jnp.ones((grid_height, grid_width, 3))
        
        # Define a function to get colors instead of using a dictionary
        def get_color(name):
            # List all colors as jax.lax.switch cases
            return jax.lax.switch(
                jnp.array([
                    name == 'red', 
                    name == 'blue', 
                    name == 'yellow', 
                    name == 'green',
                    name == 'almost_black',
                    name == 'dimgray',
                    name == 'white',
                    name == 'cyan'
                ]).argmax(),
                [
                    lambda: jnp.array([1.0, 0.0, 0.0]),  # red
                    lambda: jnp.array([0.0, 0.0, 1.0]),  # blue
                    lambda: jnp.array([1.0, 1.0, 0.0]),  # yellow
                    lambda: jnp.array([0.0, 1.0, 0.0]),  # green
                    lambda: jnp.array([0.2, 0.2, 0.2]),  # almost_black
                    lambda: jnp.array([0.4, 0.4, 0.4]),  # dimgray
                    lambda: jnp.array([1.0, 1.0, 1.0]),  # white
                    lambda: jnp.array([0.0, 1.0, 1.0]),  # cyan
                ]
            )
        
        # Define depot colors array to avoid indexing issues
        depot_colors_array = self.depot_colors
        
        # ----- DRAW GRID WALLS -----
        # Create meshgrids for vertex positions
        rows_range = jnp.arange(rows + 1)
        cols_range = jnp.arange(cols + 1)
        row_positions = rows_range * (cell_width + wall_width)
        col_positions = cols_range * (cell_width + wall_width)
        
        # Function to draw vertices (intersection points)
        def draw_vertex(grid, r, c):
            row_pos = r * (cell_width + wall_width)
            col_pos = c * (cell_width + wall_width)
            
            # Create mask for this vertex
            row_mask = (jnp.arange(grid_height) >= row_pos) & (jnp.arange(grid_height) < row_pos + wall_width)
            col_mask = (jnp.arange(grid_width) >= col_pos) & (jnp.arange(grid_width) < col_pos + wall_width)
            
            vertex_mask = jnp.expand_dims(row_mask, 1) & jnp.expand_dims(col_mask, 0)
            vertex_mask = jnp.expand_dims(vertex_mask, 2)
            
            return grid * (1 - vertex_mask) + get_color('almost_black') * vertex_mask
        
        # Draw all vertices using scan
        def draw_vertices_row(carry, r):
            grid_r = carry
            
            def draw_vertices_col(carry, c):
                return draw_vertex(carry, r, c), None
            
            grid_r, _ = jax.lax.scan(draw_vertices_col, grid_r, jnp.arange(cols + 1))
            return grid_r, None
        
        grid, _ = jax.lax.scan(draw_vertices_row, grid, jnp.arange(rows + 1))
        
        # Function to check if there's a horizontal wall
        def has_horizontal_wall(r, c):
            # Get indices in the wall grid
            r_idx, c_idx = 2*r, 2*c+1
            
            # Check boundary walls
            is_boundary = (r == 0) | (r == rows)
            
            # Check internal walls - using JAX-compatible approach without if statements
            # Compute the internal wall value only if indices are valid
            valid_indices = (r_idx < self.grid.shape[0]) & (c_idx < self.grid.shape[1])
            internal_wall = jnp.where(valid_indices, self.grid[r_idx, c_idx], False)
            
            return is_boundary | internal_wall
        
        # Function to draw a horizontal wall
        def draw_horizontal_wall(grid, r, c):
            row_pos = r * (cell_width + wall_width)
            col_pos = c * (cell_width + wall_width) + wall_width
            
            # Create mask for this wall section
            row_mask = (jnp.arange(grid_height) >= row_pos) & (jnp.arange(grid_height) < row_pos + wall_width)
            col_mask = (jnp.arange(grid_width) >= col_pos) & (jnp.arange(grid_width) < col_pos + cell_width)
            
            wall_mask = jnp.expand_dims(row_mask, 1) & jnp.expand_dims(col_mask, 0)
            wall_mask = jnp.expand_dims(wall_mask, 2)
            
            # Draw wall conditionally
            should_draw = has_horizontal_wall(r, c)
            return jnp.where(
                should_draw,
                grid * (1 - wall_mask) + get_color('almost_black') * wall_mask,
                grid
            )
        
        # Draw horizontal walls using nested scan
        def draw_h_walls_row(carry, r):
            grid_r = carry
            
            def draw_h_walls_col(carry, c):
                return draw_horizontal_wall(carry, r, c), None
            
            grid_r, _ = jax.lax.scan(draw_h_walls_col, grid_r, jnp.arange(cols))
            return grid_r, None
        
        grid, _ = jax.lax.scan(draw_h_walls_row, grid, jnp.arange(rows + 1))
        
        # Function to check if there's a vertical wall
        def has_vertical_wall(r, c):
            # Get indices in the wall grid
            r_idx, c_idx = 2*r+1, 2*c
            
            # Check boundary walls
            is_boundary = (c == 0) | (c == cols)
            
            # Check internal walls - using JAX-compatible approach without if statements
            # Compute the internal wall value only if indices are valid
            valid_indices = (r_idx < self.grid.shape[0]) & (c_idx < self.grid.shape[1])
            internal_wall = jnp.where(valid_indices, self.grid[r_idx, c_idx], False)
            
            return is_boundary | internal_wall
        
        # Function to draw a vertical wall
        def draw_vertical_wall(grid, r, c):
            row_pos = r * (cell_width + wall_width) + wall_width
            col_pos = c * (cell_width + wall_width)
            
            # Create mask for this wall section
            row_mask = (jnp.arange(grid_height) >= row_pos) & (jnp.arange(grid_height) < row_pos + cell_width)
            col_mask = (jnp.arange(grid_width) >= col_pos) & (jnp.arange(grid_width) < col_pos + wall_width)
            
            wall_mask = jnp.expand_dims(row_mask, 1) & jnp.expand_dims(col_mask, 0)
            wall_mask = jnp.expand_dims(wall_mask, 2)
            
            # Draw wall conditionally
            should_draw = has_vertical_wall(r, c)
            return jnp.where(
                should_draw,
                grid * (1 - wall_mask) + get_color('almost_black') * wall_mask,
                grid
            )
        
        # Draw vertical walls using nested scan
        def draw_v_walls_row(carry, r):
            grid_r = carry
            
            def draw_v_walls_col(carry, c):
                return draw_vertical_wall(carry, r, c), None
            
            grid_r, _ = jax.lax.scan(draw_v_walls_col, grid_r, jnp.arange(cols + 1))
            return grid_r, None
        
        grid, _ = jax.lax.scan(draw_v_walls_row, grid, jnp.arange(rows))
        
        # ----- DRAW DEPOTS -----
        # Function to create a depot patch
        def create_depot_patch():
            # Create a blank patch of the exact final size needed
            patch = jnp.zeros((cell_width, cell_width), dtype=jnp.float32)
            
            # Define the border width
            border = 1
            corner = depot_width // 2
            
            # Add borders
            # Top and bottom borders
            patch = patch.at[0:border, :].set(1.0)
            patch = patch.at[-border:, :].set(1.0)
            
            # Left and right borders
            patch = patch.at[:, 0:border].set(1.0)
            patch = patch.at[:, -border:].set(1.0)
            
            return patch
        
        # Create depot patch once (same shape used for all depots)
        depot_patch = create_depot_patch()
        
        # Function to draw a depot
        def draw_depot(grid, depot_idx):
            # depot
            
            row, col = state.passengers[depot_idx, 2], state.passengers[depot_idx, 3]
            
            top = wall_width + row * (cell_width + wall_width)
            left = wall_width + col * (cell_width + wall_width)
            
            # Create masks for this position using dynamic updates
            # Instead of using .at[] with dynamic indices, use dynamic_update_slice
            
            # Get color using array indexing instead of dictionary
            color_idx = jnp.all(self.depot_locs == state.passengers[depot_idx, 2:4][None], axis=1).argmax()
            depot_color = self.depot_colors[color_idx]
            
            # Create a mask of the same shape as grid
            # We'll update a region with 1s where the depot should be drawn
            mask = jnp.zeros((grid_height, grid_width), dtype=jnp.float32)
            
            # Use dynamic_update_slice to place the depot patch
            mask = jax.lax.dynamic_update_slice(
                mask,
                depot_patch,
                (top, left)
            )
            
            # Expand mask for broadcasting
            mask_3d = jnp.expand_dims(mask > 0, 2)
            
            # Apply color
            return grid * (1 - mask_3d) + depot_color * mask_3d
        
        # Draw first n_passengers depots
        def draw_depots_loop(i, grid):
            return jax.lax.cond(
                i < self.n_passengers,
                lambda: draw_depot(grid, i),
                lambda: grid
            )
        
        grid = jax.lax.fori_loop(0, 4, draw_depots_loop, grid)  # Limit to maximum of 4 depots
        
        # ----- DRAW PASSENGERS -----
        # Function to create passenger patch
        def create_passenger_patch():
            # Create a blank patch
            passenger_patch = jnp.zeros((cell_width, cell_width), dtype=jnp.float32)
            agent_width = character_width
            
            # Create a circle mask
            y_coords, x_coords = jnp.mgrid[:cell_width, :cell_width]
            center = cell_width // 2
            circle_mask = ((y_coords - center)**2 + (x_coords - center)**2) <= (agent_width/2)**2
            
            # Create X mark using diagonals - using boolean arrays
            diag1 = jnp.eye(cell_width, dtype=bool)
            diag2 = jnp.fliplr(jnp.eye(cell_width, dtype=bool))
            # Boolean OR is valid
            x_mask = jnp.logical_or(diag1, diag2) & circle_mask
            
            # Set circle as the base with x_mask overlay
            passenger_patch = circle_mask.astype(jnp.float32)
            
            return passenger_patch, x_mask.astype(jnp.float32)
        
        # Create passenger patches once (same shape used for all passengers)
        passenger_patch, passenger_x_mask = create_passenger_patch()
        
        # Function to draw a passenger
        def draw_passenger(grid, p_idx):
            # Only draw if passenger is not in taxi
            should_draw = state.passenger_in_taxi != p_idx
            
            def do_draw():
                p_row, p_col = state.passengers[p_idx, 0], state.passengers[p_idx, 1]
                color_idx = jnp.all(self.depot_locs == state.passengers[p_idx, 2:4][None], axis=1).argmax()
                passenger_color = self.depot_colors[color_idx]
                top = wall_width + p_row * (cell_width + wall_width)
                left = wall_width + p_col * (cell_width + wall_width)
                
                # Create masks using dynamic update slice
                # Create empty masks of the same shape as grid
                patch_mask = jnp.zeros((grid_height, grid_width), dtype=jnp.float32)
                x_mask = jnp.zeros((grid_height, grid_width), dtype=jnp.float32)
                
                # Use dynamic_update_slice to place the passenger patches
                patch_mask = jax.lax.dynamic_update_slice(
                    patch_mask,
                    passenger_patch,
                    (top, left)
                )
                
                x_mask = jax.lax.dynamic_update_slice(
                    x_mask,
                    passenger_x_mask,
                    (top, left)
                )
                
                # Create 3D masks for color application
                patch_mask_3d = jnp.expand_dims(patch_mask > 0, 2)
                x_mask_3d = jnp.expand_dims(x_mask > 0, 2)
                
                # Apply cyan base
                grid_with_base = grid * (1 - patch_mask_3d) + passenger_color * patch_mask_3d
                
                # Apply X marks with dimgray
                return grid_with_base * (1 - x_mask_3d) + (get_color('dimgray') / 4) * x_mask_3d
            
            return jax.lax.cond(should_draw, do_draw, lambda: grid)
        
        # Draw all passengers
        def draw_passengers_loop(i, grid):
            return jax.lax.cond(
                i < self.n_passengers,
                lambda: draw_passenger(grid, i),
                lambda: grid
            )
        
        grid = jax.lax.fori_loop(0, self.n_passengers, draw_passengers_loop, grid)
        
        # ----- DRAW TAXI -----
        # Function to create taxi patch
        def create_taxi_patches():
            # Create outline patch (hollow taxi)
            taxi_outline = jnp.zeros((cell_width, cell_width), dtype=jnp.float32)
            
            # Set borders
            taxi_outline = taxi_outline.at[:depot_width//2+1, :].set(1.0)
            taxi_outline = taxi_outline.at[-depot_width//2-1:, :].set(1.0)
            taxi_outline = taxi_outline.at[:, :depot_width//2+1].set(1.0)
            taxi_outline = taxi_outline.at[:, -depot_width//2-1:].set(1.0)
            
            # Create clean corners (make them match the original)
            taxi_outline = taxi_outline.at[:depot_width//2, :].set(0.0)
            taxi_outline = taxi_outline.at[-depot_width//2:, :].set(0.0)
            taxi_outline = taxi_outline.at[:, :depot_width//2].set(0.0)
            taxi_outline = taxi_outline.at[:, -depot_width//2:].set(0.0)
            
            

            # Create passenger patch (plus sign)
            passenger_patch = jnp.zeros((cell_width, cell_width), dtype=jnp.float32)
            start_idx = (cell_width - character_width) // 2
            end_idx = start_idx + character_width
            
            # Create a center square for cyan
            center_mask = jnp.zeros((cell_width, cell_width), dtype=jnp.float32)
            center_mask = center_mask.at[start_idx:end_idx, start_idx:end_idx].set(1.0)
            
            # Create a plus sign
            plus_mask = jnp.zeros((cell_width, cell_width), dtype=jnp.float32)
            center = cell_width // 2
            plus_mask = plus_mask.at[start_idx:end_idx, center].set(1.0)
            plus_mask = plus_mask.at[center, start_idx:end_idx].set(1.0)
            
            return taxi_outline, center_mask, plus_mask
        
        # Create taxi patches once
        taxi_outline, passenger_center, passenger_plus = create_taxi_patches()
        
        # Function to draw the taxi
        def draw_taxi(grid):
            row, col = state.taxi_pos
            top = wall_width + row * (cell_width + wall_width)
            left = wall_width + col * (cell_width + wall_width)
            
            # Create empty masks of the same shape as grid
            outline_mask = jnp.zeros((grid_height, grid_width), dtype=jnp.float32)
            center_mask = jnp.zeros((grid_height, grid_width), dtype=jnp.float32)
            plus_mask = jnp.zeros((grid_height, grid_width), dtype=jnp.float32)
            
            # Use dynamic_update_slice to place the taxi patches
            outline_mask = jax.lax.dynamic_update_slice(
                outline_mask,
                taxi_outline,
                (top, left)
            )
            
            center_mask = jax.lax.dynamic_update_slice(
                center_mask,
                passenger_center,
                (top, left)
            )
            
            plus_mask = jax.lax.dynamic_update_slice(
                plus_mask,
                passenger_plus,
                (top, left)
            )
            
            # Create 3D masks
            outline_mask_3d = jnp.expand_dims(outline_mask > 0, 2)
            center_mask_3d = jnp.expand_dims(center_mask > 0, 2)
            plus_mask_3d = jnp.expand_dims(plus_mask > 0, 2)
            
            # Draw taxi outline
            grid = grid * (1 - outline_mask_3d) + (get_color('dimgray') / 4) * outline_mask_3d
            
            # Check if passenger in taxi
            has_passenger = state.passenger_in_taxi >= 0
            
            # get color of passenger
            color_idx = jnp.all(self.depot_locs == state.passengers[state.passenger_in_taxi, 2:4][None], axis=1).argmax()
            passenger_color = self.depot_colors[color_idx]

            def draw_passenger_in_taxi():
                # First add the cyan center
                grid_with_center = grid * (1 - center_mask_3d) + passenger_color * center_mask_3d
                # Then add the plus mark
                return grid_with_center * (1 - plus_mask_3d) + (get_color('dimgray') / 4) * plus_mask_3d
            
            return jax.lax.cond(has_passenger, draw_passenger_in_taxi, lambda: grid)
        
        grid = draw_taxi(grid)
        
        # ----- ADD BORDER -----
        # Function to add border
        def add_border(grid):
            pad_top, pad_bot = border_widths
            padded_height = grid_height + pad_top + pad_bot
            padded_width = grid_width + pad_top + pad_bot
            
            # Create a white border
            bordered_grid = jnp.ones((padded_height, padded_width, 3))
            
            # Check if passenger in taxi
            has_passenger = state.passenger_in_taxi >= 0
            color_idx = jnp.all(self.depot_locs == state.passengers[state.passenger_in_taxi, 2:4][None], axis=1).argmax()
            passenger_color = self.depot_colors[color_idx]
            def add_colored_border():
                passenger_idx = state.passenger_in_taxi
                # Use modulo to get the depot color index directly
                depot_color = passenger_color
                
                # Create a checkerboard pattern
                dash_width = 4
                y_coords, x_coords = jnp.mgrid[:padded_height, :padded_width]
                checkerboard = ((y_coords // dash_width) % 2) ^ ((x_coords // dash_width) % 2)
                checkerboard_mask = jnp.expand_dims(checkerboard, 2)
                
                # Apply colored border
                bordered_grid = jnp.ones((padded_height, padded_width, 3)) * (1 - checkerboard_mask)
                bordered_grid = bordered_grid + jnp.expand_dims(depot_color, (0, 1)) * checkerboard_mask
                
                # Insert original grid in center
                mask = jnp.zeros((padded_height, padded_width))
                mask = mask.at[pad_top:pad_top+grid_height, pad_top:pad_top+grid_width].set(1)
                mask = jnp.expand_dims(mask, 2)
                
                return bordered_grid * (1 - mask) + jnp.pad(grid, ((pad_top, pad_bot), (pad_top, pad_bot), (0, 0))) * mask
            
            def add_white_border():
                # Just insert original grid in center
                return bordered_grid.at[pad_top:pad_top+grid_height, pad_top:pad_top+grid_width].set(grid)
            
            return jax.lax.cond(has_passenger, add_colored_border, add_white_border)
        
        grid = add_border(grid)
        
        # ----- RESIZE IF NEEDED -----
        target_height, target_width = 64, 64
        
        # Check if resize needed
        def resize_grid():
            # Calculate padding
            pad_height = target_height - grid.shape[0]
            pad_width = target_width - grid.shape[1]
            
            pad_top = pad_height // 2
            pad_bot = pad_height - pad_top
            pad_left = pad_width // 2
            pad_right = pad_width - pad_left
            
            pad_width = ((max(0, pad_top), max(0, pad_bot)), 
                          (max(0, pad_left), max(0, pad_right)), 
                          (0, 0))
            
            return jnp.pad(grid, pad_width, mode='constant', constant_values=1.0)
        
        grid = jax.lax.cond(
            (grid.shape[0] != target_height) | (grid.shape[1] != target_width),
            lambda: resize_grid(),
            lambda: grid
        )
        grid = jax.image.resize(grid, (self.img_size, self.img_size, 3), method='bilinear')
        return grid

    def _pickup_passenger(self, state: TaxiState, taxi_pos: chex.Array) -> TaxiState:
        """Pick up a passenger from a depot."""
        # Find which passenger to pick up (if any are at the taxi's location)
        # Check for passengers at the taxi's current position using a JAX-friendly approach
        at_taxi_pos = jnp.all(state.passengers[:, :2] == taxi_pos, axis=1)
        passenger_idx = jnp.where(at_taxi_pos, jnp.arange(self.n_passengers), -1)
        # Get the first matching passenger index or -1 if none found
        passenger_idx = jnp.max(passenger_idx)  # Will be -1 if no matches found
        
        # Update passenger state
        return state.replace(passenger_in_taxi=passenger_idx)

    def _dropoff_passenger(self, state: TaxiState, taxi_pos: chex.Array) -> TaxiState:
        """Drop off a passenger at their goal location."""
        # Get current passenger
        passenger_idx = state.passenger_in_taxi
        passenger = state.passengers[passenger_idx]
        
        # Check if at goal location
        at_goal = jnp.all(taxi_pos == passenger[2:4])
        
        # If allow_dropoff_anywhere is True, always allow dropoff
        should_dropoff = jnp.logical_or(at_goal, self.allow_dropoff_anywhere)
        
        # # Update passenger position if at goal or if dropoff_anywhere is allowed
        # passengers = state.passengers.at[passenger_idx, :2].set(
        #     jnp.where(should_dropoff, taxi_pos, passenger[:2])
        # )
        
        # Reset passenger in taxi if dropped off
        passenger_in_taxi = jnp.where(should_dropoff, -1, passenger_idx)
        
        return state.replace(
            # passengers=passengers,
            passenger_in_taxi=passenger_in_taxi
        )

    def _check_goal(self, state: TaxiState) -> bool:
        """Check if all passengers have reached their goals."""
        # Goal is reached when all passengers are at their destinations
        # and no passenger is in the taxi
        
        # First check that no passenger is in the taxi
        no_passenger_in_taxi = state.passenger_in_taxi == -1
        
        # Then check that all passengers are at their destinations
        all_at_destinations = jnp.all(
            jnp.all(state.passengers[:, :2] == state.passengers[:, 2:4], axis=1)
        )
        
        return no_passenger_in_taxi & all_at_destinations

    

def sampling_policy(state: TaxiState, key: chex.PRNGKey, params: Optional[TaxiParams] = None, 
                  random_dropoff_prob: float = 0.5, target_pos: Optional[chex.Array] = None):
    """Sampling policy that randomly picks up a passenger and drops them off at random locations or destination.
    
    Args:
        state: Current taxi state
        key: PRNGKey for random sampling
        params: Environment parameters
        random_dropoff_prob: Probability of dropping passenger at a random location instead of destination
        target_pos: Optional target position to override policy decisions
        
    Returns:
        action: The action to take (0-5)
        new_key: Updated PRNGKey
    """
    # Use default params if none provided
    if params is None:
        params = TaxiParams()
    
    # Generate keys for random sampling
    key, subkey1, subkey2, subkey3 = jax.random.split(key, 4)
    
    # If target position is provided, move towards it regardless of other considerations
    if target_pos is not None:
        action = move_to_position(state, target_pos, params.size)
        return action, key
    
    # If no passenger in taxi, pick a random passenger and go towards them
    if state.passenger_in_taxi == -1:
        # Choose a random passenger that hasn't reached their destination
        passengers_not_at_dest = ~jnp.all(state.passengers[:, :2] == state.passengers[:, 2:4], axis=1)
        valid_passengers = jnp.arange(params.n_passengers)[passengers_not_at_dest]
        
        # If no valid passengers, take a random action
        has_valid_passengers = jnp.any(passengers_not_at_dest)
        
        def pick_random_passenger():
            # Sample a random passenger index
            passenger_idx = jax.random.choice(subkey1, valid_passengers)
            # Get passenger position
            target_pos = state.passengers[passenger_idx, :2]
            # Move towards passenger
            action = move_to_position(state, target_pos, params.size)
            return action
            
        def random_action():
            return jax.random.randint(subkey1, shape=(), minval=0, maxval=6)
        
        # Conditional logic for picking valid passenger
        action = jax.lax.cond(
            has_valid_passengers,
            pick_random_passenger,
            random_action
        )
    else:
        # We have a passenger in taxi, choose destination
        passenger_idx = state.passenger_in_taxi
        
        # Check if we should use the intended destination or a random depot
        # Use the probability parameter to control this decision
        use_random_target = jax.random.uniform(subkey2) < random_dropoff_prob
        
        # Get the intended destination
        intended_dest = state.passengers[passenger_idx, 2:4]
        
        def get_random_depot():
            # Create an environment with appropriate parameters
            env = TaxiGymnax(size=params.size, n_passengers=params.n_passengers)
            num_depots = env.num_depots
            
            # Choose a random depot index
            random_idx = jax.random.randint(subkey3, shape=(), minval=0, maxval=num_depots)
            
            # Get the depot position using dynamic_slice to avoid tracing errors
            return jax.lax.dynamic_slice(env.depot_locs, (random_idx, 0), (1, 2))[0]
        
        # Choose target position based on random_dropoff_prob
        target_pos = jax.lax.cond(
            use_random_target,
            get_random_depot,
            lambda: intended_dest
        )
        
        # Move towards target position
        action = move_to_position(state, target_pos, params.size)
    
    return action, key

def move_to_position(state: TaxiState, target_pos: chex.Array, env_size: int = 5) -> jnp.int32:
    """Policy to move taxi towards a target position, accounting for walls.
    
    Args:
        state: Current taxi state
        target_pos: Target position [row, col]
        env_size: Size of the environment grid
        
    Returns:
        action: The action to take (0-5)
    """
    # If already at target and passenger in taxi, drop off (action 5)
    at_target = jnp.all(state.taxi_pos == target_pos)
    passenger_in_taxi = state.passenger_in_taxi != -1
    
    # If at target, do interact (pickup or dropoff)
    def handle_at_target():
        return jnp.int32(5)  # interact
    
    def handle_navigation():
        # Calculate the direction to the target
        delta = target_pos - state.taxi_pos
        
        # Manhattan distance components
        dx = delta[1]  # column difference
        dy = delta[0]  # row difference
        
        # Determine if we should move horizontally or vertically first
        # Prioritize the dimension with the largest difference
        horiz_first = jnp.abs(dx) >= jnp.abs(dy)
        
        # Define potential actions for each direction
        action_right = jnp.int32(1)  # right
        action_left = jnp.int32(2)   # left
        action_up = jnp.int32(3)     # up
        action_down = jnp.int32(4)   # down
        
        # Choose horizontal action based on dx
        h_action = jnp.where(dx > 0, action_right, action_left)
        
        # Choose vertical action based on dy
        v_action = jnp.where(dy > 0, action_down, action_up)
        
        # Check for walls in both directions
        wall_h = check_wall_for_action(state, h_action, env_size)
        wall_v = check_wall_for_action(state, v_action, env_size)
        
        # Function to choose the best action based on walls and priorities
        def choose_action():
            # If going horizontally first
            def try_horiz_first():
                # First try horizontal if no wall
                def try_h():
                    return jnp.where(
                        wall_h, 
                        jnp.where(wall_v, jnp.int32(0), v_action),  # if h wall, try v
                        h_action  # no h wall, go horizontal
                    )
                
                # First try vertical if dx == 0
                def try_v():
                    return jnp.where(
                        wall_v,
                        jnp.where(wall_h, jnp.int32(0), h_action),  # if v wall, try h
                        v_action  # no v wall, go vertical
                    )
                
                return jax.lax.cond(dx != 0, try_h, try_v)
            
            # If going vertically first
            def try_vert_first():
                # First try vertical if no wall
                def try_v():
                    return jnp.where(
                        wall_v,
                        jnp.where(wall_h, jnp.int32(0), h_action),  # if v wall, try h
                        v_action  # no v wall, go vertical
                    )
                
                # First try horizontal if dy == 0
                def try_h():
                    return jnp.where(
                        wall_h,
                        jnp.where(wall_v, jnp.int32(0), v_action),  # if h wall, try v
                        h_action  # no h wall, go horizontal
                    )
                
                return jax.lax.cond(dy != 0, try_v, try_h)
            
            return jax.lax.cond(horiz_first, try_horiz_first, try_vert_first)
        
        return choose_action()
    
    # If at target and passenger in taxi, drop off
    # If at target and no passenger in taxi, pick up
    # Otherwise navigate towards target
    return jax.lax.cond(at_target, handle_at_target, handle_navigation)

def check_wall_for_action(state: TaxiState, action: jnp.int32, env_size: int = 5) -> jnp.bool_:
    """Check if there is a wall in the direction of the given action.
    
    Args:
        state: Current taxi state
        action: Action to check (1-4)
        env_size: Size of the environment grid
        
    Returns:
        has_wall: Whether there is a wall in that direction
    """
    # Define action offsets - must match exact order in step function
    action_offsets = jnp.array([
        [0, 0],   # noop (0)
        [0, 1],   # right (1)
        [0, -1],  # left (2)
        [-1, 0],  # up (3)
        [1, 0],   # down (4)
        [0, 0],   # interact (5)
    ])
    
    # Get the offset for this action
    offset = jax.lax.dynamic_slice_in_dim(action_offsets, action, 1)[0]
    
    # Use TaxiGymnax's has_wall method to check for walls
    # Create a temporary environment with appropriate size
    env = TaxiGymnax(size=env_size)
    wall = env.has_wall(state.taxi_pos, offset)
    
    # Check if the new position would be in bounds
    new_pos = state.taxi_pos + offset
    valid_pos = jnp.logical_and(
        jnp.all(new_pos >= 0),
        jnp.all(new_pos < env_size)
    )
    
    # There's effectively a wall if out of bounds or actual wall
    return jnp.logical_or(~valid_pos, wall)