import jax
import jax.numpy as jnp
import numpy as np
import os
import navix as nx
from src.envs.gymnax_wrappers import NavixToGymnax, PixelNoise
import inspect

def setup_environment(env_name, obs_type="rgb", img_size=64):
    """Setup a Navix environment with the specified parameters."""
    observation_fns = {
        "categorical": nx.observations.categorical,
        "symbolic": nx.observations.symbolic,
        "rgb": nx.observations.rgb
    }
    
    # Create base environment
    env = nx.make(
        f"Navix-{env_name}",
        observation_fn=observation_fns[obs_type],
        action_set=nx.actions.COMPLETE_ACTION_SET
    )
    
    # Normalize observations if needed
    if obs_type == "rgb":
        def normalized_rgb(rgb_fn, size):
            def _normalized_rgb_fn(state):
                obs = rgb_fn(state)
                obs = jax.image.resize(obs, (size, size, 3), method="bilinear")
                obs = obs / 255
                return obs
            return _normalized_rgb_fn
        
        env = env.replace(
            observation_fn=normalized_rgb(observation_fns[obs_type], img_size),
            observation_space=nx.spaces.Continuous.create((img_size, img_size, 3), 0, 1)
        )
    
    # Convert to Gymnax-compatible environment
    env = NavixToGymnax(env)
    
    # Add noise if needed
    noise_sigma = 1e-3  # Small amount of noise
    if obs_type == "rgb":
        env = PixelNoise(env, noise_sigma=noise_sigma)
    
    return env

def inspect_state_structure(env_name="DoorKey-5x5-v0"):
    """Inspect the structure of the environment state."""
    env = setup_environment(env_name, obs_type="rgb", img_size=64)
    
    # Reset environment
    key = jax.random.key(42)
    obs, env_state = env.reset(key)
    
    # Print state structure
    print(f"Environment State Type: {type(env_state)}")
    print(f"Environment State Attributes: {dir(env_state)}")
    
    # Check timestep structure
    if hasattr(env_state, 'timestep'):
        print(f"\nTimestep Type: {type(env_state.timestep)}")
        print(f"Timestep Attributes: {dir(env_state.timestep)}")
        
        # Check state structure
        if hasattr(env_state.timestep, 'state'):
            print(f"\nState Type: {type(env_state.timestep.state)}")
            print(f"State Attributes: {dir(env_state.timestep.state)}")
            
            # Check entities structure
            if hasattr(env_state.timestep.state, 'entities'):
                print(f"\nEntities Keys: {env_state.timestep.state.entities.keys()}")
                
                for entity_type, entity in env_state.timestep.state.entities.items():
                    print(f"\nEntity Type: {entity_type}")
                    print(f"Entity Class: {type(entity)}")
                    
                    # Check if it's a single entity or a list
                    if hasattr(entity, '__len__') and not isinstance(entity, (str, bytes, jnp.ndarray)):
                        print(f"Entity Length: {len(entity)}")
                        if len(entity) > 0:
                            print(f"First Entity Type: {type(entity[0])}")
                            print(f"First Entity Attributes: {dir(entity[0])}")
                    else:
                        print("Single Entity")
                        print(f"Entity Attributes: {dir(entity)}")
    
    # Print observation structure
    print(f"\nObservation Type: {type(obs)}")
    print(f"Observation Shape: {obs.shape}")
    
    return env, env_state, obs

def debug_navix():
    """Explore the Navix environment structure."""
    print("Creating a simple Navix environment...")
    env = nx.make("Navix-DoorKey-8x8-v0", observation_fn=nx.observations.rgb)
    
    print("\nNavix environment methods:")
    env_methods = [method for method in dir(env) if not method.startswith('_')]
    print(env_methods)
    
    print("\nTimestep structure:")
    rng = jax.random.key(42)
    timestep = env.reset(rng)
    timestep_methods = [method for method in dir(timestep) if not method.startswith('_')]
    print(f"Timestep methods: {timestep_methods}")
    print(f"Timestep type: {type(timestep)}")
    
    # Check if observation is available
    if hasattr(timestep, 'observation'):
        print(f"\nObservation shape: {timestep.observation.shape}")
        print(f"Observation type: {type(timestep.observation)}")
    
    # Check if state is available
    if hasattr(timestep, 'state'):
        print(f"\nState type: {type(timestep.state)}")
        state_methods = [method for method in dir(timestep.state) if not method.startswith('_')]
        print(f"State methods: {state_methods}")
    
    # Check rendering capabilities
    if 'render' in env_methods:
        print("\nRender method exists. Signature:")
        print(inspect.signature(env.render))
    elif 'render_rgb' in env_methods:
        print("\nRender_rgb method exists. Signature:")
        print(inspect.signature(env.render_rgb))
    elif 'render_rgb_array' in env_methods:
        print("\nRender_rgb_array method exists. Signature:")
        print(inspect.signature(env.render_rgb_array))
    
    # Try to find a way to render
    render_methods = [method for method in env_methods if 'render' in method]
    print(f"\nAll render-related methods: {render_methods}")
    
    return env, timestep

if __name__ == "__main__":
    env, timestep = debug_navix() 