import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import chex
from datasets.dataset_generators import generate_taxi_dataset
from datasets import TransitionData

def render_taxi_with_failure(img, taxi_pos, passenger_positions, passenger_colors, taxi_passenger, failure_type, grid_size=8, image_size=40):
    '''
    Renders the taxi environment with failure signals.
    failure_type: 0 (no failure), 1 (pickup failed), 2 (dropoff failed)
    '''
    # First render the normal taxi image
    
    # Add failure signals if needed
    def add_pickup_failure(img):
        # Add a yellow+black X in the center of the taxi cell (instead of red)
        x, y = taxi_pos
        cell_size = image_size / grid_size
        
        # Calculate the center of the grid cell more precisely
        # First convert normalized position to grid cell index
        grid_x = jnp.floor(x * (grid_size - 1)).astype(jnp.int32)
        grid_y = jnp.floor(y * (grid_size - 1)).astype(jnp.int32)
        
        # Then calculate the center of that cell in pixel coordinates
        center_x = jnp.round((grid_x + 0.5) * cell_size).astype(jnp.int32)
        center_y = jnp.round((grid_y + 0.5) * cell_size).astype(jnp.int32)
        
        # Use a fixed size for the X mark
        x_size = 5  # Fixed size that works well with the grid
        
        # Draw a simple fixed-thickness yellow-black X using JAX's functional approach
        def draw_x_pixel(img, i):
            # Calculate position on both diagonals
            x_idx1 = jnp.clip(center_x + i, 0, image_size - 1)
            y_idx1 = jnp.clip(center_y + i, 0, image_size - 1)
            x_idx2 = jnp.clip(center_x - i, 0, image_size - 1)
            y_idx2 = jnp.clip(center_y + i, 0, image_size - 1)
            
            # Draw main X with yellow color
            img = img.at[y_idx1, x_idx1].set(jnp.array([255, 255, 0], dtype=jnp.uint8))
            img = img.at[y_idx2, x_idx2].set(jnp.array([255, 255, 0], dtype=jnp.uint8))
            
            # Add thickness to the center of the X (make the middle part thicker)
            # Avoid boolean conditionals by using multiplicative masks
            # Compute a weight that's higher near center and 0 at edges
            weight = jnp.maximum(0.0, 1.0 - jnp.abs(i) / (x_size * 0.7))
            
            # Only add thickness where weight > 0, and scale by weight
            # Create a binary mask (0 or 1) for where to add thickness 
            do_thickness = (weight > 0.001).astype(jnp.int32)
            
            # Additional pixels for thickness
            x_idx1a = jnp.clip(center_x + i + do_thickness, 0, image_size - 1)
            y_idx1a = jnp.clip(center_y + i, 0, image_size - 1)
            x_idx2a = jnp.clip(center_x - i - do_thickness, 0, image_size - 1)
            y_idx2a = jnp.clip(center_y + i, 0, image_size - 1)
            
            x_idx1b = jnp.clip(center_x + i, 0, image_size - 1)
            y_idx1b = jnp.clip(center_y + i + do_thickness, 0, image_size - 1)
            x_idx2b = jnp.clip(center_x - i, 0, image_size - 1)
            y_idx2b = jnp.clip(center_y + i + do_thickness, 0, image_size - 1)
            
            # Only set pixels if do_thickness is true (> 0)
            # Use black color for the outline to create contrast
            img = img.at[y_idx1a, x_idx1a].set(
                jnp.where(do_thickness > 0, 
                          jnp.array([0, 0, 0], dtype=jnp.uint8),
                          img[y_idx1a, x_idx1a])
            )
            img = img.at[y_idx2a, x_idx2a].set(
                jnp.where(do_thickness > 0, 
                          jnp.array([0, 0, 0], dtype=jnp.uint8),
                          img[y_idx2a, x_idx2a])
            )
            img = img.at[y_idx1b, x_idx1b].set(
                jnp.where(do_thickness > 0, 
                          jnp.array([0, 0, 0], dtype=jnp.uint8),
                          img[y_idx1b, x_idx1b])
            )
            img = img.at[y_idx2b, x_idx2b].set(
                jnp.where(do_thickness > 0, 
                          jnp.array([0, 0, 0], dtype=jnp.uint8),
                          img[y_idx2b, x_idx2b])
            )
            
            return img, None
        
        # Use scan to iterate over range of positions
        indices = jnp.arange(-x_size, x_size + 1)
        img, _ = jax.lax.scan(draw_x_pixel, img, indices)
        
        return img
    
    def add_dropoff_failure(img):
        # Add a pink/magenta X in the center of the taxi cell (instead of blue)
        x, y = taxi_pos
        cell_size = image_size / grid_size
        
        # Calculate the center of the grid cell more precisely
        # First convert normalized position to grid cell index
        grid_x = jnp.floor(x * (grid_size - 1)).astype(jnp.int32)
        grid_y = jnp.floor(y * (grid_size - 1)).astype(jnp.int32)
        
        # Then calculate the center of that cell in pixel coordinates
        center_x = jnp.round((grid_x + 0.5) * cell_size).astype(jnp.int32)
        center_y = jnp.round((grid_y + 0.5) * cell_size).astype(jnp.int32)
        
        # Use a fixed size for the X mark
        x_size = 5  # Fixed size that works well with the grid
        
        # Draw a simple fixed-thickness magenta X using JAX's functional approach
        def draw_x_pixel(img, i):
            # Calculate position on both diagonals
            x_idx1 = jnp.clip(center_x + i, 0, image_size - 1)
            y_idx1 = jnp.clip(center_y + i, 0, image_size - 1)
            x_idx2 = jnp.clip(center_x - i, 0, image_size - 1)
            y_idx2 = jnp.clip(center_y + i, 0, image_size - 1)
            
            # Draw main X with magenta color
            img = img.at[y_idx1, x_idx1].set(jnp.array([255, 0, 255], dtype=jnp.uint8))
            img = img.at[y_idx2, x_idx2].set(jnp.array([255, 0, 255], dtype=jnp.uint8))
            
            # Add thickness to the center of the X (make the middle part thicker)
            # Avoid boolean conditionals by using multiplicative masks
            # Compute a weight that's higher near center and 0 at edges
            weight = jnp.maximum(0.0, 1.0 - jnp.abs(i) / (x_size * 0.7))
            
            # Only add thickness where weight > 0, and scale by weight
            # Create a binary mask (0 or 1) for where to add thickness 
            do_thickness = (weight > 0.001).astype(jnp.int32)
            
            # Additional pixels for thickness
            x_idx1a = jnp.clip(center_x + i + do_thickness, 0, image_size - 1)
            y_idx1a = jnp.clip(center_y + i, 0, image_size - 1)
            x_idx2a = jnp.clip(center_x - i - do_thickness, 0, image_size - 1)
            y_idx2a = jnp.clip(center_y + i, 0, image_size - 1)
            
            x_idx1b = jnp.clip(center_x + i, 0, image_size - 1)
            y_idx1b = jnp.clip(center_y + i + do_thickness, 0, image_size - 1)
            x_idx2b = jnp.clip(center_x - i, 0, image_size - 1)
            y_idx2b = jnp.clip(center_y + i + do_thickness, 0, image_size - 1)
            
            # Only set pixels if do_thickness is true (> 0)
            # Use white color for the outline to create contrast
            img = img.at[y_idx1a, x_idx1a].set(
                jnp.where(do_thickness > 0, 
                          jnp.array([255, 255, 255], dtype=jnp.uint8),
                          img[y_idx1a, x_idx1a])
            )
            img = img.at[y_idx2a, x_idx2a].set(
                jnp.where(do_thickness > 0, 
                          jnp.array([255, 255, 255], dtype=jnp.uint8),
                          img[y_idx2a, x_idx2a])
            )
            img = img.at[y_idx1b, x_idx1b].set(
                jnp.where(do_thickness > 0, 
                          jnp.array([255, 255, 255], dtype=jnp.uint8),
                          img[y_idx1b, x_idx1b])
            )
            img = img.at[y_idx2b, x_idx2b].set(
                jnp.where(do_thickness > 0, 
                          jnp.array([255, 255, 255], dtype=jnp.uint8),
                          img[y_idx2b, x_idx2b])
            )
            
            return img, None
        
        # Use scan to iterate over range of positions
        indices = jnp.arange(-x_size, x_size + 1)
        img, _ = jax.lax.scan(draw_x_pixel, img, indices)
        
        return img
    
    # Add failure signals based on failure_type
    img = jax.lax.switch(
        failure_type,
        [
            lambda: img,  # No failure
            lambda: add_pickup_failure(img),  # Pickup failed
            lambda: add_dropoff_failure(img)  # Dropoff failed
        ]
    )
    
    return img

def generate_taxi_suff_dataset(horizon, rng, grid_size=8, n_passengers=3, img_size=40, custom_actions=None, 
                             initial_taxi_pos=None, initial_passenger_positions=None, success_rate=0.7):
    # First generate the regular taxi dataset
    
    dataset, env_config = generate_taxi_dataset(
        horizon, rng, grid_size, n_passengers, img_size, custom_actions,
        initial_taxi_pos, initial_passenger_positions, success_rate
    )
    
    # Add failure signals to the observations that need them
    def add_failure_signals(curr_state, next_state, action, obs):
        # Extract state components
        taxi_pos = curr_state[:2]  # Already normalized
        passenger_positions = curr_state[2:2+n_passengers*2].reshape(n_passengers, 2)  # Already normalized
        taxi_passenger_curr = (curr_state[-1] * n_passengers).astype(jnp.int32)
        taxi_passenger_next = (next_state[-1] * n_passengers).astype(jnp.int32)
        
        # Determine failure type by comparing current and next states
        def check_pickup_failure():
            # Pickup action (5) fails if taxi_passenger didn't change
            return (action == 5) & (taxi_passenger_curr == taxi_passenger_next) & (taxi_passenger_curr == 0)
        
        def check_dropoff_failure():
            # Dropoff action (6) fails if taxi_passenger didn't change
            return (action == 6) & ((taxi_passenger_curr == taxi_passenger_next) & (taxi_passenger_curr > 0) | (taxi_passenger_curr == 0))
        
        failure_type = jax.lax.cond(
            check_pickup_failure(),
            lambda: 1,  # Pickup failed
            lambda: jax.lax.cond(
                check_dropoff_failure(),
                lambda: 2,  # Dropoff failed
                lambda: 0   # No failure
            )
        )
        
        # Only re-render if there's a failure
        return jax.lax.cond(
            failure_type > 0,
            lambda: render_taxi_with_failure(
                (obs * 255).astype(jnp.uint8),taxi_pos, passenger_positions, jnp.arange(n_passengers),
                taxi_passenger_curr, failure_type, grid_size, img_size
            ).astype(jnp.float32) / 255.0,
            lambda: obs  # Keep the original observation if no failure
        )
    
    # For t=0, no action has been taken yet, so no failure possible
    # Keep the original observation
    initial_obs = dataset.obs[0]
    
    # Process each state-action-next_state triplet (except for the last state)
    def process_transition(i):
        return add_failure_signals(
            dataset.state[i],      # Current state
            dataset.state[i+1],    # Next state
            dataset.action[i],     # Action taken
            dataset.obs[i+1]       # Original observation
        )
    
    # Apply the function to each transition (from 0 to horizon-2)
    processed_obs = jax.vmap(process_transition)(jnp.arange(horizon-1))
    
    # The last observation has no next state, so no failure can be determined
    # Keep the original last observation
    last_obs = dataset.obs[-1]
    
    # Combine all observations: initial, processed middle ones, and last
    new_observations = jnp.concatenate([
        initial_obs[None],       # t=0
        processed_obs,           # t=1 to t=horizon-1
        last_obs[None]           # t=horizon
    ], axis=0)

    # Create new dataset with updated observations
    return TransitionData(
        obs=new_observations,
        action=dataset.action,
        reward=dataset.reward,
        done=dataset.done,
        is_first=dataset.is_first,
        state=dataset.state
    ), env_config

def test_taxi_suff_rendering():
    """Test function to visualize the taxi environment with failure signals."""
    import matplotlib.pyplot as plt
    
    # Generate a sample dataset with specific actions to demonstrate failures
    horizon = 10
    rng = jax.random.key(0)
    
    # Create specific actions to demonstrate pickup and dropoff failures
    # with clear taxi movement
    actions = jnp.array([
        1,  # Move right (step 0)
        3,  # Move down (step 1)
        5,  # Try pickup (should fail - no passenger) (step 2)
        1,  # Move right (step 3)
        3,  # Move down (step 4)
        5,  # Try pickup (step 5) - succeed at passenger position
        2,  # Move left (step 6)
        4,  # Move up (step 7)
        6,  # Try dropoff (step 8)
        0,  # No-op (step 9)
    ])
    
    # Set up initial positions for clear visualization
    grid_size = 8
    initial_taxi_pos = jnp.array([0, 0], dtype=jnp.int32)
    initial_passenger_positions = jnp.array([
        [1, 2],  # First passenger at (1,2)
        [4, 4]   # Second passenger at (4,4)
    ], dtype=jnp.int32)
    
    # Generate dataset with specific actions and initial positions
    dataset, env_config = jax.vmap(generate_taxi_suff_dataset, in_axes=(None, 0, None, None, None, None, None, None))(
        horizon, jax.random.split(rng), grid_size, 2, 150, actions, initial_taxi_pos, initial_passenger_positions
    )
    
    # Create a figure with subplots to show the sequence
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.ravel()
    
    # Plot each timestep
    for t in range(min(horizon, 10)):
        obs = np.array(dataset.obs[0,t])
        axes[t].imshow(obs)
        
        # Add helpful description of what's happening at each step
        action_desc = {
            0: "No-op",
            1: "Move right",
            2: "Move left",
            3: "Move down",
            4: "Move up",
            5: "Try pickup",
            6: "Try dropoff"
        }
        
        # Add additional info about expected failures
        extra_info = ""
        if t == 2:
            extra_info = "\n(Pickup fails - red X)"
        elif t == 8:
            extra_info = "\n(Dropoff at empty spot)"
        
        # Convert JAX array to int for indexing
        action_int = int(actions[t])
        axes[t].set_title(f'Step {t}\n{action_desc[action_int]}{extra_info}')
        axes[t].axis('off')
    
    plt.tight_layout()
    plt.savefig('taxi_suff_test.png')
    print("Test visualization saved as 'taxi_suff_test.png'")

def test_taxi_suff_pickup_failure():
    """Test function to visualize pickup failures."""
    import matplotlib.pyplot as plt
    
    horizon = 5
    rng = jax.random.key(0)
    
    # Actions to demonstrate pickup failure
    actions = jnp.array([
        1,  # Move right (step 0)
        3,  # Move down (step 1) 
        5,  # Try pickup (step 2) - fails, no passenger here
        2,  # Move left (step 3)
        5,  # Try pickup (step 4) - fails, no passenger here
    ])
    
    grid_size = 8
    initial_taxi_pos = jnp.array([0, 0], dtype=jnp.int32)
    initial_passenger_positions = jnp.array([
        [4, 4],  # First passenger at (4,4) - far from taxi path
        [6, 6]   # Second passenger at (6,6) - far from taxi path
    ], dtype=jnp.int32)
    
    dataset, _ = jax.vmap(generate_taxi_suff_dataset, in_axes=(None, 0, None, None, None, None, None, None))(
        horizon, jax.random.split(rng), grid_size, 2, 150, actions, initial_taxi_pos, initial_passenger_positions
    )
    
    fig, axes = plt.subplots(1, 5, figsize=(20, 4))
    axes = axes.ravel()
    
    for t in range(horizon):
        obs = np.array(dataset.obs[0,t])
        axes[t].imshow(obs)
        
        action_desc = {
            0: "No-op",
            1: "Move right",
            2: "Move left",
            3: "Move down",
            4: "Move up",
            5: "Try pickup",
            6: "Try dropoff"
        }
        
        extra_info = ""
        if t == 2 or t == 4:
            extra_info = "\n(Pickup fails - red X)"
        
        # Convert JAX array to int for indexing
        action_int = int(actions[t])
        axes[t].set_title(f'Step {t}\n{action_desc[action_int]}{extra_info}')
        axes[t].axis('off')
    
    plt.tight_layout()
    plt.savefig('taxi_suff_pickup_failure_test.png')
    print("Test visualization saved as 'taxi_suff_pickup_failure_test.png'")

def test_taxi_suff_dropoff_failure():
    """Test function to visualize dropoff failures."""
    import matplotlib.pyplot as plt
    
    horizon = 7
    rng = jax.random.key(0)
    
    # Actions to demonstrate dropoff failure
    actions = jnp.array([
        1,  # Move right (step 0)
        3,  # Move down (step 1)
        1,  # Move right (step 2)
        3,  # Move down (step 3)
        6,  # Try dropoff (step 4) - fails, no passenger in taxi
        5,  # Try pickup (step 5) - pickup passenger at (3,3)
        6,  # Try dropoff (step 6) - succeeds
    ])
    
    grid_size = 8
    initial_taxi_pos = jnp.array([0, 0], dtype=jnp.int32)
    initial_passenger_positions = jnp.array([
        [2, 2],  # First passenger at (3,3) - on taxi path
        [6, 6]   # Second passenger at (6,6) - far from taxi path
    ], dtype=jnp.int32)
    
    dataset, _ = jax.vmap(generate_taxi_suff_dataset, in_axes=(None, 0, None, None, None, None, None, None))(
        horizon, jax.random.split(rng), grid_size, 2, 150, actions, initial_taxi_pos, initial_passenger_positions
    )
    
    fig, axes = plt.subplots(1, 7, figsize=(24, 4))
    axes = axes.ravel()
    
    for t in range(horizon):
        obs = np.array(dataset.obs[0,t])
        axes[t].imshow(obs)
        
        action_desc = {
            0: "No-op",
            1: "Move right",
            2: "Move left",
            3: "Move down",
            4: "Move up",
            5: "Try pickup",
            6: "Try dropoff"
        }
        
        extra_info = ""
        if t == 4:
            extra_info = "\n(Dropoff fails - blue X)"
        elif t == 5:
            extra_info = "\n(Pickup succeeds)"
        elif t == 6:
            extra_info = "\n(Dropoff succeeds)"
        
        # Convert JAX array to int for indexing
        action_int = int(actions[t])
        axes[t].set_title(f'Step {t}\n{action_desc[action_int]}{extra_info}')
        axes[t].axis('off')
    
    plt.tight_layout()
    plt.savefig('taxi_suff_dropoff_failure_test.png')
    print("Test visualization saved as 'taxi_suff_dropoff_failure_test.png'")

def test_taxi_suff_sequence():
    """Test function to visualize a complete sequence with both pickup and dropoff interactions."""
    import matplotlib.pyplot as plt
    
    horizon = 12
    rng = jax.random.key(0)
    
    # Actions to demonstrate full sequence with successful and failed actions
    actions = jnp.array([
        1,  # Move right (step 0)
        3,  # Move down (step 1)
        6,  # Try dropoff (step 2) - fails, no passenger in taxi
        1,  # Move right (step 3)
        5,  # Try pickup (step 4) - fails, no passenger here
        1,  # Move right (step 5)
        3,  # Move down (step 6)
        5,  # Try pickup (step 7) - succeeds at passenger position
        2,  # Move left (step 8)
        2,  # Move left (step 9)
        4,  # Move up (step 10)
        6,  # Try dropoff (step 11) - succeeds
    ])
    
    # Set up initial positions for clear visualization
    grid_size = 8
    initial_taxi_pos = jnp.array([0, 0], dtype=jnp.int32)
    initial_passenger_positions = jnp.array([
        [3, 2],  # First passenger at (3,2)
        [6, 6]   # Second passenger at (6,6)
    ], dtype=jnp.int32)
    
    # Generate dataset with specific actions and initial positions
    dataset, env_config = jax.vmap(generate_taxi_suff_dataset, in_axes=(None, 0, None, None, None, None, None, None))(
        horizon, jax.random.split(rng), grid_size, 2, 150, actions, initial_taxi_pos, initial_passenger_positions
    )
    
    # Create a figure with subplots to show the sequence
    fig, axes = plt.subplots(2, 6, figsize=(24, 8))
    axes = axes.ravel()
    
    # Plot each timestep
    for t in range(horizon):
        obs = np.array(dataset.obs[0,t])
        axes[t].imshow(obs)
        
        # Add helpful description of what's happening at each step
        action_desc = {
            0: "No-op",
            1: "Move right",
            2: "Move left",
            3: "Move down",
            4: "Move up",
            5: "Try pickup",
            6: "Try dropoff"
        }
        
        # Add additional info about expected failures/successes
        extra_info = ""
        if t == 2:
            extra_info = "\n(Dropoff fails - blue X)"
        elif t == 4:
            extra_info = "\n(Pickup fails - red X)"
        elif t == 7:
            extra_info = "\n(Pickup succeeds)"
        elif t == 11:
            extra_info = "\n(Dropoff succeeds)"
        
        # Convert JAX array to int for indexing
        action_int = int(actions[t])
        axes[t].set_title(f'Step {t}\n{action_desc[action_int]}{extra_info}')
        axes[t].axis('off')
    
    plt.tight_layout()
    plt.savefig('taxi_suff_sequence_test.png')
    print("Test visualization saved as 'taxi_suff_sequence_test.png'")

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--test', type=str, choices=['all', 'basic', 'pickup', 'dropoff', 'sequence'], required=True, help='Run the test visualization')
    args = parser.parse_args()
    
    if args.test == 'all':
        test_taxi_suff_rendering()
        test_taxi_suff_pickup_failure()
        test_taxi_suff_dropoff_failure()
        test_taxi_suff_sequence()
    elif args.test == 'basic':
        test_taxi_suff_rendering()
    elif args.test == 'pickup':
        test_taxi_suff_pickup_failure()
    elif args.test == 'dropoff':
        test_taxi_suff_dropoff_failure()
    elif args.test == 'sequence':
        test_taxi_suff_sequence() 