import functools
from typing import Optional, Union, Tuple, Any

import jax
import jax.numpy as jnp
import chex


from gymnax.wrappers.purerl import GymnaxWrapper
from gymnax.environments import environment, spaces

import seaborn as sns

from typing import Any, Dict, Tuple
import jax
from jax import Array
from flax import struct
from gymnax.environments.environment import (
    Environment as GymnaxEnv,
    EnvParams,
    EnvState,
)
from gymnax.environments.spaces import Discrete as GymnaxDiscrete, Box as GymnaxBox

from navix.environments.environment import Environment, Timestep
from functools import partial

# Import dependency graph modules
# from .navix_dependency_graph import (
#     create_environment_specific_dependency_graph,
#     generate_adjacency_matrix,
#     get_action_dependent_variables,
# )
# from .navix_dependency_graph import create_conditional_dependency_graph

class DreamerWrapper(GymnaxWrapper):

    @functools.partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, Any]:  # dict]:
        if params is None:
            params = self.default_params
        # key, key_reset = jax.random.split(key)
        obs_st, state_st, reward, done, info = self.step_env(key, state, action, params)
        # obs_re, state_re = self.reset_env(key_reset, params)
        
        # Auto-reset environment based on termination
        # state = jax.tree_map(
        #     lambda x, y: jax.lax.select(done, x, y), state_re, state_st
        # )
        # obs = jax.lax.select(done, obs_re, obs_st)
        return obs_st, state_st, reward, done, info

def render_minatar(obs, colors):
    n_channels = obs.shape[-1]
    numerical_state = (
        jnp.amax(obs * jnp.reshape(jnp.arange(n_channels) + 1, (1, 1, -1)), 2)
    ).astype(jnp.int32)

    new_obs = colors[numerical_state]

    return new_obs


class MinAtarPixel(GymnaxWrapper):

    def __init__(self, env):
        super().__init__(env)
        n_channels = env.obs_shape[-1]
        cmap = sns.color_palette("cubehelix", n_channels)
        cmap.insert(0, (0, 0, 0))
        self.colors = jnp.array(list(cmap))
        self.obs_shape = (*env.obs_shape[:-1], 3) #RGB

    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, Any]:
        
        obs, state, reward, done, info = self._env.step(
            key,
            state,
            action,
            params
        )

        obs = render_minatar(obs, self.colors)
        return obs, state, reward, done, info
     
    def reset(
            self,
            key: chex.PRNGKey,
            params: Optional[environment.EnvParams] = None
    )-> Tuple[chex.Array, environment.EnvState]:
        obs, state = self._env.reset(key, params)
        return render_minatar(obs, self.colors), state
    
    def observation_space(self, params: environment.EnvParams) -> spaces.Box:
        """Observation space of the environment."""
        return spaces.Box(0, 1, self.obs_shape)
    
class PixelNoise(GymnaxWrapper):

    def __init__(self, env, noise_sigma, **kwargs):
        super().__init__(env)
        self.noise_sigma = noise_sigma

    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, Any]:
        key_step, key_noise = jax.random.split(key)
        obs, state, reward, done, info = self._env.step(
            key_step,
            state,
            action,
            params
        )

        obs = obs + jax.random.normal(key_noise, obs.shape) * self.noise_sigma
        obs = jnp.clip(obs, 0., 1.)
        return obs, state, reward, done, info
     
    def reset(
            self,
            key: chex.PRNGKey,
            params: Optional[environment.EnvParams] = None
    )-> Tuple[chex.Array, environment.EnvState]:
        key, key_noise = jax.random.split(key)
        obs, state = self._env.reset(key, params)
        obs = obs + jax.random.normal(key_noise, obs.shape) * self.noise_sigma
        obs = jnp.clip(obs, 0., 1.)
        return obs, state

    def observation_space(self, params: environment.EnvParams) -> spaces.Box:
        """Observation space of the environment."""
        return spaces.Box(0, 1, self.obs_shape)

    
@struct.dataclass
class LogEnvState:
    env_state: environment.EnvState
    episode_returns: float
    episode_lengths: int
    returned_episode_returns: float
    returned_episode_lengths: int


class LogWrapper(GymnaxWrapper):
    """Log the episode returns and lengths."""

    def __init__(self, env: environment.Environment):
        super().__init__(env)

    @partial(jax.jit, static_argnums=(0,))
    def reset(
        self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
    ) -> Tuple[chex.Array, environment.EnvState]:
        obs, env_state = self._env.reset(key, params)
        state = LogEnvState(env_state, 0., 0, 0., 0)
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
        obs, env_state, reward, done, info = self._env.step(
            key, state.env_state, action, params
        )
        new_episode_return = state.episode_returns + reward
        new_episode_length = state.episode_lengths + 1
        state = LogEnvState(
            env_state=env_state,
            episode_returns=new_episode_return * (1 - done),
            episode_lengths=new_episode_length * (1 - done),
            returned_episode_returns=state.returned_episode_returns * (1 - done)
            + new_episode_return * done,
            returned_episode_lengths=state.returned_episode_lengths * (1 - done)
            + new_episode_length * done,
        )
        # info = {}
        info["returned_episode_returns"] = state.returned_episode_returns
        info["returned_episode_lengths"] = state.returned_episode_lengths
        info["returned_episode"] = done
        return obs, state, reward, done, info

@struct.dataclass
class GymnaxState(EnvState):
    timestep: Timestep
    time: Array

from navix.components import HasColour, Openable, Directional

class NavixToGymnax(GymnaxEnv):
    def __init__(self, env: Environment, autoreset: bool = True):
        self.env = env
        self.autoreset = autoreset

    @property
    def default_params(self) -> EnvParams:
        return EnvParams(max_steps_in_episode=self.env.max_steps)

    @classmethod
    def wrap(cls, env: Environment) -> Tuple[GymnaxEnv, EnvParams]:
        return cls(env=env), EnvParams(max_steps_in_episode=env.max_steps)

    def action_space(self, params: Any):
        return GymnaxDiscrete(len(self.env.action_set))

    def observation_space(self, params: Any):
        o_space = self.env.observation_space
        return GymnaxBox(
            low=o_space.minimum,
            high=o_space.maximum,
            shape=o_space.shape,
            dtype=o_space.dtype,
        )
    
    @property
    def obs_shape(self, params):
        return self.env.observation_space.shape

    def reset(
        self, key: jax.Array, params: EnvParams = None
    ) -> Tuple[Array, EnvState]:
        timestep = self.env.reset(key)

        # state = timestep.state
        # new_player = state.get_player().replace(direction=jnp.array(0))
        # state.set_player(new_player)
        # timestep = timestep.replace(state=state)

        timestep = timestep.replace(info={})
        return (
            timestep.observation,
            GymnaxState(time=timestep.t, timestep=timestep),
        )
    
    def get_state_names(self, state):
        # player
        entities = state.entities
        player = entities['player']
        # player_state = jnp.concatenate([player.tag, player.position[0], player.direction, player.pocket], axis=-1)
        player_state = jnp.concatenate([player.position[0], player.direction, player.pocket], axis=-1)
        player_state_names = ['player_y', 'player_x', 'player_dir', 'player_pocket']
        # entities
        entity_state_names = []
        for entity_class in sorted(entities):
            if entity_class in ('wall', 'floor', 'player'): # exclude walls and floor
                continue
            entity = state.entities[entity_class]
            entity_name = []
            if isinstance(entity, Openable):
                entity_state = entity.open #+ (entity.requires != jnp.zeros(entity.shape))
            elif isinstance(entity, Directional):
                entity_state = entity.direction
            else:
                entity_state = None

            for i in range(entity.shape[0]):
                entity_name.extend([f'{entity_class}_y_{i}', f'{entity_class}_x_{i}'])
                if isinstance(entity, HasColour):
                    entity_name.append(f'{entity_class}_colour_{i}')
                if entity_state is not None:
                    entity_name.append(f'{entity_class}_state_{i}')

            entity_state_names.extend(entity_name)
            
        return player_state_names + entity_state_names

    def get_state(self, state):
        # player
        entities = state.entities
        player = entities['player']
        # player_state = jnp.concatenate([player.tag, player.position[0], player.direction, player.pocket], axis=-1)
        player_state = jnp.concatenate([player.position[0], player.direction, player.pocket], axis=-1)
        factored_state = [player_state]
        # entities
        for entity_class in sorted(entities):
            if entity_class in ('wall', 'floor', 'player'): # exclude walls and floor
                continue
            entity = state.entities[entity_class]
            tag = entity.tag
            # colour layer
            if isinstance(entity, HasColour):
                colour = entity.colour
            else:
                colour = jnp.zeros(entity.shape)
            # state layer
            if isinstance(entity, Openable):
                entity_state = entity.open# + (entity.requires != jnp.zeros(entity.shape))
            elif isinstance(entity, Directional):
                entity_state = entity.direction
            else:
                entity_state = None

            if len(tag.shape) != len(entity.position.shape):
                # entity_symbol = jnp.concatenate([tag, entity.position[0], colour, entity_state], axis=-1)
                if entity_state is not None:
                    entity_symbol = jnp.concatenate([entity.position[0], colour, entity_state], axis=-1)
                else:
                    entity_symbol = jnp.concatenate([entity.position[0], colour], axis=-1)
            else:
                # entity_symbol = jnp.concatenate([tag, entity.position, colour, entity_state], axis=-1)
                if entity_state is not None:
                    entity_symbol = jnp.concatenate([entity.position, colour, entity_state], axis=-1)
                else:
                    entity_symbol = jnp.concatenate([entity.position, colour], axis=-1)
            
            factored_state.append(entity_symbol.reshape(-1))
        
        factored_state = jnp.concatenate(factored_state, axis=-1)
        return factored_state

    def step(
        self, key: Array, state: GymnaxState, action: jax.Array, params: EnvParams
    ) -> Tuple[Array, EnvState, Array, Array, Dict[str, Any]]:
        factored_state = self.get_state(state.timestep.state)
        step_function = self.env.step if self.autoreset else self.env._step
        new_timestep = step_function(state.timestep, action)
        new_timestep = new_timestep.replace(info={}) # remove return
        info = {**new_timestep.info, 'state': factored_state}
        return (
            new_timestep.observation,
            GymnaxState(time=new_timestep.t, timestep=new_timestep),
            new_timestep.reward,
            new_timestep.is_done(),
            info,
        )

# class NavixToGymnaxGraph(NavixToGymnax):
#     """
#     An extension of NavixToGymnax wrapper that includes dependency graph
#     information in the info dictionary returned by step().
#     """
    
#     def __init__(self, env: Environment, autoreset: bool = True, env_name: str = None):
#         super().__init__(env, autoreset)
#         self.env_name = env_name or self._infer_env_name()
#         # Create the global dependency graph for this environment
#         self.global_dependency_graph = create_environment_specific_dependency_graph(self.env_name)
#         self.global_adj_matrix, self.var_names = generate_adjacency_matrix(self.global_dependency_graph)
#         # Get action-specific variable dependencies
#         self.action_vars = get_action_dependent_variables(self.global_dependency_graph)
        
#     def _infer_env_name(self) -> str:
#         """Attempt to infer the environment name from the environment object."""
#         env_str = str(self.env)
#         if "DoorKey" in env_str:
#             return "DoorKey"
#         elif "FourRooms" in env_str:
#             return "FourRooms"
#         elif "MultiRoom" in env_str:
#             return "MultiRoom"
#         elif "Empty" in env_str:
#             return "Empty"
#         elif "KeyCorridor" in env_str:
#             return "KeyCorridor"
#         else:
#             # Default to DoorKey if we can't infer
#             return "DoorKey"
    
#     def _analyze_state_conditions(self, state) -> Dict[str, bool]:
#         """
#         Analyze the current state to determine conditions like:
#         - Whether player is near a key
#         - Whether player has a key
#         - Whether player is facing a door
        
#         Returns a dictionary of conditions.
#         """
#         conditions = {}
        
#         # Get entities
#         entities = state.entities if hasattr(state, 'entities') else state.timestep.state.entities
#         player = entities['player']
#         player_pos = player.position[0]
#         player_dir = player.direction
#         player_pocket = player.pocket
        
#         # Check if player has a key (non-zero pocket)
#         conditions['has_key'] = bool(jnp.any(player_pocket != 0))
        
#         # Check for key entities to determine if player is near a key
#         near_key = False
#         if 'key' in entities:
#             key = entities['key']
#             key_positions = key.position
#             # Check if player is adjacent to any key
#             for i in range(key.shape[0]):
#                 key_pos = key_positions[i]
#                 # Check if this key is at a position
#                 if jnp.all(key_pos != 0):  # Non-zero position means key exists
#                     # Check if player is adjacent to this key
#                     dist = jnp.abs(player_pos - key_pos).sum()
#                     if dist <= 1:  # Adjacent or same position
#                         near_key = True
#         conditions['near_key'] = near_key
        
#         # Check if player is facing a door
#         facing_door = False
#         if 'door' in entities:
#             door = entities['door']
#             door_positions = door.position
            
#             # Direction vectors for each direction (EAST, SOUTH, WEST, NORTH)
#             dir_vectors = jnp.array([
#                 [0, 1],  # EAST
#                 [1, 0],  # SOUTH
#                 [0, -1], # WEST
#                 [-1, 0]  # NORTH
#             ])
            
#             # Get player's direction vector
#             dir_idx = jnp.argmax(player_dir)
#             dir_vector = dir_vectors[dir_idx]
            
#             # Position player is facing
#             facing_pos = player_pos + dir_vector
            
#             # Check if facing position has a door
#             for i in range(door.shape[0]):
#                 door_pos = door_positions[i]
#                 if jnp.all(door_pos == facing_pos):
#                     facing_door = True
        
#         conditions['facing_door'] = facing_door
        
#         return conditions
    
#     def _get_dependency_info(self, state, action):
#         """
#         Generate dependency graph information based on the current state and action.
        
#         Returns a dictionary with dependency information.
#         """
#         # Analyze current state conditions
#         conditions = self._analyze_state_conditions(state)
        
#         # Create conditional dependency graph based on the conditions
#         conditional_graph, affected_edges = create_conditional_dependency_graph(
#             self.env_name, conditions
#         )
        
#         # Generate adjacency matrix for the conditional graph
#         conditional_adj_matrix, _ = generate_adjacency_matrix(conditional_graph)
        
#         # Get all variables affected by the current action
#         affected_vars = self.action_vars.get(action, [])
        
#         # Create action-specific adjacency matrix
#         action_matrix = jnp.zeros_like(conditional_adj_matrix)
        
#         # For each affected variable, copy its dependencies from the conditional graph
#         for var in affected_vars:
#             if var in self.var_names:
#                 var_idx = self.var_names.index(var)
#                 for parent_idx in range(len(self.var_names)):
#                     if conditional_adj_matrix[parent_idx, var_idx] == 1:
#                         action_matrix = action_matrix.at[parent_idx, var_idx].set(1)
        
#         # Return dependency information
#         return {
#             'global_adj_matrix': self.global_adj_matrix,
#             'conditional_adj_matrix': conditional_adj_matrix,
#             'action_adj_matrix': action_matrix,
#             'var_names': self.var_names,
#             'affected_vars': affected_vars,
#             'conditions': conditions
#         }
            
#     def step(
#         self, key: Array, state: GymnaxState, action: jax.Array, params: EnvParams
#     ) -> Tuple[Array, EnvState, Array, Array, Dict[str, Any]]:
#         # Get the factored state before step
#         factored_state = self.get_state(state.timestep.state)
        
#         # Call the step function from the base class
#         step_function = self.env.step if self.autoreset else self.env._step
#         new_timestep = step_function(state.timestep, action)
#         new_timestep = new_timestep.replace(info={})  # remove return
        
#         # Get dependency graph information
#         dependency_info = self._get_dependency_info(state.timestep.state, action)
#         print(dependency_info)
#         # Create info dictionary with both state and dependency information
#         info = {
#             **new_timestep.info,
#             'state': factored_state,
#             'dependency_graph': dependency_info
#         }
        
#         return (
#             new_timestep.observation,
#             GymnaxState(time=new_timestep.t, timestep=new_timestep),
#             new_timestep.reward,
#             new_timestep.is_done(),
#             info,
#         )

# if __name__=='__main__':
#     import gymnax
#     import matplotlib.pyplot as plt

#     env, env_params = gymnax.make('Breakout-MinAtar')
#     env = PixelNoise(MinAtarPixel(env), noise_sigma=5/255)
#     obs, state = jax.jit(env.reset)(jax.random.key(0), env_params)
#     plt.imshow(obs)
#     plt.show()

