import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import chex
import argparse
from omegaconf import OmegaConf as oc
from utils.printarr import printarr
from functools import partial
import matplotlib
# Use Agg backend to avoid display issues
matplotlib.use('Agg')

from datasets import TransitionData
from typing import List


@partial(jax.vmap, in_axes=(0,))
def render_state_to_image(state, image_size=40):
    """Renders a 2D state as an image representation.
    
    Args:
        state (chex.Array): 2D state coordinates in [0,1] range
        image_size (int): Size of the output image in pixels
        
    Returns:
        chex.Array: RGB image representation of the state
    """
    # Create a white background image (RGB)
    img = jnp.full((image_size, image_size, 3), 255, dtype=jnp.uint8)

    num_cells = 8
    cell_size = image_size // num_cells

    # Draw vertical grid lines
    for i in range(1, num_cells):
        x = i * cell_size
        img = img.at[x, :].set(jnp.array([128, 128, 128], dtype=jnp.uint8))
    # Draw horizontal grid lines
    for i in range(1, num_cells):
        y = i * cell_size
        img = img.at[:, y].set(jnp.array([128, 128, 128], dtype=jnp.uint8))
    
    # Draw a black border around the image
    img = img.at[0, :].set(jnp.array([0, 0, 0], dtype=jnp.uint8))
    img = img.at[image_size - 1, :].set(jnp.array([0, 0, 0], dtype=jnp.uint8))
    img = img.at[:, 0].set(jnp.array([0, 0, 0], dtype=jnp.uint8))
    img = img.at[:, image_size - 1].set(jnp.array([0, 0, 0], dtype=jnp.uint8))

    # Map the state coordinate (in [0,1]) to pixel coordinates.
    x_pixel = jnp.clip((state[0] * (image_size - 1)).astype(jnp.int32), 0, image_size - 1)
    y_pixel = jnp.clip((state[1] * (image_size - 1)).astype(jnp.int32), 0, image_size - 1)

    # Draw a red circle at (x_pixel, y_pixel) with radius 2.
    r = 1
    ys = jnp.arange(image_size)
    xs = jnp.arange(image_size)
    yy, xx = jnp.meshgrid(ys, xs, indexing="ij")
    circle = (xx - x_pixel)**2 + (yy - y_pixel)**2 <= r**2
    red = jnp.array([255, 0, 0], dtype=jnp.uint8)
    # Broadcast the red color to match circle mask and update.
    img = jnp.where(circle[..., None], red, img)

    return img

@partial(jax.vmap, in_axes=(None, 0, None))
def generate_2d_gridworld_dataset(horizon, rng, pixel_obs=True):
    """Generates a dummy dataset for testing purposes in a gridworld scenario.
    
    Args:
        horizon (int): Number of timesteps per trajectory
        rng (jax.random.PRNGKey): Random number generator key
        pixel_obs (bool): If True, states are rendered as 32x32 images; otherwise, uses the state vectors.
        
    Returns:
        tuple: (TransitionData, EnvConfig) containing the generated dataset and environment configuration
    """
    # Important: Generates a dummy dataset for testing purposes in a gridworld scenario.
    # If pixel_obs is True, states are rendered as 32x32 images; otherwise, uses the state vectors.
    step_size = 1 / 25
    n_actions = 5
    action_effects = jnp.array([(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)])
    rng_rollout, rng_state, rng_action = jax.random.split(rng, 3)
    state = jax.random.uniform(rng_state, (2,))
    actions = jax.random.randint(rng_action, (horizon,), 0, n_actions)
    goal = jnp.array((0.5, 0.5))

    def _step(state, action):
        obs, rng = state
        rng, rng_noise = jax.random.split(rng)
        next_obs = action_effects[action] * step_size + obs + jax.random.normal(rng_noise, obs.shape) * (step_size / 100)
        next_obs = jnp.clip(next_obs, 0, 1)
        return (next_obs, rng), next_obs

    _, states = jax.lax.scan(
        _step,
        (state, rng_rollout),
        actions
    )
    # Concatenate the initial state with the rollout.
    states = jnp.concatenate([state[None], states], axis=0)
    # Depending on the flag, render states as images or keep as vectors.
    if pixel_obs:
        observations = jnp.stack([jnp.array(render_state_to_image(s)) for s in states[:-1]], axis=0)
    else:
        observations = states[:-1]
    # Compute rewards: checking if the next state is within a circle around the goal.
    rewards = ((((states[1:] - goal) ** 2).sum(-1) - step_size ** 2) < 0).astype(jnp.float32)
    dones = rewards

    @chex.dataclass(frozen=True)
    class EnvConfig:
        obs: chex.Array
        n_actions: int

    return TransitionData(
        obs=observations,
        action=actions,
        reward=rewards,
        done=dones,
        is_first=jnp.zeros_like(rewards),
        state=states[:-1]
    ), EnvConfig(n_actions=n_actions, obs=observations[0])
    
def sample_l_shaped_initial_state(rng, thickness=0.3):
    """Samples an initial state uniformly from an L-shaped region.
    
    Args:
        rng (jax.random.PRNGKey): Random number generator key
        thickness (float): Thickness of the L-shaped region
        
    Returns:
        chex.Array: Sampled initial state coordinates
    """
    # Samples an initial state uniformly from an L-shaped region with adjustable thickness.
    # L-shaped region is defined as the union of:
    #   Area1: x in [0, thickness], y in [0, 1]
    #   Area2: x in [thickness, 1], y in [0, thickness]
    rng, rng_choice = jax.random.split(rng)
    choice = jax.random.uniform(rng_choice)
    p_area1 = 1.0 / (2 - thickness)  # probability for Area1
    def sample_area1(rng):
         rng, rngx, rngy = jax.random.split(rng, 3)
         x = jax.random.uniform(rngx, (), minval=0.0, maxval=thickness)
         y = jax.random.uniform(rngy, (), minval=0.0, maxval=1.0)
         return jnp.array([x, y])
    def sample_area2(rng):
         rng, rngx, rngy = jax.random.split(rng, 3)
         x = jax.random.uniform(rngx, (), minval=thickness, maxval=1.0)
         y = jax.random.uniform(rngy, (), minval=0.0, maxval=thickness)
         return jnp.array([x, y])
    initial_state = jax.lax.cond(choice < p_area1, lambda r: sample_area1(r), lambda r: sample_area2(r), rng)
    return initial_state

def sample_complement_state(rng, thickness=0.3):
    """Samples an initial state uniformly from the complement of the L-shaped region.
    
    Args:
        rng (jax.random.PRNGKey): Random number generator key
        thickness (float): Thickness of the L-shaped region
        
    Returns:
        chex.Array: Sampled initial state coordinates
    """
    # Samples an initial state uniformly from the complement of the L-shaped region,
    # i.e. from [thickness, 1] x [thickness, 1].
    rng, rngx, rngy = jax.random.split(rng, 3)
    x = jax.random.uniform(rngx, (), minval=thickness, maxval=1.0)
    y = jax.random.uniform(rngy, (), minval=thickness, maxval=1.0)
    return jnp.array([x, y])

# Modified: generate_l_shaped_dataset now includes the 'thickness' parameter
def generate_l_shaped_dataset(horizon, rng, thickness=0.3, pixel_obs=True, complement=False):
    """Generates a dataset where initial states are sampled from an L-shaped region.
    
    Args:
        horizon (int): Number of timesteps per trajectory
        rng (jax.random.PRNGKey): Random number generator key
        thickness (float): Thickness of the L-shaped region
        pixel_obs (bool): If True, states are rendered as images
        complement (bool): If True, sample from complement of L-shaped region
        
    Returns:
        tuple: (TransitionData, EnvConfig) containing the generated dataset and environment configuration
    """
    # Generates a dummy dataset where the initial state is sampled from an L-shaped region
    # with adjustable thickness. If pixel_obs is True, renders states as images.
    step_size = 1/25
    n_actions = 5
    action_effects = jnp.array([(0,0), (1,0), (-1,0), (0,1), (0,-1)])
    # Split RNGs: for rollout, actions, and initial state sampling.
    rng, rng_rollout, rng_action, rng_initial, rng_pixel_noise = jax.random.split(rng, 5)
    # Choose the sampling function based on the complement flag.
    state = ( sample_complement_state(rng_initial, thickness)
             if complement else sample_l_shaped_initial_state(rng_initial, thickness) )
    actions = jax.random.randint(rng_action, (horizon,), 0, n_actions)
    goal = jnp.array((0.5, 0.5))
    
    def _step(state, action):
        obs, rng = state
        rng, rng_noise = jax.random.split(rng)
        next_obs = action_effects[action] * step_size + obs + jax.random.normal(rng_noise, obs.shape) * step_size/100
        next_obs = jnp.clip(next_obs, 0, 1)
        return (next_obs, rng), next_obs

    # Simulate the trajectory over the given horizon.
    _, states = jax.lax.scan(_step, (state, rng_rollout), actions)
    rewards = (((states - goal[None]) ** 2).sum(-1) - step_size**2 < 0).astype(jnp.float32)
    dones = rewards
    states = jnp.concatenate([state[None], states], axis=0)
    
    # Render states as pixel images if pixel_obs flag is True.
    if pixel_obs:
        observations = render_state_to_image(states[:-1])
        # Convert to float and normalize to [0, 1]
        observations = observations.astype(jnp.float32) / 255.0
        # Add Gaussian pixel noise with standard deviation 0.05
        noise = jax.random.normal(rng_pixel_noise, observations.shape) * 0.05
        observations = jnp.clip(observations + noise, 0.0, 1.0)
    else:
        observations = states[:-1]
        
    @chex.dataclass(frozen=True)
    class EnvConfig:
        obs: chex.Array
        n_actions: int
    
    return TransitionData(
        obs=observations,
        action=actions,
        reward=rewards,
        done=dones,
        is_first=jnp.zeros_like(dones),
        state=states[:-1]
    ), EnvConfig(n_actions=n_actions, obs=observations[0])
    

def render_objects_to_image(
        objects,
        image_size=40
):
    '''
        objects: (n_objects, 3) array of (x, y, shape_type)
        shape_type: 0: circle, 1: square, 2: triangle
    '''
    # Create a white background image (RGB)
    img = jnp.full((image_size, image_size, 3), 255, dtype=jnp.uint8)

    num_cells = 8
    cell_size = image_size // num_cells

    # Draw vertical grid lines
    for i in range(1, num_cells):
        x = i * cell_size
        img = img.at[x, :].set(jnp.array([128, 128, 128], dtype=jnp.uint8))
    # Draw horizontal grid lines
    for i in range(1, num_cells):
        y = i * cell_size
        img = img.at[:, y].set(jnp.array([128, 128, 128], dtype=jnp.uint8))
    
    # Draw a black border around the image
    img = img.at[0, :].set(jnp.array([0, 0, 0], dtype=jnp.uint8))
    img = img.at[image_size - 1, :].set(jnp.array([0, 0, 0], dtype=jnp.uint8))
    img = img.at[:, 0].set(jnp.array([0, 0, 0], dtype=jnp.uint8))
    img = img.at[:, image_size - 1].set(jnp.array([0, 0, 0], dtype=jnp.uint8))

    colors = jnp.array([
        [255, 0, 0],    # Red circle
        [0, 255, 0],    # Green square 
        [0, 0, 255]     # Blue triangle
    ], dtype=jnp.uint8)
    # Draw each object

    def render_object(img, obj):
        x, y, shape_type = obj[0], obj[1], obj[2]
        # Map the state coordinate (in [0,1]) to pixel coordinates
        x_pixel = jnp.clip((x * (image_size - 1)).astype(jnp.int32), 0, image_size - 1)
        y_pixel = jnp.clip((y * (image_size - 1)).astype(jnp.int32), 0, image_size - 1)

        # Create meshgrid for shape drawing
        ys = jnp.arange(image_size)
        xs = jnp.arange(image_size)
        yy, xx = jnp.meshgrid(ys, xs, indexing="ij")

        r = 2
        size = 2
        shape_mask = jax.lax.cond(
            shape_type == 0,
            lambda _: (xx - x_pixel)**2 + (yy - y_pixel)**2 <= r**2,
            lambda _: jax.lax.cond(
                shape_type == 1,
                lambda _: (jnp.abs(xx - x_pixel) <= size) & (jnp.abs(yy - y_pixel) <= size),
                lambda _: (yy - y_pixel >= -size) & (jnp.abs(xx - x_pixel) <= size) & (yy - y_pixel <= size - jnp.abs(xx - x_pixel)),
                None
            ),
            None
        )
        # Update image with the shape
        img = jnp.where(shape_mask[..., None], colors[shape_type.astype(jnp.int32)], img)
        return img, img

    img, _ = jax.lax.scan(render_object, img, objects)

    return img

def generate_multi_object_dataset(
        horizon,
        rng,
        n_objects=3,
        img_size=40,
        pixel_obs=True
):
    # Generates a dataset with multiple objects of different shapes
    step_size = 1 / 25
    # Number of actions per object (4 directions) + 1 no-op action
    n_actions = n_objects * 4 + 1
    action_effects = jnp.array([(0, 0)] + [(1, 0), (-1, 0), (0, 1), (0, -1)] * n_objects)
    
    rng_rollout, rng_state, rng_action, rng_shape, rng_noise = jax.random.split(rng, 5)
    
    # Initialize random positions and shapes for each object
    # Sample positions in range [0.1, 0.9] to avoid edges
    states = jax.random.uniform(rng_state, (n_objects, 2), minval=0.1, maxval=0.9)
    n_shapes = 3    
    shape_indices = jnp.arange(n_objects) % n_shapes
    states = jnp.concatenate([states, shape_indices[:, None]], axis=1)
    
    actions = jax.random.randint(rng_action, (horizon,), 0, n_actions)
    goal = jnp.array((0.5, 0.5))

    def _step(state, action):
        objects, rng = state
        rng, rng_noise = jax.random.split(rng)
        
        def update_object():
            # Calculate which object to move and which direction
            obj_idx = (action - 1) // 4
            direction_idx = (action - 1) % 4
            
            # Update the position of the selected object
            return objects.at[obj_idx, :2].set(
                jnp.clip(objects[obj_idx, :2] + action_effects[action] * (step_size + 
                jax.random.normal(rng_noise, (2,)) * (step_size / 100)), 0, 1)
            )
        
        # If action is 0 (no-op), don't change objects, otherwise update
        new_objects = jax.lax.cond(
            action == 0,
            lambda: objects,
            lambda: update_object()
        )
        
        return (new_objects, rng), new_objects

    _, states_history = jax.lax.scan(
        _step,
        (states, rng_rollout),
        actions
    )
    
    # Concatenate initial state with history
    states_history = jnp.concatenate([states[None], states_history], axis=0)
    # Render states as images
    observations = jax.vmap(render_objects_to_image, in_axes=(0, None))(states_history, img_size)
    observations = observations.astype(jnp.float32) / 255.0
    
    # Compute rewards based on distance to goal (using mean distance of all objects)
    distances = jnp.sqrt(((states_history[1:, :, :2] - goal[None, None]) ** 2).sum(-1))
    mean_distances = distances.mean(axis=1)
    
    rewards = (mean_distances < step_size).astype(jnp.float32)
    dones = rewards

    @chex.dataclass(frozen=True)
    class EnvConfig:
        obs: chex.Array
        n_actions: int
        state_names: List[str]
    #flatten states
    flattened_states = states_history.reshape(states_history.shape[0], -1)
    state_names = []
    for i in range(n_objects):
        state_names.append(f"object_{i}_x")
        state_names.append(f"object_{i}_y")
        state_names.append(f"object_{i}_shape")
        
    return TransitionData(
        obs=observations,
        action=actions,
        reward=rewards,
        done=dones,
        is_first=jnp.zeros_like(rewards),
        state=flattened_states
    ), EnvConfig(n_actions=n_actions, obs=observations[0], state_names=state_names)

def generate_multi_object_dataset_with_selection(horizon, rng, n_objects=3, img_size=40):
    # Generates a dataset with multiple objects of different shapes and object selection
    step_size = 1 / 25
    # Actions: 0 (no-op), 1-4 (movement), 5 (prev object), 6 (next object)
    n_actions = 7
    action_effects = jnp.array([(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)])
    
    rng_rollout, rng_state, rng_action, rng_shape, rng_noise = jax.random.split(rng, 5)
    
    # Initialize random positions and shapes for each object
    states = jax.random.uniform(rng_state, (n_objects, 2))
    n_shapes = 3    
    shape_indices = jnp.arange(n_objects) % n_shapes
    states = jnp.concatenate([states, shape_indices[:, None]], axis=1)
    
    # Initialize selected object index
    selected_idx = jax.random.randint(rng_action, (1,), 0, n_objects)[0]
    
    actions = jax.random.randint(rng_action, (horizon,), 0, n_actions)
    goal = jnp.array((0.5, 0.5))

    def _step(state, action):
        objects, selected_idx, rng = state
        rng, rng_noise = jax.random.split(rng)
        
        # Handle object selection actions
        def handle_selection(action, selected_idx):
            return jax.lax.cond(
                action == 5,  # Previous object
                lambda: (selected_idx - 1) % n_objects,
                lambda: jax.lax.cond(
                    action == 6,  # Next object
                    lambda: (selected_idx + 1) % n_objects,
                    lambda: selected_idx
                )
            )
        
        # Update selected index if needed
        new_selected_idx = handle_selection(action, selected_idx)
        
        # Handle movement actions
        def handle_movement(action, objects, selected_idx):
            return jax.lax.cond(
                jnp.logical_and(action >= 1, action <= 4),
                lambda: objects.at[selected_idx, :2].set(
                    objects[selected_idx, :2] + action_effects[action] * step_size + 
                    jax.random.normal(rng_noise, (2,)) * (step_size / 100)
                ),
                lambda: objects
            )
        
        # Update object positions if movement action
        new_objects = handle_movement(action, objects, new_selected_idx)
        
        # Clip positions to [0, 1]
        new_objects = new_objects.at[new_selected_idx, :2].set(
            jnp.clip(new_objects[new_selected_idx, :2], 0, 1)
        )
        
        return (new_objects, new_selected_idx, rng), (new_objects, new_selected_idx)

    _, (states_history, selected_indices) = jax.lax.scan(
        _step,
        (states, selected_idx, rng_rollout),
        actions
    )
    
    # Concatenate initial state with history
    states_history = jnp.concatenate([states[None], states_history], axis=0)
    selected_indices = jnp.concatenate([selected_idx[None], selected_indices], axis=0)
    
    # Render states as images with selected object highlighting
    def render_with_selection(objects, selected_idx):
        img = render_objects_to_image(objects, img_size)
        # Add black outline to selected object
        x, y = objects[selected_idx, :2]
        x_pixel = jnp.clip((x * (img_size - 1)).astype(jnp.int32), 0, img_size - 1)
        y_pixel = jnp.clip((y * (img_size - 1)).astype(jnp.int32), 0, img_size - 1)
        
        # Create outline mask
        ys = jnp.arange(img_size)
        xs = jnp.arange(img_size)
        yy, xx = jnp.meshgrid(ys, xs, indexing="ij")
        outline_mask = ((xx - x_pixel)**2 + (yy - y_pixel)**2 <= 4**2) & \
                      ((xx - x_pixel)**2 + (yy - y_pixel)**2 > 2**2)
        
        # Apply black outline
        img = jnp.where(outline_mask[..., None], jnp.array([0, 0, 0], dtype=jnp.uint8), img)
        return img
    
    observations = jax.vmap(render_with_selection)(states_history[:-1], selected_indices[:-1])
    observations = observations.astype(jnp.float32) / 255.0
    
    # Compute rewards based on distance to goal (using mean distance of all objects)
    distances = jnp.sqrt(((states_history[1:, :, :2] - goal[None, None]) ** 2).sum(-1))
    mean_distances = distances.mean(axis=1)
    rewards = (mean_distances < step_size).astype(jnp.float32)
    dones = rewards

    @chex.dataclass(frozen=True)
    class EnvConfig:
        obs: chex.Array
        n_actions: int

    # Flatten objects and append selected_idx
    flattened_states = states_history[:-1].reshape(states_history[:-1].shape[0], -1)
    states_with_selection = jnp.concatenate([flattened_states, selected_indices[:-1, None] / (n_objects - 1)], axis=1)
    
    return TransitionData(
        obs=observations,
        action=actions,
        reward=rewards,
        done=dones,
        is_first=jnp.zeros_like(rewards),
        state=states_with_selection
    ), EnvConfig(n_actions=n_actions, obs=observations[0])

def render_taxi_to_image(taxi_pos, passenger_positions, passenger_colors, taxi_passenger, grid_size=8, image_size=40):
    '''
        taxi_pos: (2,) array of (x, y)
        passenger_positions: (n_passengers, 2) array of (x, y) positions
        passenger_colors: (n_passengers,) array of color indices
        taxi_passenger: int indicating which passenger is in taxi (0 if empty)
        grid_size: number of cells in the grid (NxN)
        image_size: size of the rendered image
    '''
    # Create a white background image (RGB)
    img = jnp.full((image_size, image_size, 3), 255, dtype=jnp.uint8)

    cell_size = image_size / grid_size

    # Draw grid at exact positions where agents can be
    # For a grid_size of N, we need to draw lines at positions 0/N, 1/N, 2/N, ..., N/N
    # These positions get mapped to pixel coordinates
    for i in range(grid_size + 1):
        # Calculate position in [0,1] range
        pos = jnp.array(i / grid_size)
        # Map to pixel coordinates
        pixel_pos = jnp.clip((pos * (image_size - 1)).astype(jnp.int32), 0, image_size - 1)
        
        # Draw vertical line
        img = img.at[pixel_pos, :].set(jnp.array([128, 128, 128], dtype=jnp.uint8))
        # Draw horizontal line
        img = img.at[:, pixel_pos].set(jnp.array([128, 128, 128], dtype=jnp.uint8))
    
    # Draw a black border around the image
    img = img.at[0, :].set(jnp.array([0, 0, 0], dtype=jnp.uint8))
    img = img.at[image_size - 1, :].set(jnp.array([0, 0, 0], dtype=jnp.uint8))
    img = img.at[:, 0].set(jnp.array([0, 0, 0], dtype=jnp.uint8))
    img = img.at[:, image_size - 1].set(jnp.array([0, 0, 0], dtype=jnp.uint8))

    # Define colors for passengers (excluding white and black)
    colors = jnp.array([
        [255, 0, 0],    # Red
        [0, 255, 0],    # Green
        [0, 0, 255],    # Blue
        [255, 255, 0],  # Yellow
        [255, 0, 255],  # Magenta
        [0, 255, 255],  # Cyan
    ], dtype=jnp.uint8)

    # Draw passengers
    def render_passenger(img, pos, color_idx, passenger_idx):
        x, y = pos
        # Convert position to grid cell coordinates
        x_cell = jnp.clip((x * (grid_size-1)).astype(jnp.int32), 0, grid_size - 1)
        y_cell = jnp.clip((y * (grid_size-1)).astype(jnp.int32), 0, grid_size - 1)

        # Calculate pixel coordinates for the cell
        x_start = x_cell * cell_size
        y_start = y_cell * cell_size
        
        # Create meshgrid for passenger drawing
        ys = jnp.arange(image_size)
        xs = jnp.arange(image_size)
        yy, xx = jnp.meshgrid(ys, xs, indexing="ij")
        
        # Create mask for the full cell
        cell_mask = (xx >= x_start) & (xx < x_start + cell_size) & \
                   (yy >= y_start) & (yy < y_start + cell_size)
        
        # Check if passenger is in taxi
        is_in_taxi = passenger_idx + 1 == taxi_passenger
        
        
        def draw_regular_passenger(img):
            # Draw regular passenger without outline
            return jnp.where(cell_mask[..., None], colors[color_idx], img)
        
        # Use jax.lax.cond for conditional rendering
        img = jax.lax.cond(
            is_in_taxi,
            lambda: img, # don't draw passenger in taxi
            lambda: draw_regular_passenger(img)
        )
        
        return img, img
    
    # Draw all passengers
    img = jax.lax.scan(lambda img, data: render_passenger(img, data[0], data[1], data[2]), 
                      img, (passenger_positions, passenger_colors, jnp.arange(len(passenger_positions))))[0]
    
    def render_passenger_in_taxi(img, pos, color_idx, passenger_idx):
        x, y = pos
        x_cell = jnp.clip((x * (grid_size-1)).astype(jnp.int32), 0, grid_size - 1)
        y_cell = jnp.clip((y * (grid_size-1)).astype(jnp.int32), 0, grid_size - 1)
        
        x_start = x_cell * cell_size
        y_start = y_cell * cell_size
        
        ys = jnp.arange(image_size)
        xs = jnp.arange(image_size)
        yy, xx = jnp.meshgrid(ys, xs, indexing="ij")    
        
        cell_mask = (xx >= x_start) & (xx < x_start + cell_size) & \
                   (yy >= y_start) & (yy < y_start + cell_size)
        
        is_in_taxi = passenger_idx + 1 == taxi_passenger    
        
        center_x = x_start + cell_size // 2
        center_y = y_start + cell_size // 2
        
        radius = cell_size // 3
        circle_mask = ((xx - center_x)**2 + (yy - center_y)**2 <= radius**2) & cell_mask
        img = jnp.where(circle_mask[..., None], colors[color_idx], img) 
        
        return img
    
    img = jax.lax.cond(
        taxi_passenger > 0,
        lambda: render_passenger_in_taxi(img, passenger_positions[taxi_passenger-1], passenger_colors[taxi_passenger-1], taxi_passenger-1),
        lambda: img
    )

    # Draw taxi (empty square)
    x, y = taxi_pos
    # Convert position to grid cell coordinates
    x_cell = jnp.clip((x * (grid_size-1)).astype(jnp.int32), 0, grid_size - 1)
    y_cell = jnp.clip((y * (grid_size-1)).astype(jnp.int32), 0, grid_size - 1)
    
    # Calculate pixel coordinates for the cell
    x_start = jnp.ceil(x_cell * cell_size)
    y_start = jnp.ceil(y_cell * cell_size)
    
    # Create meshgrid for taxi drawing
    ys = jnp.arange(image_size)
    xs = jnp.arange(image_size)
    yy, xx = jnp.meshgrid(ys, xs, indexing="ij")
    
    # Create mask for the full cell
    cell_mask = (xx >= x_start) & (xx < x_start + jnp.ceil(cell_size)) & \
                (yy >= y_start) & (yy < y_start + jnp.ceil(cell_size))
    
    # Add black border to taxi cell
    border_mask = ((xx == x_start) | (xx == x_start + jnp.ceil(cell_size) - 1) | \
                  (yy == y_start) | (yy == y_start + jnp.ceil(cell_size) - 1)) & cell_mask
    img = jnp.where(border_mask[..., None], jnp.array([0, 0, 0], dtype=jnp.uint8), img)

    return img

def generate_taxi_dataset(horizon, rng, grid_size=8, n_passengers=3, img_size=40, custom_actions=None, 
                         initial_taxi_pos=None, initial_passenger_positions=None, success_rate=0.7):
    # Generates a dataset for the taxi environment with discrete grid positions
    # Each position is at the intersection of grid lines
    # Actions: 0 (no-op), 1-4 (movement), 5 (pickup), 6 (dropoff)
    # success_rate controls the probability of choosing pickup/dropoff actions when they would be successful
    n_actions = 7
    action_effects = jnp.array([(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)])
    
    rng_rollout, rng_taxi, rng_action, rng_passenger, rng_noise, rng_case = jax.random.split(rng, 6)
    
    # Randomly select one of the three cases (0: different positions, 1: same position, 2: passenger in taxi)
    case_idx = jax.random.randint(rng_case, (), 0, 3)
    
    def init_case_0():
        # Case 0: Different positions
        if initial_taxi_pos is None:
            taxi_pos = jax.random.randint(rng_taxi, (2,), 0, grid_size)
        else:
            taxi_pos = initial_taxi_pos
        
        if initial_passenger_positions is None:
            passenger_positions = jnp.zeros((n_passengers, 2), dtype=taxi_pos.dtype)
            used_positions = jnp.zeros((grid_size, grid_size), dtype=jnp.bool_)
            used_positions = used_positions.at[taxi_pos[0], taxi_pos[1]].set(True)
            
            def init_passenger(state, i):
                passenger_positions, used_positions, rng = state
                rng, rng_pos = jax.random.split(rng)
                
                # Find available positions
                available_positions = ~used_positions
                probs =  (available_positions.astype(jnp.float32) / available_positions.astype(jnp.float32).sum())
                sample_flat_pos = jax.random.categorical(rng_pos, jnp.log(probs).flatten())
                sample_pos = jnp.unravel_index(sample_flat_pos, available_positions.shape)
                available_count = available_positions.sum()
                pos = jax.lax.cond(
                    available_count == 0,
                    lambda: jnp.array([-1, -1], dtype=taxi_pos.dtype),
                    lambda: jnp.array(sample_pos, dtype=taxi_pos.dtype)
                )
                
                # Update used positions
                used_positions = used_positions.at[pos[0], pos[1]].set(True)
                # Update passenger position
                passenger_positions = passenger_positions.at[i].set(pos)
                
                return (passenger_positions, used_positions, rng), None
            
            (passenger_positions, _, _), _ = jax.lax.scan(
                init_passenger,
                (passenger_positions, used_positions, rng_passenger),
                jnp.arange(n_passengers)
            )
        else:
            passenger_positions = initial_passenger_positions.astype(taxi_pos.dtype)
        
        return taxi_pos, passenger_positions, jnp.array(0, dtype=jnp.int32)
    
    def init_case_1():
        # Case 1: Same position
        if initial_taxi_pos is None:
            taxi_pos = jax.random.randint(rng_taxi, (2,), 0, grid_size)
        else:
            taxi_pos = initial_taxi_pos
        
        if initial_passenger_positions is None:
            passenger_positions = jnp.zeros((n_passengers, 2), dtype=taxi_pos.dtype)
            # Randomly select which passenger to place at taxi position
            rng_state, rng_passenger_idx = jax.random.split(rng_taxi)
            passenger_idx = jax.random.randint(rng_passenger_idx, (), 0, n_passengers)
            passenger_positions = passenger_positions.at[passenger_idx].set(taxi_pos)
            
            # Place other passengers at different positions
            used_positions = jnp.zeros((grid_size, grid_size), dtype=jnp.bool_)
            used_positions = used_positions.at[taxi_pos[0], taxi_pos[1]].set(True)
            
            def init_passenger(state, i):
                passenger_positions, used_positions, rng = state
                rng, rng_pos = jax.random.split(rng)
                
                # Skip the randomly selected passenger as it's already placed
                return jax.lax.cond(
                    i == passenger_idx, 
                    lambda: ((passenger_positions, used_positions, rng), i),
                    lambda: (init_passenger_at_different_pos(passenger_positions, used_positions, rng, i, grid_size), i)
                )
            
            (passenger_positions, _, _), _ = jax.lax.scan(
                init_passenger,
                (passenger_positions, used_positions, rng_passenger),
                jnp.arange(n_passengers)
            )
        else:
            passenger_positions = initial_passenger_positions.astype(taxi_pos.dtype)
            # Doesn't matter which passenger is at taxi position here
            # But we need passenger_idx defined for init_passenger
            passenger_idx = 0
        
        return taxi_pos, passenger_positions, jnp.array(0, dtype=jnp.int32)
    
    def init_case_2():
        # Case 2: Passenger in taxi
        if initial_taxi_pos is None:
            taxi_pos = jax.random.randint(rng_taxi, (2,), 0, grid_size)
        else:
            taxi_pos = initial_taxi_pos
        
        if initial_passenger_positions is None:
            passenger_positions = jnp.zeros((n_passengers, 2), dtype=taxi_pos.dtype)
            # Randomly select which passenger to place at taxi position
            rng_state, rng_passenger_idx = jax.random.split(rng_taxi)
            passenger_idx = jax.random.randint(rng_passenger_idx, (), 0, n_passengers)
            passenger_positions = passenger_positions.at[passenger_idx].set(taxi_pos)
            
            # Place other passengers at different positions
            used_positions = jnp.zeros((grid_size, grid_size), dtype=jnp.bool_)
            used_positions = used_positions.at[taxi_pos[0], taxi_pos[1]].set(True)
            
            def init_passenger(state, i):
                passenger_positions, used_positions, rng = state
                rng, rng_pos = jax.random.split(rng)
                
                # Skip the randomly selected passenger as it's already placed
                return jax.lax.cond(
                    i == passenger_idx,
                    lambda: ((passenger_positions, used_positions, rng), i),
                    lambda: (init_passenger_at_different_pos(passenger_positions, used_positions, rng, i, grid_size), i)
                )
            
            (passenger_positions, _, _), _ = jax.lax.scan(
                init_passenger,
                (passenger_positions, used_positions, rng_passenger),
                jnp.arange(n_passengers)
            )
        else:
            passenger_positions = initial_passenger_positions.astype(taxi_pos.dtype)
            # We need to determine which passenger is in the taxi
            # Let's select the first passenger by default
            passenger_idx = 0
        
        return taxi_pos, passenger_positions, jnp.array(passenger_idx + 1, dtype=jnp.int32)
    
    # Use jax.lax.switch to select the initialization function
    taxi_pos, passenger_positions, taxi_passenger = jax.lax.switch(
        case_idx,
        [init_case_0, init_case_1, init_case_2]
    )
    
    passenger_colors = jnp.arange(n_passengers)
    
    # Define a function to sample actions based on the current state
    def sample_action(taxi_pos, passenger_positions, taxi_passenger, rng):
        rng_policy, rng_passenger, rng_action = jax.random.split(rng, 3)
        
        # Check if pickup would be successful
        def can_pickup():
            return jnp.any(jnp.all(passenger_positions == taxi_pos, axis=1)) & (taxi_passenger == 0)
        
        # Check if dropoff would be successful
        def can_dropoff():
            mask = jnp.arange(passenger_positions.shape[0]) + 1 == taxi_passenger
            no_passenger_collision = ~jnp.any(jnp.all(passenger_positions == taxi_pos, axis=1) & ~mask)
            return jax.lax.cond(
                taxi_passenger == 0,
                lambda _: jnp.array(False),
                lambda _: no_passenger_collision,
                operand=None
            )
        
        # 1. RANDOM POLICY: Select a completely random action
        random_action = jax.random.randint(rng_action, (), 0, n_actions)
        
        # 2. MOVE TOWARDS PASSENGER POLICY
        # Select a random passenger to move towards
        passenger_idx = jax.random.randint(rng_passenger, (), 0, n_passengers)
        target_pos = passenger_positions[passenger_idx]
        
        # Compute direction to move towards target passenger
        dx = target_pos[0] - taxi_pos[0]
        dy = target_pos[1] - taxi_pos[1]
        
        # Determine movement direction based on the largest distance component
        # If x distance is larger, move horizontally first
        move_towards_action = jax.lax.cond(
            jnp.abs(dx) > jnp.abs(dy),
            lambda: jax.lax.cond(
                dx > 0,
                lambda: 1,  # Move right
                lambda: 2,  # Move left
            ),
            lambda: jax.lax.cond(
                dy > 0,
                lambda: 3,  # Move down
                lambda: 4,  # Move up
            )
        )
        
        # If already at target position, choose random movement
        move_towards_action = jax.lax.cond(
            jnp.all(taxi_pos == target_pos),
            lambda: jax.random.randint(rng_passenger, (), 1, 5),  # Random movement (exclude no-op)
            lambda: move_towards_action
        )
        
        # 3. PICKUP POLICY
        pickup_action = 5
        
        # 4. DROPOFF POLICY
        dropoff_action = 6
        
        # Select which policy to use with equal probability
        policy_selector = jax.random.randint(rng_policy, (), 0, 4)
        
        # Final action selection logic
        action = jax.lax.switch(
            policy_selector,
            [
                # Policy 0: Random action
                lambda: random_action,
                # Policy 1: Move towards passenger
                lambda: move_towards_action,
                # Policy 2: Pickup if possible, otherwise move towards
                lambda: jax.lax.cond(
                    can_pickup(),
                    lambda: pickup_action,
                    lambda: move_towards_action
                ),
                # Policy 3: Dropoff if possible, otherwise random
                lambda: jax.lax.cond(
                    can_dropoff(),
                    lambda: dropoff_action,
                    lambda: random_action
                ),
            ]
        )
        
        return action, rng_policy

    # Decide which trajectory generation approach to use based on custom_actions
    has_custom_actions = custom_actions is not None
    
    # When using custom_actions, create a special step function
    def step_with_custom_actions(state_and_idx, _):
        state, action_idx = state_and_idx
        taxi_pos, passenger_positions, taxi_passenger, rng = state
        action = custom_actions[action_idx]
        rng, rng_noise = jax.random.split(rng)
        
        # Make sure passenger positions have the same dtype as taxi_pos
        passenger_positions = passenger_positions.astype(taxi_pos.dtype)
        
        # Handle movement actions - discrete grid movements
        def handle_movement(action, taxi_pos):
            return jax.lax.cond(
                jnp.logical_and(action >= 1, action <= 4),
                lambda: jnp.clip(
                    taxi_pos + action_effects[action],
                    0, grid_size - 1
                ),
                lambda: taxi_pos
            )
        
        # Handle pickup action
        def handle_pickup(taxi_pos, passenger_positions, taxi_passenger):
            def pickup_condition(i):
                return jnp.all(passenger_positions[i] == taxi_pos) & (taxi_passenger == 0)
            
            def pickup_update(i):
                return i + 1
            
            def scan_pickup(taxi_passenger, i):
                return jax.lax.cond(
                    pickup_condition(i),
                    lambda: pickup_update(i),
                    lambda: taxi_passenger
                ), taxi_passenger
            
            taxi_passenger, _ = jax.lax.scan(
                scan_pickup,
                taxi_passenger,
                jnp.arange(n_passengers)
            )
            return taxi_passenger
        
        # Handle dropoff action
        def handle_dropoff(taxi_passenger, taxi_pos, passenger_positions):
            def can_dropoff():
                mask = jnp.arange(passenger_positions.shape[0]) + 1 == taxi_passenger
                return ~jnp.any(jnp.all(passenger_positions == taxi_pos, axis=1) & ~mask)
            
            return jax.lax.cond(
                taxi_passenger > 0,
                lambda: jax.lax.cond(
                    can_dropoff(),
                    lambda: 0,
                    lambda: taxi_passenger
                ),
                lambda: taxi_passenger
            )
        
        # Update taxi position
        new_taxi_pos = handle_movement(action, taxi_pos)
        
        # Update taxi passenger state based on actions
        new_taxi_passenger = jax.lax.cond(
            action == 5,
            lambda: handle_pickup(new_taxi_pos, passenger_positions, taxi_passenger),
            lambda: jax.lax.cond(
                action == 6,
                lambda: handle_dropoff(taxi_passenger, new_taxi_pos, passenger_positions),
                lambda: taxi_passenger
            )
        )
        
        # Update passenger positions
        new_passenger_positions = jnp.where(
            (jnp.arange(n_passengers) + 1 == new_taxi_passenger)[:, None],
            new_taxi_pos[None],
            passenger_positions
        )
        
        return ((new_taxi_pos, new_passenger_positions, new_taxi_passenger, rng), action_idx + 1), \
               (new_taxi_pos, new_passenger_positions, new_taxi_passenger, action)
    
    # When dynamically sampling actions, use this step function
    def step_with_sampling(state, _):
        taxi_pos, passenger_positions, taxi_passenger, rng = state
        rng, rng_step, rng_noise = jax.random.split(rng, 3)
        
        # Make sure passenger positions have the same dtype as taxi_pos
        passenger_positions = passenger_positions.astype(taxi_pos.dtype)
        
        # Sample action based on current state
        action, rng_step = sample_action(taxi_pos, passenger_positions, taxi_passenger, rng_step)
        
        # Handle movement actions - discrete grid movements
        def handle_movement(action, taxi_pos):
            return jax.lax.cond(
                jnp.logical_and(action >= 1, action <= 4),
                lambda: jnp.clip(
                    taxi_pos + action_effects[action],
                    0, grid_size - 1
                ),
                lambda: taxi_pos
            )
        
        # Handle pickup action
        def handle_pickup(taxi_pos, passenger_positions, taxi_passenger):
            def pickup_condition(i):
                # Check if passenger and taxi are at exactly the same position and taxi is empty
                return jnp.all(passenger_positions[i] == taxi_pos) & (taxi_passenger == 0)
            
            def pickup_update(i):
                return i + 1  # Return passenger index (1-based)
            
            # Try to pickup each passenger
            def scan_pickup(taxi_passenger, i):
                return jax.lax.cond(
                    pickup_condition(i),
                    lambda: pickup_update(i),
                    lambda: taxi_passenger
                ), taxi_passenger
            
            taxi_passenger, _ = jax.lax.scan(
                scan_pickup,
                taxi_passenger,
                jnp.arange(n_passengers)
            )
            return taxi_passenger
        
        # Handle dropoff action
        def handle_dropoff(taxi_passenger, taxi_pos, passenger_positions):
            # If there's a passenger in the taxi, check if we can drop them off
            def can_dropoff():
                mask = jnp.arange(passenger_positions.shape[0]) + 1 == taxi_passenger
                # Check if any other passenger is at the taxi's position
                return ~jnp.any(jnp.all(passenger_positions == taxi_pos, axis=1) & ~mask)
            
            # Only dropoff if the cell is not occupied
            return jax.lax.cond(
                taxi_passenger > 0,
                lambda: jax.lax.cond(
                    can_dropoff(),
                    lambda: 0,  # Return to empty state
                    lambda: taxi_passenger  # Keep passenger in taxi if cell is occupied
                ),
                lambda: taxi_passenger
            )
        
        # Update taxi position
        new_taxi_pos = handle_movement(action, taxi_pos)
        
        # Update taxi passenger state based on actions
        new_taxi_passenger = jax.lax.cond(
            action == 5,  # Pickup
            lambda: handle_pickup(new_taxi_pos, passenger_positions, taxi_passenger),
            lambda: jax.lax.cond(
                action == 6,  # Dropoff
                lambda: handle_dropoff(taxi_passenger, new_taxi_pos, passenger_positions),
                lambda: taxi_passenger
            )
        )
        
        # Update passenger positions if they are in the taxi
        new_passenger_positions = jnp.where(
            (jnp.arange(n_passengers) + 1 == new_taxi_passenger)[:, None],
            new_taxi_pos[None],
            passenger_positions
        )
        
        return (new_taxi_pos, new_passenger_positions, new_taxi_passenger, rng), \
               (new_taxi_pos, new_passenger_positions, new_taxi_passenger, action)
    
    # Run the appropriate trajectory generation based on whether custom_actions is provided
    if has_custom_actions:
        init_state = ((taxi_pos, passenger_positions, taxi_passenger, rng_rollout), 0)
        _, (taxi_history, passenger_positions_history, taxi_passenger_history, action_history) = jax.lax.scan(
            step_with_custom_actions,
            init_state,
            None,
            length=horizon
        )
        actions = custom_actions
    else:
        _, (taxi_history, passenger_positions_history, taxi_passenger_history, action_history) = jax.lax.scan(
            step_with_sampling,
            (taxi_pos, passenger_positions, taxi_passenger, rng_rollout),
            None,
            length=horizon
        )
        actions = action_history

    # append initial state to history
    taxi_history = jnp.concatenate([taxi_pos[None], taxi_history], axis=0)
    passenger_positions_history = jnp.concatenate([passenger_positions[None], passenger_positions_history], axis=0)
    taxi_passenger_history = jnp.concatenate([taxi_passenger[None], taxi_passenger_history], axis=0)
    
    # Render states as images
    def render_state(taxi_pos, passenger_positions, passenger_colors, taxi_passenger):
        # Convert integer positions to normalized coordinates for rendering
        taxi_pos_norm = taxi_pos.astype(jnp.float32) / (grid_size - 1)
        passenger_positions_norm = passenger_positions.astype(jnp.float32) / (grid_size - 1)
        return render_taxi_to_image(taxi_pos_norm, passenger_positions_norm, passenger_colors, taxi_passenger, grid_size, img_size)
    
    observations = jax.vmap(render_state, in_axes=(0, 0, None, 0))(taxi_history, passenger_positions_history, passenger_colors, taxi_passenger_history)
    observations = observations.astype(jnp.float32) / 255.0

    # Add pixel noise to observations
    noise = jax.random.normal(rng_noise, observations.shape) * 0.05  # 5% noise
    observations = jnp.clip(observations + noise, 0.0, 1.0)
    
    # Compute rewards based on successful dropoffs at specific locations
    # For simplicity, we just use zeros for now, but could be modified later
    rewards = jnp.zeros(horizon+1)
    dones = jnp.zeros(horizon+1, dtype=jnp.bool_)

    @chex.dataclass(frozen=True)
    class EnvConfig:
        obs: chex.Array
        n_actions: int
        grid_size: int

    # Normalize states only at the end
    normalized_taxi_history = taxi_history.astype(jnp.float32) / (grid_size - 1)
    normalized_passenger_history = passenger_positions_history.astype(jnp.float32) / (grid_size - 1)
    
    # Flatten state information
    flattened_states = jnp.concatenate([
        normalized_taxi_history,
        normalized_passenger_history.reshape(horizon+1, -1),
        taxi_passenger_history[:, None] / n_passengers  # Add taxi passenger state
    ], axis=1)
    
    # Create environment config
    config = EnvConfig(obs=observations[0], n_actions=n_actions, grid_size=grid_size)
    
    return TransitionData(
        obs=observations,
        action=actions,
        reward=rewards,
        done=dones,
        is_first=jnp.zeros_like(rewards),
        state=flattened_states
    ), config

def init_passenger_at_different_pos(passenger_positions, used_positions, rng, i, grid_size):
    """Helper function to initialize a passenger at a different position than the taxi."""
    rng, rng_pos = jax.random.split(rng)
    
    # Find available positions
    available_positions = ~used_positions
    probs =  (available_positions.astype(jnp.float32) / available_positions.astype(jnp.float32).sum())
    sample_flat_pos = jax.random.categorical(rng_pos, jnp.log(probs).flatten())
    sample_pos = jnp.unravel_index(sample_flat_pos, available_positions.shape)

    available_count = available_positions.sum()
    pos = jax.lax.cond(
        available_count == 0,
        lambda: jnp.array([-1, -1]),
        lambda: jnp.array(sample_pos)
    )
    
    # Update used positions
    used_positions = used_positions.at[pos[0], pos[1]].set(True)
    # Update passenger position
    passenger_positions = passenger_positions.at[i].set(pos)
    
    return (passenger_positions, used_positions, rng)

def test_multi_object_rendering():
    """Test function to visualize the multi-object dataset generation and rendering."""
    import matplotlib.pyplot as plt
    
    # Generate a sample dataset with 3 objects
    horizon = 10
    rng = jax.random.key(0)
    dataset, env_config = jax.vmap(generate_multi_object_dataset, in_axes=(None, 0, None, None))(horizon, jax.random.split(rng), 3, 150)
    
    # Create a figure with subplots to show the sequence
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.ravel()
    
    # Plot each timestep
    for t in range(min(horizon, 10)):
        # Convert observation to numpy and plot
        obs = np.array(dataset.obs[0,t])
        axes[t].imshow(obs)
        axes[t].set_title(f'Step {t}\nAction: {dataset.action[0,t]}')
        axes[t].axis('off')
    
    # Save the figure
    plt.tight_layout()
    plt.savefig('multi_object_test.png')
    print("Test visualization saved as 'multi_object_test.png'")
    
    # Print some information about the dataset
    print(f"\nDataset information:")
    print(f"Number of actions: {env_config.n_actions}")
    print(f"Observation shape: {dataset.obs.shape}")
    print(f"Action shape: {dataset.action.shape}")
    print(f"Reward shape: {dataset.reward.shape}")
    print(f"State shape: {dataset.state.shape}")

def test_multi_object_rendering_with_selection():
    """Test function to visualize the multi-object dataset generation with selection."""
    import matplotlib.pyplot as plt
    
    # Generate a sample dataset with 3 objects
    horizon = 10
    rng = jax.random.key(0)
    dataset, env_config = jax.vmap(generate_multi_object_dataset_with_selection, in_axes=(None, 0, None, None))(horizon, jax.random.split(rng), 3, 150)
    
    # Create a figure with subplots to show the sequence
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.ravel()
    
    # Plot each timestep
    for t in range(min(horizon, 10)):
        # Convert observation to numpy and plot
        obs = np.array(dataset.obs[0,t])
        axes[t].imshow(obs)
        axes[t].set_title(f'Step {t}\nAction: {dataset.action[0,t]}')
        axes[t].axis('off')
    
    # Save the figure
    plt.tight_layout()
    plt.savefig('multi_object_test_with_selection.png')
    print("Test visualization saved as 'multi_object_test_with_selection.png'")
    
    # Print some information about the dataset
    print(f"\nDataset information:")
    print(f"Number of actions: {env_config.n_actions}")
    print(f"Observation shape: {dataset.obs.shape}")
    print(f"Action shape: {dataset.action.shape}")
    print(f"Reward shape: {dataset.reward.shape}")
    print(f"State shape: {dataset.state.shape}")

def test_taxi_rendering():
    """Test function to visualize the taxi environment generation."""
    import matplotlib.pyplot as plt
    
    # Generate a sample dataset with specific actions to demonstrate pickup and dropoff
    horizon = 10
    rng = jax.random.key(0)
    
    # Create specific actions to demonstrate pickup and dropoff
    # 0: no-op, 1-4: move (right, left, down, up), 5: pickup, 6: dropoff
    actions = jnp.array([
        1,  # Move right
        1,  # Move right
        3,  # Move down
        3,  # Move down
        5,  # Pickup first passenger
        4,  # Move up
        4,  # Move up
        2,  # Move left
        6,  # Dropoff first passenger
        3,  # Move down
    ])
    
    # Set up initial positions for clear visualization
    grid_size = 8
    # Place taxi at (0,0)
    initial_taxi_pos = jnp.array([0, 0], dtype=jnp.int32)
    # Place two passengers at different positions
    initial_passenger_positions = jnp.array([
        [2, 2],  # First passenger at (2,2)
        [6, 6]   # Second passenger at (6,6)
    ], dtype=jnp.int32)
    
    # Generate dataset with specific actions and initial positions
    dataset, env_config = jax.vmap(generate_taxi_dataset, in_axes=(None, 0, None, None, None, None, None, None))(
        horizon, jax.random.split(rng), grid_size, 2, 150, actions, initial_taxi_pos, initial_passenger_positions
    )
    
    # Create a figure with subplots to show the sequence
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.ravel()
    
    # Plot each timestep
    for t in range(min(horizon, 10)):
        # Convert observation to numpy and plot
        obs = np.array(dataset.obs[0,t])
        axes[t].imshow(obs)
        axes[t].set_title(f'Step {t}\nAction: {actions[t]}')
        axes[t].axis('off')
    

    print(dataset.state[0])
    # Save the figure
    plt.tight_layout()
    plt.savefig('taxi_test.png')
    print("Test visualization saved as 'taxi_test.png'")
    
    # Print some information about the dataset
    print(f"\nDataset information:")
    print(f"Number of actions: {env_config.n_actions}")
    print(f"Observation shape: {dataset.obs.shape}")
    print(f"Action shape: {dataset.action.shape}")
    print(f"Reward shape: {dataset.reward.shape}")
    print(f"State shape: {dataset.state.shape}")

def test_taxi_passenger_overlap():
    """Test function to visualize what happens when a passenger in the taxi overlaps with another passenger."""
    import matplotlib.pyplot as plt
    
    # Generate a sample dataset with specific actions to demonstrate passenger overlap
    horizon = 10
    rng = jax.random.key(0)
    
    # Create specific actions to demonstrate pickup and overlap
    # 0: no-op, 1-4: move (right, left, down, up), 5: pickup, 6: dropoff
    actions = jnp.array([
        1,  # Move right
        1,  # Move right
        3,  # Move down
        3,  # Move down
        5,  # Pickup first passenger
        1,  # Move right to overlap with second passenger
        1,  # Move right to overlap with second passenger
        3,  # Move down to overlap with second passenger
        3,  # Move down to overlap with second passenger
        6,  # Dropoff first passenger
    ])
    
    # Set up initial positions for clear visualization
    grid_size = 8
    # Place taxi at (0,0)
    initial_taxi_pos = jnp.array([0, 0], dtype=jnp.int32)
    # Place two passengers at the same position (2,2)
    initial_passenger_positions = jnp.array([
        [2, 2],  # First passenger at (2,2)
        [3, 3]   # Second passenger at (3,3)
    ], dtype=jnp.int32)
    
    # Generate dataset with specific actions and initial positions
    dataset, env_config = jax.vmap(generate_taxi_dataset, in_axes=(None, 0, None, None, None, None, None, None))(
        horizon, jax.random.split(rng), grid_size, 2, 150, actions, initial_taxi_pos, initial_passenger_positions
    )
    
    # Create a figure with subplots to show the sequence
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.ravel()
    
    # Plot each timestep
    for t in range(min(horizon, 10)):
        # Convert observation to numpy and plot
        obs = np.array(dataset.obs[0,t])
        axes[t].imshow(obs)
        axes[t].set_title(f'Step {t}\nAction: {actions[t]}')
        axes[t].axis('off')
    
    # Save the figure
    plt.tight_layout()
    plt.savefig('taxi_passenger_overlap_test.png')
    print("Test visualization saved as 'taxi_passenger_overlap_test.png'")
    
    # Print some information about the dataset
    print(f"\nDataset information:")
    print(f"Number of actions: {env_config.n_actions}")
    print(f"Observation shape: {dataset.obs.shape}")
    print(f"Action shape: {dataset.action.shape}")
    print(f"Reward shape: {dataset.reward.shape}")
    print(f"State shape: {dataset.state.shape}")

def test_taxi_dropoff_occupied():
    """Test function to visualize what happens when trying to drop off a passenger in an occupied cell."""
    import matplotlib.pyplot as plt
    
    # Generate a sample dataset with specific actions to demonstrate pickup and failed dropoff
    horizon = 10
    rng = jax.random.key(0)
    
    # Create specific actions to demonstrate pickup and failed dropoff
    # 0: no-op, 1-4: move (right, left, down, up), 5: pickup, 6: dropoff
    actions = jnp.array([
        1,  # Move right
        1,  # Move right
        3,  # Move down
        3,  # Move down
        5,  # Pickup first passenger
        1,  # Move right
        1,  # Move right
        3,  # Move down
        3,  # Move down
        6,  # Try to dropoff at second passenger's position (should fail)
    ])
    
    # Set up initial positions for clear visualization
    grid_size = 8
    # Place taxi at (0,0)
    initial_taxi_pos = jnp.array([0, 0], dtype=jnp.int32)
    # Place two passengers at different positions
    initial_passenger_positions = jnp.array([
        [2, 2],  # First passenger at (2,2)
        [4, 4]   # Second passenger at (4,4)
    ], dtype=jnp.int32)
    
    # Generate dataset with specific actions and initial positions
    dataset, env_config = jax.vmap(generate_taxi_dataset, in_axes=(None, 0, None, None, None, None, None, None))(
        horizon, jax.random.split(rng), grid_size, 2, 150, actions, initial_taxi_pos, initial_passenger_positions
    )
    
    # Create a figure with subplots to show the sequence
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.ravel()
    
    # Plot each timestep
    for t in range(min(horizon, 10)):
        # Convert observation to numpy and plot
        obs = np.array(dataset.obs[0,t])
        axes[t].imshow(obs)
        axes[t].set_title(f'Step {t}\nAction: {actions[t]}')
        axes[t].axis('off')
    print(dataset.state[0])
    # Save the figure
    plt.tight_layout()
    plt.savefig('taxi_dropoff_occupied_test.png')
    print("Test visualization saved as 'taxi_dropoff_occupied_test.png'")
    
    # Print some information about the dataset
    print(f"\nDataset information:")
    print(f"Number of actions: {env_config.n_actions}")
    print(f"Observation shape: {dataset.obs.shape}")
    print(f"Action shape: {dataset.action.shape}")
    print(f"Reward shape: {dataset.reward.shape}")
    print(f"State shape: {dataset.state.shape}")

def test_taxi_random_positions():
    """Test function to visualize the taxi environment with random initial positions."""
    import matplotlib.pyplot as plt
    
    # Generate multiple samples to show different initial states
    n_samples = 3
    horizon = 10
    rng = jax.random.key(0)
    
    # Create specific actions to demonstrate pickup and dropoff
    # 0: no-op, 1-4: move (right, left, down, up), 5: pickup, 6: dropoff
    actions = jnp.array([
        1,  # Move right
        1,  # Move right
        3,  # Move down
        3,  # Move down
        5,  # Pickup first passenger
        4,  # Move up
        4,  # Move up
        2,  # Move left
        6,  # Dropoff first passenger
        3,  # Move down
    ])
    
    # Generate dataset with random initial positions
    dataset, env_config = jax.vmap(generate_taxi_dataset, in_axes=(None, 0, None, None, None, None, None, None))(
        horizon, jax.random.split(rng, n_samples), 8, 2, 150, actions, None, None
    )
    
    # Create a figure with subplots to show the sequence
    fig, axes = plt.subplots(n_samples, 5, figsize=(25, 5*n_samples))
    
    # Plot each timestep for each sample
    for s in range(n_samples):
        for t in range(min(horizon, 5)):
            # Convert observation to numpy and plot
            obs = np.array(dataset.obs[s,t])
            axes[s,t].imshow(obs)
            axes[s,t].set_title(f'Sample {s+1}\nStep {t}\nAction: {actions[t]}')
            axes[s,t].axis('off')
    
    # Save the figure
    plt.tight_layout()
    plt.savefig('taxi_random_positions_test.png')
    print("Test visualization saved as 'taxi_random_positions_test.png'")
    
    # Print some information about the dataset
    print(f"\nDataset information:")
    print(f"Number of actions: {env_config.n_actions}")
    print(f"Observation shape: {dataset.obs.shape}")
    print(f"Action shape: {dataset.action.shape}")
    print(f"Reward shape: {dataset.reward.shape}")
    print(f"State shape: {dataset.state.shape}")
    
    # Print initial states to verify different cases
    print("\nInitial states for each sample:")
    for s in range(n_samples):
        print(f"\nSample {s+1}:")
        print(f"Initial state: {dataset.state[s,0]}")

def test_taxi_initial_states():
    """Test function to visualize different initial states in the taxi environment."""
    import matplotlib.pyplot as plt
    
    # Generate a sample dataset for each initial state type
    horizon = 5
    rng = jax.random.key(0)
    
    # Create specific actions to demonstrate the environment
    actions = jnp.array([
        1,  # Move right
        3,  # Move down
        5,  # Pickup (for cases where passenger is not in taxi)
        4,  # Move up
        6,  # Dropoff
    ])
    
    # Set up initial positions for clear visualization
    grid_size = 8
    
    # Create a figure with subplots for each case
    fig, axes = plt.subplots(3, 5, figsize=(25, 15))
    
    # Test Case 1: Different positions
    dataset1, _ = jax.vmap(generate_taxi_dataset, in_axes=(None, 0, None, None, None, None, None, None))(
        horizon, jax.random.split(rng), grid_size, 3, 150, actions, None, None
    )
    
    # Test Case 2: Same position
    dataset2, _ = jax.vmap(generate_taxi_dataset, in_axes=(None, 0, None, None, None, None, None, None))(
        horizon, jax.random.split(rng), grid_size, 3, 150, actions, None, None
    )
    
    # Test Case 3: Passenger in taxi
    dataset3, _ = jax.vmap(generate_taxi_dataset, in_axes=(None, 0, None, None, None, None, None, None))(
        horizon, jax.random.split(rng), grid_size, 3, 150, actions, None, None
    )
    
    # Plot each case
    for t in range(horizon):
        # Case 1: Different positions
        obs1 = np.array(dataset1.obs[0,t])
        axes[0,t].imshow(obs1)
        axes[0,t].set_title(f'Case 0: Different\nStep {t}\nAction: {actions[t]}')
        axes[0,t].axis('off')
        
        # Case 2: Same position
        obs2 = np.array(dataset2.obs[0,t])
        axes[1,t].imshow(obs2)
        axes[1,t].set_title(f'Case 1: Same Position\nStep {t}\nAction: {actions[t]}')
        axes[1,t].axis('off')
        
        # Case 3: Passenger in taxi
        obs3 = np.array(dataset3.obs[0,t])
        axes[2,t].imshow(obs3)
        axes[2,t].set_title(f'Case 2: In Taxi\nStep {t}\nAction: {actions[t]}')
        axes[2,t].axis('off')
    
    # Save the figure
    plt.tight_layout()
    plt.savefig('taxi_initial_states_test.png')
    print("Test visualization saved as 'taxi_initial_states_test.png'")
    
    # Print state information for each case
    print("\nCase 0 - Different Positions:")
    print(f"Initial state: {(dataset1.state[0] * (grid_size-1)).astype(int)}")
    print("\nCase 1 - Same Position:")
    print(f"Initial state: {(dataset2.state[0] * (grid_size-1)).astype(int)}")
    print("\nCase 2 - Passenger in Taxi:")
    print(f"Initial state: {(dataset3.state[0] * (grid_size-1)).astype(int)}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--test", type=str, required=True, help="Test which dataset generator")
    args = parser.parse_args()

    if args.test == "multi_object":
        test_multi_object_rendering()
    elif args.test == "multi_object_selection":
        test_multi_object_rendering_with_selection()
    elif args.test == "taxi":
        test_taxi_rendering()