import jax
import jax.numpy as jnp
import chex
import numpy as np

from envs.taxi_gymnax import TaxiGymnax, TaxiParams, TaxiState
from datasets import TransitionData
from typing import List

def generate_taxi_gymnax_dataset(
        horizon,
        rng,
        size=5,
        n_passengers=1,
        img_size=40,
        allow_dropoff_anywhere=True,
        pixel_obs=True,
        random_dropoff_prob=0.5,
        exploring_starts=False
):

    """Generates a dataset from TaxiGymnax environment.
    
    Args:
        horizon (int): Number of timesteps per trajectory
        rng (jax.random.PRNGKey): Random number generator key
        size (int): Size of the grid (5 or 10)
        n_passengers (int): Number of passengers (1-4 for size=5, 1-8 for size=10)
        img_size (int): Size of rendered image if pixel_obs=True
        allow_dropoff_anywhere (bool): If True, allows dropping off passengers anywhere
        pixel_obs (bool): If True, return rendered RGB observations; otherwise return vector observations
        random_dropoff_prob (float): Probability of dropping passenger at random location vs destination
        
    Returns:
        tuple: (TransitionData, EnvConfig) containing the generated dataset and environment configuration
    """
    from src.envs.taxi_gymnax import TaxiGymnax, TaxiParams, TaxiState
    
    # Initialize environment
    render_mode = 1 if pixel_obs else 0
    env = TaxiGymnax(
                    size=size, n_passengers=n_passengers, 
                     render_mode=render_mode, 
                     allow_dropoff_anywhere=allow_dropoff_anywhere,
                     exploring_starts=exploring_starts
    )
    params = env.default_params
    
    # Split RNG keys
    rng_reset, rng_step = jax.random.split(rng)
    
    # Reset environment to get initial state
    obs, state = env.reset_env(rng_reset, params)
    
    # Define a smarter jittable policy function that preserves key sampling_policy behaviors
    def policy(state, key):
        """A jittable policy function with smart behaviors for the taxi environment."""
        # Split keys for different random operations
        key, key_action_type, key_passenger, key_random = jax.random.split(key, 4)
        
        # Get state components
        taxi_pos = state.taxi_pos
        passengers = state.passengers
        passenger_in_taxi = state.passenger_in_taxi
        
        # Decision: What type of action to take?
        # 0: Random action
        # 1: Interact (pickup/dropoff)
        # 2: Move towards a passenger or goal
        action_type = jax.random.randint(key_action_type, (), 0, 3)
        
        # 1. Random action
        random_action = jax.random.randint(key_random, (), 0, env.num_actions)
        
        # 2. Interact action (always action 5 in this environment)
        interact_action = jnp.array(5, dtype=jnp.int32)
        
        # 3. Movement action - pick a direction to move
        # Choose a random passenger
        passenger_idx = jax.random.randint(key_passenger, (), 0, n_passengers)
        
        # Determine target position based on state
        has_passenger = passenger_in_taxi >= 0
        
        # If passenger in taxi, target their goal. Otherwise target a passenger's position
        # We must avoid conditional statements that would cause tracer errors
        # Instead, compute both possibilities and select using jnp.where
        
        # Target if passenger in taxi - target goal of the passenger
        passenger_idx_to_use = jnp.maximum(0, passenger_in_taxi)  # Use max to handle case of no passenger (-1)
        goal_target = passengers[passenger_idx_to_use, 2:4]
        
        # Target if no passenger in taxi - target a random passenger's position
        passenger_target = passengers[passenger_idx, :2]
        
        # Choose target based on whether there's a passenger in taxi
        # Avoid using boolean indexing by combining the boolean into a multiplier
        target_pos = jnp.where(has_passenger, goal_target, passenger_target)
        
        # Compute direction to target
        delta = target_pos - taxi_pos
        
        # Simplistic move toward target policy
        # If larger x difference, move horizontally first
        # If larger y difference, move vertically first
        horiz_diff = jnp.abs(delta[1])
        vert_diff = jnp.abs(delta[0])
        move_horiz_first = horiz_diff >= vert_diff
        
        # Select horizontal action
        horiz_action = jnp.where(delta[1] > 0, 
                                jnp.array(1, dtype=jnp.int32),  # Right
                                jnp.array(2, dtype=jnp.int32))  # Left
        
        # Select vertical action
        vert_action = jnp.where(delta[0] > 0, 
                               jnp.array(4, dtype=jnp.int32),  # Down
                               jnp.array(3, dtype=jnp.int32))  # Up
        
        # Choose movement action based on which dimension to prioritize
        move_action = jnp.where(move_horiz_first, horiz_action, vert_action)
        
        # At target check - if we're at the target position, interact
        at_target = jnp.all(taxi_pos == target_pos)
        move_action = jnp.where(at_target, interact_action, move_action)
        
        # Final action selection based on action_type
        action = jax.lax.switch(
            action_type,
            [
                lambda: random_action,              # Random action
                lambda: interact_action,            # Interact
                lambda: move_action,                # Movement toward target
            ]
        )
        
        return action, key

    # Define step function for trajectory generation
    def step_env(carry, _):
        last_obs, last_done, state, key = carry
        
        # Split key for this step
        key, key_step, key_policy = jax.random.split(key, 3)
        
        # Use the jittable policy
        action, key_policy = policy(state, key_policy)
        
        # Take a step in the environment
        next_obs, next_state, reward, done, info = env.step(key_step, state, action, params)
        transition_data = TransitionData(
            obs=last_obs,
            action=action,
            reward=reward,
            done=done.astype(jnp.float32),
            is_first=last_done,
            state=info['state']
        )

        return (next_obs, done.astype(jnp.float32), next_state, key), transition_data
    
    (last_obs, last_done, last_state, last_key), transition_data = jax.lax.scan(
        step_env,
        (obs, 0., state, rng_step),
        None,
        length=horizon
    )
    last_gt_state = env._get_state_obs(last_state)
    data = transition_data.replace(
        obs=jnp.concatenate([jax.device_put(transition_data.obs, jax.devices('cpu')[0]), jax.device_put(last_obs[None], jax.devices('cpu')[0])], axis=0),
        state=jnp.concatenate([jax.device_put(transition_data.state, jax.devices('cpu')[0]), jax.device_put(last_gt_state[None], jax.devices('cpu')[0])], axis=0),
    )
    
    @chex.dataclass(frozen=True)
    class EnvConfig:
        obs: chex.Array
        n_actions: int
        size: int
        n_passengers: int
        n_depots: int
        state_names: List[str]

    # resize obs to img_size
    data = data.replace(obs=jax.image.resize(data.obs, data.obs.shape[0:1]+(img_size, img_size, 3), 'bilinear'))
    
    return data, EnvConfig(
        obs=data.obs[0],
        n_actions=env.num_actions, 
        size=size, 
        n_passengers=n_passengers, 
        n_depots=env.num_depots,
        state_names=env.state_names
    )

def test_taxi_gymnax_dataset():
    """Test function to visualize the TaxiGymnax dataset generation."""
    import matplotlib
    matplotlib.use('Agg')  # Use Agg backend to avoid display issues
    import matplotlib.pyplot as plt
    
    # Generate a sample dataset
    horizon = 20
    rng = jax.random.key(0)
    random_dropoff_prob = 0.3  # Use a custom value for testing
    dataset, env_config = jax.vmap(generate_taxi_gymnax_dataset, in_axes=(None, 0, None, None, None, None, None, None))(
        horizon, jax.random.split(rng, 1), 5, 2, 150, True, True, random_dropoff_prob
    )

    
    
    # Create a figure with subplots to show the sequence
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.ravel()
    
    # Plot first 10 timesteps
    for t in range(min(10, horizon)):
        # Convert observation to numpy and plot
        obs = np.array(dataset.obs[0, t])
        axes[t].imshow(obs)
        axes[t].set_title(f'Step {t}\nAction: {dataset.action[0, t]}\nReward: {dataset.reward[0, t]}')
        axes[t].axis('off')
    
    # Save the figure
    plt.tight_layout()
    plt.savefig('taxi_gymnax_dataset_test.png')
    print("Test visualization saved as 'taxi_gymnax_dataset_test.png'")
    
    # Print some information about the dataset
    print(f"\nDataset information:")
    print(f"Random dropoff probability: {random_dropoff_prob}")
    print(f"Number of actions: {env_config.n_actions}")
    print(f"Grid size: {env_config.size}")
    print(f"Number of passengers: {env_config.n_passengers}")
    print(f"Observation shape: {dataset.obs.shape}")
    print(f"Action shape: {dataset.action.shape}")
    print(f"Reward shape: {dataset.reward.shape}")
    print(f"State shape: {dataset.state.shape}")
    
    # Print initial state information
    print("\nInitial state:")
    taxi_pos = dataset.state[0, 0, :2]
    passenger_info = dataset.state[0, 0, 2:-1].reshape(-1, 4)
    passenger_in_taxi = dataset.state[0, 0, -1]
    print(f"Taxi position: {taxi_pos}")
    for i, passenger in enumerate(passenger_info):
        print(f"Passenger {i} - Current pos: {passenger[:2]}, Goal: {passenger[2:4]}")
    print(f"Passenger in taxi: {passenger_in_taxi}")

def create_taxi_gymnax_gifs(num_episodes=3, horizon=20, size=5, n_passengers=2, random_dropoff_prob=0.3, output_file="taxi_episodes.gif"):
    """Generate GIFs of TaxiGymnax episodes with action and episode labels.
    
    Args:
        num_episodes: Number of episodes to generate and visualize
        horizon: Number of timesteps per episode
        size: Grid size for the environment (5 or 10)
        n_passengers: Number of passengers in the environment
        random_dropoff_prob: Probability of dropping passengers at random locations
        output_file: Filename for the output GIF
    """
    # Force JAX to use CPU to avoid GPU-related issues in headless environments
    import os
    os.environ['JAX_PLATFORMS'] = 'cpu'
    os.environ['MPLBACKEND'] = 'Agg'
    
    import matplotlib
    # Force matplotlib to use a non-interactive backend that doesn't require a display
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    import numpy as np
    
    # Generate dataset
    print(f"Generating {num_episodes} episodes with horizon {horizon}...")
    import jax
    rng = jax.random.key(0)
    keys = jax.random.split(rng, num_episodes)
    
    dataset, env_config = jax.vmap(generate_taxi_gymnax_dataset, in_axes=(None, 0, None, None, None, None, None, None))(
        horizon, keys, size, n_passengers, 40, True, True, random_dropoff_prob
    )

    
    # Convert to numpy for easier handling
    obs_np = np.array(dataset.obs)
    actions_np = np.array(dataset.action)
    rewards_np = np.array(dataset.reward)
    
    # Define action names for labeling
    action_names = ["NOOP", "RIGHT", "LEFT", "UP", "DOWN", "INTERACT"]
    
    # Create frames directory
    frames_dir = "taxi_frames"
    os.makedirs(frames_dir, exist_ok=True)
    print(f"Saving frames to {frames_dir}/...")
    
    # Generate and save frames for all episodes
    frame_count = 0
    for episode in range(num_episodes):
        for t in range(horizon):
            # Create a new figure for each frame
            fig, ax = plt.subplots(figsize=(6, 6))
            
            # Get observation and plot it
            obs = obs_np[episode, t]
            ax.imshow(obs)
            
            # Get action and reward for this step
            action = int(actions_np[episode, t])
            reward = float(rewards_np[episode, t])
            
            # Add labels
            action_txt = action_names[action] if action < len(action_names) else f"ACTION {action}"
            ax.set_title(f"Episode {episode+1}, Step {t+1}\nAction: {action_txt}, Reward: {reward:.1f}")
            ax.axis('off')
            
            # Save the figure directly to a file
            frame_path = os.path.join(frames_dir, f"frame_{frame_count:03d}.png")
            plt.savefig(frame_path, bbox_inches='tight')
            print(f"Saved frame {frame_count+1}/{num_episodes*horizon}")
            frame_count += 1
            
            # Close the figure to avoid memory leaks
            plt.close(fig)
    
    print(f"All frames saved successfully to {frames_dir}/")
    
    # Try to create a GIF if imageio is available
    try:
        import imageio
        print(f"Attempting to create GIF from frames...")
        
        # Load the saved PNG frames
        frames = []
        for i in range(frame_count):
            frame_path = os.path.join(frames_dir, f"frame_{i:03d}.png")
            frames.append(imageio.imread(frame_path))
        
        # Save as GIF with duration in milliseconds (500ms = 2fps)
        imageio.mimsave(output_file, frames, duration=500)
        print(f"GIF saved to {output_file}")
    except Exception as e:
        print(f"Failed to create GIF: {e}")
        print(f"Individual frames are available in {frames_dir}/")
    
    return dataset, env_config

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--test", type=str, required=True, help="Test which dataset generator")
    args = parser.parse_args()

    if args.test == "taxi_gymnax":
        test_taxi_gymnax_dataset()
    elif args.test == "taxi_gymnax_gif":
        create_taxi_gymnax_gifs(num_episodes=2, horizon=20, output_file="taxi_episodes.gif")