
from functools import partial
import gymnax
import gymnax.wrappers
import navix as nx
import envs.four_rooms_fixed   
import envs.uniform_doorkey
from envs.taxi_gymnax import TaxiGymnax
import chex
import jax

from collections.abc import Callable
from utils.printarr import printarr

from envs.gymnax_wrappers import DreamerWrapper, PixelNoise, MinAtarPixel, NavixToGymnax, LogWrapper
from envs.navix_to_gymnax_graph import NavixToGymnaxGraph

@chex.dataclass(frozen=True)
class EnvConfig:
    obs: chex.Array
    action: chex.Array
    done: chex.Array
    reward: chex.Array
    env_reset: Callable
    env_step: Callable
    jittable: bool
    n_actions: int
    truncated: chex.Array
    n_envs: int
    base_env : gymnax.environments.environment.Environment
    env_params : gymnax.EnvParams
    state_names: list[str]

def build_gymnax(env_name, config, rng=jax.random.PRNGKey(0)):
    basic_env, env_params = gymnax.make(env_name, **config.env_params)
    
    if 'FourRooms' in env_name:
        env_params = env_params.replace(
            resample_goal_pos=True
            # resample_init_pos=True
        )

    # basic_env, env_params = gymnax.make(env_name)
    basic_env = basic_env if config.autoreset else DreamerWrapper(basic_env)
    if 'MinAtar' in env_name:
        basic_env = PixelNoise(
            MinAtarPixel(basic_env),
            noise_sigma=config.env_params.get('noise_sigma', 0.)
        ) if config.render else basic_env

    env = LogWrapper(basic_env)

    obs, env_state = env.reset(rng, env_params)
    action = env.action_space().n
    obs, env_state, reward, done, info = env.step(rng, env_state, action, env_params)
    printarr(obs)
    
    env_config = EnvConfig(
        obs=obs,
        action=action,
        done=done,
        reward=reward,
        truncated=done,
        jittable=True,
        n_actions=basic_env.action_space().n,
        env_reset=partial(env.reset, params=env_params),
        env_step=partial(env.step, params=env_params),
        n_envs=config.n_envs,
        base_env=basic_env,
        env_params=env_params
    )
    return env_config


def build_navix(envname, config, rng=jax.random.PRNGKey(0)):

    def normalized_rgb(rgb_fn, size=64):
        def _normalized_rgb_fn(state):
            obs = rgb_fn(state)
            obs = jax.image.resize(obs, (size, size, 3), method="bilinear")
            obs = obs / 255
            return obs

        return _normalized_rgb_fn
    
    def normalized_symbolic(rgb_fn):
        def _normalize_sym_fn(state):
            obs = rgb_fn(state)
            obs = obs / 10
            return obs

        return _normalize_sym_fn
    

    obs_fns = {
        'categorical': nx.observations.categorical,
        'categorical_limited': nx.observations.categorical_first_person,
        'symbolic': nx.observations.symbolic,
        'symbolic_limited': nx.observations.symbolic_first_person,
        'rgb': nx.observations.rgb,
        'rgb_limited': nx.observations.rgb_first_person
    }

    envname = f'Navix-{envname}'
    print(f'Building {envname} with {config.env_params}')
    action_space = config.get('action_space', 'reduced')
    action_set = nx.actions.COMPLETE_ACTION_SET if action_space == 'full' else nx.actions.DEFAULT_ACTION_SET
    observation_space = config.get('observation_space', 'symbolic')
    env = nx.make(
        envname,
        **config.env_params,
        action_set=action_set,
        observation_fn=obs_fns[observation_space]
    )
    if 'rgb' in observation_space:
        img_size = config.get('img_size', 64)
        env = env.replace(
            observation_fn=normalized_rgb(obs_fns[observation_space], img_size),
            observation_space=nx.spaces.Continuous.create((img_size,img_size,3), 0, 1)
        )
    elif 'symbolic' in observation_space:
        env = env.replace(
            observation_fn=normalized_symbolic(obs_fns[observation_space])
        )

    basic_env = NavixToGymnaxGraph(env, autoreset=config.get('autoreset', True))
    env_params = basic_env.default_params
    
    # if 'rgb' in observation_space:
    #     basic_env = PixelNoise(
    #         basic_env,
    #         noise_sigma=config.get('noise_sigma', 0.)
    #     )

    env = LogWrapper(basic_env)

    action = env.action_space(env_params).sample(rng)
    rng=jax.random.key(10)
    obs, env_state = env.reset(rng)
    obs, env_state, reward, done, info = env.step(rng, env_state, action, env_params)

    env_config = EnvConfig(
        obs=obs,
        action=action,
        done=done,
        reward=reward,
        truncated=done,
        jittable=True,
        n_actions=basic_env.action_space(env_params).n,
        env_reset=partial(env.reset, params=env_params),
        env_step=partial(env.step, params=env_params),
        n_envs=config.n_envs,
        base_env=basic_env,
        env_params=env_params,
        state_names=basic_env.get_state_names(env_state.env_state.timestep.state)
    )
    return env_config


def build_taxi_gymnax(envname, config, rng=jax.random.PRNGKey(0)):
    
    render_mode = 1 if config.get('observation_space', 'rgb') == 'rgb' else 0
    basic_env = TaxiGymnax(
                    size=config.get('grid_size', 5),
                    n_passengers=config.get('n_passengers', 1), 
                    render_mode=render_mode, 
                    allow_dropoff_anywhere=config.get('allow_dropoff_anywhere', True),
                    exploring_starts=config.get('exploring_starts', False),
                    img_size=config.get('img_size', 64),
                    max_steps_in_episode=config.get('max_steps_in_episode', 200)
                )
    env_params = basic_env.default_params
    env = LogWrapper(basic_env)
    key = jax.random.PRNGKey(0)
    key_action, key_step, key_reset = jax.random.split(key, 3)
    obs, env_state = env.reset(key_reset)
    action = env.action_space.sample(key_action)
    obs, env_state, reward, done, info = env.step(key_step, env_state, action)

    env_config = EnvConfig(
        obs=obs,
        action=action,
        done=done,
        reward=reward,
        truncated=done,
        jittable=True,
        n_actions=basic_env.num_actions,
        env_reset=partial(env.reset, params=env_params),
        env_step=partial(env.step, params=env_params),
        n_envs=config.n_envs,
        base_env=basic_env,
        env_params=env_params,
        state_names=basic_env.state_names
    )
    return env_config

class EnvironmentBuilder:
    ENV_BUILDERS = {
        'gymnax': build_gymnax,
        'navix': build_navix,
        'taxi': build_taxi_gymnax
    }

    @staticmethod
    def build(config, rng):
        env_type, env_name = config.env_name.split('_')
        return EnvironmentBuilder.ENV_BUILDERS[env_type](env_name, config, rng=rng)
