import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import os
from functools import partial
import navix as nx
from typing import Dict, List, Tuple, Callable, Optional, Union
import flax.struct
import chex

# Define action indices based on the COMPLETE_ACTION_SET
NOOP = 0
ROTATE_CW = 1
ROTATE_CCW = 2
FORWARD = 3
RIGHT = 4
BACKWARD = 5
LEFT = 6
PICKUP = 7
OPEN = 8
DONE = 9

# Direction indices
NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3

@flax.struct.dataclass
class PolicyState:
    """State maintained by the policy between decision steps."""
    subgoal: chex.Array  # Current subgoal (e.g., key position, door position, goal position)
    has_key: bool = False  # Whether the agent has picked up the key
    door_open: bool = False  # Whether the door has been opened
    memory: Dict = flax.struct.field(default_factory=dict)  # For storing additional information

def get_entity_positions(state) -> Dict:
    """Extract positions of entities from the environment state."""
    entities = {}
    for entity_type, entity_list in state.entities.items():
        # Check if entity_list is a list or a single entity
        if isinstance(entity_list, list):
            if len(entity_list) > 0:
                # Store positions of all entities of this type (list case)
                positions = entity_list.position
                if len(positions.shape) > 1:
                    # Multiple entities of this type
                    entities[entity_type] = positions
                else:
                    # Single entity
                    entities[entity_type] = positions[None]
        else:
            # Single entity object (not in a list)
            if hasattr(entity_list, 'position'):
                positions = entity_list.position
                if isinstance(positions, list) or (hasattr(positions, 'shape') and len(positions.shape) > 1):
                    # Multiple positions
                    entities[entity_type] = positions
                else:
                    # Single position
                    entities[entity_type] = jnp.expand_dims(positions, 0)
    return entities

def extract_timestep_state(env_state):
    """
    Extract the timestep.state from different environment state structures.
    Handles both NavixState and GymnaxState.
    """
    if hasattr(env_state, 'env_state'):
        # NavixState format
        return env_state.env_state.timestep.state
    elif hasattr(env_state, 'timestep'):
        # GymnaxState format 
        return env_state.timestep.state
    else:
        # Unknown format
        raise ValueError(f"Unknown environment state format: {type(env_state)}")

def manhattan_distance(pos1, pos2):
    """Calculate Manhattan distance between two positions."""
    return jnp.abs(pos1[0] - pos2[0]) + jnp.abs(pos1[1] - pos2[1])

def get_direction_to_target(agent_pos, agent_dir, target_pos):
    """
    Determine the relative direction to the target from the agent's perspective.
    Returns the action needed to face the target.
    """
    # Calculate the vector from agent to target
    dy = target_pos[0] - agent_pos[0]  # row difference
    dx = target_pos[1] - agent_pos[1]  # col difference
    
    # Determine the absolute direction to the target
    if abs(dy) > abs(dx):
        # Primarily vertical movement
        if dy < 0:
            target_dir = NORTH  # North
        else:
            target_dir = SOUTH  # South
    else:
        # Primarily horizontal movement
        if dx > 0:
            target_dir = EAST  # East
        else:
            target_dir = WEST  # West
    
    # Determine the action to take based on current direction and target direction
    dir_diff = (target_dir - agent_dir) % 4
    if dir_diff == 0:
        return None  # Already facing the right direction
    elif dir_diff == 1:
        return ROTATE_CW  # Rotate clockwise
    elif dir_diff == 3:
        return ROTATE_CCW  # Rotate counterclockwise
    else:  # dir_diff == 2
        return ROTATE_CW  # Could choose either, choose clockwise arbitrarily

def random_policy(key, obs, env_state, policy_state=None):
    """Simple random policy for baseline comparison."""
    n_actions = 10  # Based on COMPLETE_ACTION_SET
    probs = jnp.ones((n_actions,))
    action = jax.random.categorical(key, jnp.log(probs/probs.sum()))
    return action, policy_state

def navigation_policy(key, obs, env_state, policy_state=None):
    """
    Policy to navigate to a specific location.
    Args:
        key: JAX random key
        obs: Environment observation
        env_state: Environment state
        policy_state: Current policy state containing the subgoal
    Returns:
        action: Action to take
        new_policy_state: Updated policy state
    """
    if policy_state is None:
        # Default to current position if no subgoal is provided
        state = extract_timestep_state(env_state)
        agent_pos = state.entities['player'][0].position
        policy_state = PolicyState(subgoal=agent_pos)
    
    state = extract_timestep_state(env_state)
    
    # Get agent position and direction
    agent = state.entities['player'][0]
    agent_pos = agent.position
    agent_dir = agent.direction
    target_pos = policy_state.subgoal
    
    # Check if we've reached the target
    if jnp.array_equal(agent_pos, target_pos):
        return NOOP, policy_state
    
    # Determine the direction we need to face
    direction_action = get_direction_to_target(agent_pos, agent_dir, target_pos)
    
    if direction_action is not None:
        # Need to rotate to face the target
        return direction_action, policy_state
    else:
        # Already facing the right direction, move forward
        return FORWARD, policy_state

def key_collection_policy(key, obs, env_state, policy_state=None):
    """
    Policy to find and collect a key.
    Steps:
    1. Identify key position if not known
    2. Navigate to key position
    3. Pick up the key
    """
    state = extract_timestep_state(env_state)
    
    # Get entity positions
    entities = get_entity_positions(state)
    
    # Get agent info
    agent = state.entities['player'][0]
    agent_pos = agent.position
    agent_dir = agent.direction
    
    if policy_state is None or 'key' not in policy_state.memory:
        # Initialize policy state with key position as subgoal
        key_positions = entities.get('key', None)
        if key_positions is not None and len(key_positions) > 0:
            key_pos = key_positions[0]
            policy_state = PolicyState(
                subgoal=key_pos,
                memory={'key': key_pos}
            )
        else:
            # No key found in entities, use random policy
            return random_policy(key, obs, env_state, policy_state)
    
    # Check if the agent has the key (pocket is not empty)
    has_key = agent.pocket != nx.components.EMPTY_POCKET_ID
    
    if has_key:
        # Already have the key, update policy state
        return NOOP, PolicyState(
            subgoal=policy_state.subgoal,
            has_key=True,
            door_open=policy_state.door_open,
            memory=policy_state.memory
        )
    
    # Check if agent is at the key position
    if jnp.array_equal(agent_pos, policy_state.subgoal):
        # Agent is at key position, pickup the key
        return PICKUP, policy_state
    
    # Navigate to the key
    nav_action, _ = navigation_policy(key, obs, env_state, policy_state)
    return nav_action, policy_state

def door_opening_policy(key, obs, env_state, policy_state=None):
    """
    Policy to find and open a door.
    Steps:
    1. Identify door position if not known
    2. Navigate to door position
    3. Open the door
    """
    state = extract_timestep_state(env_state)
    
    # Get entity positions
    entities = get_entity_positions(state)
    
    # Get agent info
    agent = state.entities['player'][0]
    agent_pos = agent.position
    agent_dir = agent.direction
    
    if policy_state is None or 'door' not in policy_state.memory:
        # Initialize policy state with door position as subgoal
        door_positions = entities.get('door', None)
        if door_positions is not None and len(door_positions) > 0:
            door_pos = door_positions[0]
            policy_state = PolicyState(
                subgoal=door_pos,
                has_key=agent.pocket != nx.components.EMPTY_POCKET_ID,
                memory={'door': door_pos}
            )
        else:
            # No door found in entities, use random policy
            return random_policy(key, obs, env_state, policy_state)
    
    # Check if the door is open
    doors = state.entities['door']
    if isinstance(doors, list) and len(doors) > 0:
        door_open = doors[0].open
    elif hasattr(doors, 'open'):  # Single door object
        door_open = doors.open
    else:
        door_open = False
    
    if door_open:
        # Door is already open, update policy state
        return NOOP, PolicyState(
            subgoal=policy_state.subgoal,
            has_key=policy_state.has_key,
            door_open=True,
            memory=policy_state.memory
        )
    
    # Check if agent is at the door position
    at_door = manhattan_distance(agent_pos, policy_state.subgoal) <= 1
    facing_door = False
    
    if at_door:
        # Check if facing the door
        if agent_dir == NORTH and agent_pos[0] > policy_state.subgoal[0]:
            facing_door = True
        elif agent_dir == SOUTH and agent_pos[0] < policy_state.subgoal[0]:
            facing_door = True
        elif agent_dir == EAST and agent_pos[1] < policy_state.subgoal[1]:
            facing_door = True
        elif agent_dir == WEST and agent_pos[1] > policy_state.subgoal[1]:
            facing_door = True
        
        if facing_door:
            # Agent is at and facing the door, open it
            return OPEN, policy_state
    
    # Navigate to the door
    nav_action, _ = navigation_policy(key, obs, env_state, policy_state)
    return nav_action, policy_state

def goal_reaching_policy(key, obs, env_state, policy_state=None):
    """
    Policy to navigate to the goal position.
    """
    state = extract_timestep_state(env_state)
    
    # Get entity positions
    entities = get_entity_positions(state)
    
    if policy_state is None or 'goal' not in policy_state.memory:
        # Initialize policy state with goal position as subgoal
        goal_positions = entities.get('goal', None)
        if goal_positions is not None and len(goal_positions) > 0:
            goal_pos = goal_positions[0]
            policy_state = PolicyState(
                subgoal=goal_pos,
                memory={'goal': goal_pos}
            )
        else:
            # No goal found in entities, use random policy
            return random_policy(key, obs, env_state, policy_state)
    
    # Navigate to the goal
    nav_action, _ = navigation_policy(key, obs, env_state, policy_state)
    return nav_action, policy_state

def doorkey_policy(key, obs, env_state, policy_state=None):
    """
    A complete policy for the DoorKey environment that:
    1. Finds and collects the key
    2. Finds and opens the door
    3. Navigates to the goal
    """
    state = extract_timestep_state(env_state)
    
    # Get agent info
    agent = state.entities['player'][0]
    has_key = agent.pocket != nx.components.EMPTY_POCKET_ID
    
    # Get door info
    doors = state.entities['door']
    if isinstance(doors, list) and len(doors) > 0:
        door_open = doors[0].open
    elif hasattr(doors, 'open'):  # Single door object
        door_open = doors.open
    else:
        door_open = False
    
    # Get entity positions
    entities = get_entity_positions(state)
    
    if policy_state is None:
        # Initialize with empty policy state
        memory = {}
        
        # Find key position
        key_positions = entities.get('key', None)
        if key_positions is not None and len(key_positions) > 0:
            memory['key'] = key_positions[0]
        
        # Find door position
        door_positions = entities.get('door', None)
        if door_positions is not None and len(door_positions) > 0:
            memory['door'] = door_positions[0]
        
        # Find goal position
        goal_positions = entities.get('goal', None)
        if goal_positions is not None and len(goal_positions) > 0:
            memory['goal'] = goal_positions[0]
        
        # Set initial subgoal to key position if not collected
        subgoal = memory.get('key', agent.position) if not has_key else (
            memory.get('door', agent.position) if not door_open else 
            memory.get('goal', agent.position)
        )
        
        policy_state = PolicyState(
            subgoal=subgoal,
            has_key=has_key,
            door_open=door_open,
            memory=memory
        )
    
    # Update policy state with current environment information
    policy_state = PolicyState(
        subgoal=policy_state.subgoal,
        has_key=has_key,
        door_open=door_open,
        memory=policy_state.memory
    )
    
    # Determine action based on current state
    if not policy_state.has_key:
        # First priority: get the key
        action, _ = key_collection_policy(key, obs, env_state, PolicyState(
            subgoal=policy_state.memory.get('key', agent.position),
            has_key=policy_state.has_key,
            door_open=policy_state.door_open,
            memory=policy_state.memory
        ))
    elif not policy_state.door_open:
        # Second priority: open the door
        action, _ = door_opening_policy(key, obs, env_state, PolicyState(
            subgoal=policy_state.memory.get('door', agent.position),
            has_key=policy_state.has_key,
            door_open=policy_state.door_open,
            memory=policy_state.memory
        ))
    else:
        # Third priority: reach the goal
        action, _ = goal_reaching_policy(key, obs, env_state, PolicyState(
            subgoal=policy_state.memory.get('goal', agent.position),
            has_key=policy_state.has_key,
            door_open=policy_state.door_open,
            memory=policy_state.memory
        ))
    
    return action, policy_state

def run_episode(env, policy_fn, max_steps=100, render=True, seed=0):
    """
    Run a complete episode with the given policy and environment.
    Returns:
        frames: List of rendered frames if render=True
        rewards: Total reward accumulated during the episode
    """
    frames = []
    rng = jax.random.key(seed)
    rng, key_reset = jax.random.split(rng)
    
    # Reset environment
    obs, env_state = env.reset(key_reset)
    policy_state = None
    done = False
    total_reward = 0
    
    for step in range(max_steps):
        if render:
            # Use the observation directly for RGB environments
            frames.append(obs)
        
        # Get action from policy
        rng, key_act = jax.random.split(rng)
        action, policy_state = policy_fn(key_act, obs, env_state, policy_state)
        
        # Step environment
        rng, key_step = jax.random.split(rng)
        obs, env_state, reward, done, info = env.step(key_step, env_state, action)
        
        total_reward += reward
        
        if done:
            if render:
                # Add the final observation
                frames.append(obs)
            break
    
    return frames, total_reward

def create_gif(frames, filename='episode.gif', fps=5):
    """Create a GIF from a list of frames."""
    # Convert frames to uint8 if they're in float format
    frames = [np.clip(frame * 255, 0, 255).astype(np.uint8) if frame.dtype == np.float32 or frame.dtype == np.float64 
              else frame for frame in frames]
    
    fig, ax = plt.subplots(figsize=(frames[0].shape[1]/100, frames[0].shape[0]/100))
    ax.axis('off')
    
    # Create animation
    img = ax.imshow(frames[0])
    
    def update(frame):
        img.set_array(frame)
        return [img]
    
    anim = FuncAnimation(fig, update, frames=frames, interval=1000/fps, blit=True)
    anim.save(filename, writer='pillow', fps=fps)
    plt.close(fig)
    
    print(f"GIF saved to {filename}")
    return filename 