import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import os
import navix as nx
from typing import Dict, Tuple, Any

# Define action indices based on the COMPLETE_ACTION_SET
NOOP = 0
ROTATE_CW = 1
ROTATE_CCW = 2
FORWARD = 3
RIGHT = 4
BACKWARD = 5
LEFT = 6
PICKUP = 7
OPEN = 8
DONE = 9

# Direction indices
NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3

def doorkey_simple_policy(key, obs, state=None, info=None):
    """
    A simplified policy that:
    1. Randomly rotates and moves forward
    2. Picks up items when they're in view
    3. Opens doors when they're in view
    
    This is not an optimal policy but should work to demonstrate visualization
    """
    # Use a predefined sequence of actions in a loop
    action_sequence = [FORWARD, ROTATE_CW, FORWARD, PICKUP, FORWARD, OPEN, FORWARD]
    
    # Use the key to select a random action from the sequence
    action_idx = jax.random.randint(key, (), 0, len(action_sequence))
    action = action_sequence[action_idx]
    
    return action

def run_episode(env_name="DoorKey-5x5-v0", max_steps=100, seed=42):
    """Run an episode and return frames for visualization."""
    # Create environment
    env = nx.make(f"Navix-{env_name}", observation_fn=nx.observations.rgb)
    
    # Initialize random key
    rng = jax.random.key(seed)
    
    # Reset environment
    rng, key_reset = jax.random.split(rng)
    timestep = env.reset(key_reset)
    
    frames = []
    total_reward = 0
    
    for step in range(max_steps):
        # Render current state
        frame = env.render_rgb(timestep.state)
        frames.append(frame)
        
        # Select action
        rng, key_action = jax.random.split(rng)
        action = doorkey_simple_policy(key_action, timestep.observation)
        
        # Step environment
        timestep = env.step(timestep, action)
        total_reward += timestep.reward
        
        if timestep.is_done():
            # Add final frame
            frame = env.render_rgb(timestep.state)
            frames.append(frame)
            break
    
    return frames, total_reward

def create_gif(frames, filename='episode.gif', fps=5):
    """Create a GIF from a list of frames."""
    # Convert frames to uint8 if they're in float format
    frames = [np.clip(frame * 255, 0, 255).astype(np.uint8) if frame.dtype == np.float32 or frame.dtype == np.float64 
              else frame for frame in frames]
    
    fig, ax = plt.subplots(figsize=(frames[0].shape[1]/100, frames[0].shape[0]/100))
    ax.axis('off')
    
    # Create animation
    img = ax.imshow(frames[0])
    
    def update(frame):
        img.set_array(frame)
        return [img]
    
    anim = FuncAnimation(fig, update, frames=frames, interval=1000/fps, blit=True)
    anim.save(filename, writer='pillow', fps=fps)
    plt.close(fig)
    
    print(f"GIF saved to {filename}")
    return filename

if __name__ == "__main__":
    # Create output directory if it doesn't exist
    os.makedirs("outputs/navix_demos", exist_ok=True)
    
    # Run demo for different environments
    envs = ["DoorKey-5x5-v0", "DoorKey-8x8-v0"]
    
    for env_name in envs:
        print(f"Running demo for {env_name}")
        frames, reward = run_episode(env_name=env_name, max_steps=100, seed=42)
        
        # Create GIF
        gif_path = f"outputs/navix_demos/{env_name}_simple.gif"
        create_gif(frames, filename=gif_path, fps=5)
        
        print(f"Completed with reward: {reward}")
        print(f"Saved GIF to: {gif_path}")
    
    print("All demos completed!") 