import os
import joblib
from copy import deepcopy

import gym

try:
    from nle.env import tasks
    from nle.env.base import DUNGEON_SHAPE
except ModuleNotFoundError:
    pass

from .wrappers import ObjectiveWrapper, ActionSelectWrapper, EventWrapper
from .crafter_wrappers import CrafterRenderWrapper, ImageToPyTorch, ObservationDictWrapper, CrafterMonitorWrapper
from .atari_wrappers import make_atari, wrap_deepmind, wrap_pytorch, WarpFrame


NETHACK_ENVS = dict(
    staircase=tasks.NetHackStaircase,
    score=tasks.NetHackScore,
    pet=tasks.NetHackStaircasePet,
    oracle=tasks.NetHackOracle,
    gold=tasks.NetHackGold,
    eat=tasks.NetHackEat,
    scout=tasks.NetHackScout,
    challenge=tasks.NetHackChallenge
)

NETHACK_KWARGS = dict(
    savedir=None,
    archivefile=None
)

CRAFTER_KWARGS = dict(
    size=(84, 84),
    render_centering=False, 
    health_reward_coef=0.0, 
    immortal=True, 
    idle_death=100
)

CRAFTER_ORIGINAL_KWARGS = dict(
    size=(84, 84),
    render_centering=False,
    vanila=True
)

SOBOKAN_KWARGS = dict(
    achievement_reward=True
)

# MINIGRID_ENVS = dict(
#     keycorridor=gym_minigrid.envs.KeyCorridor,
#     blockedunlockpickup=gym_minigrid.envs.BlockedUnlockPickup,
#     distractions=gym_minigrid.envs.Distractions
# )

MINIGRID_KWARGS = dict(
    default=dict(achievement_reward=True),
    keycorridor=dict(room_size=5, num_rows=3),
    # distractions=dict(room_size=5, num_rows=9, num_nodes=15)
    distractions=dict(room_size=5, num_rows=5, num_nodes=8)
)


ENV_KWARGS = dict(
    nethack=NETHACK_KWARGS,
    crafter=CRAFTER_KWARGS,
    atari={},
    sobokan=SOBOKAN_KWARGS,
    minigrid=MINIGRID_KWARGS,
    orbit={}
)


def get_env(flags):
    return flags.env.split('-')[0]


def make_env(env_name, kwargs={}, flags=None):
    kwargs = deepcopy(kwargs)
    env_names = env_name.split('-')

    base_env = env_names[0]
    is_nethack = base_env == 'nethack'
    is_crafter = base_env == 'crafter'
    is_atari = base_env == 'atari'
    is_sobokan = base_env == 'sobokan'
    is_minigrid = base_env == 'minigrid'
    is_orbit = base_env == 'orbit'

    env = None

    if is_nethack:
        env_cls = NETHACK_ENVS[env_names[1]]
    elif is_crafter:
        from crafter.env import Env as CrafterEnv
        env_cls = CrafterEnv
    elif is_atari:
        env_id = env_names[1] + "NoFrameskip-v4"
        env_cls = lambda: wrap_deepmind(
            make_atari(env_id),
            clip_rewards=False,
            frame_stack=True,
            scale=False,
        )
    elif is_sobokan:
        from gym_sokoban.envs.sokoban_env import SokobanEnv
        env_cls = SokobanEnv
    elif is_minigrid:
        import gym_minigrid.envs
        from gym_minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
        MINIGRID_ENVS = dict(
            keycorridor=gym_minigrid.envs.KeyCorridor,
            blockedunlockpickup=gym_minigrid.envs.BlockedUnlockPickup,
            distractions=gym_minigrid.envs.Distractions
        )
        env_cls = MINIGRID_ENVS[env_names[1]]
    elif is_orbit:
        from omni.isaac.kit import SimulationApp
        config = {"headless": True}
        simulation_app = SimulationApp(config)
        import omni.isaac.orbit.environments  # noqa: F401
        from omni.isaac.orbit.utils.parse_cfg import parse_env_cfg

        task = {
            "cartpole": "Isaac-Cartpole-v0"
        }[env_names[1]]

        env_cfg = parse_env_cfg(task, use_gpu=False, num_envs=1)
        env = gym.make(task, cfg=env_cfg, headless=True)

        print(env.action_space)
        print(env.observation_space)
    else:
        raise NotImplementedError(f'Unrecognized env: {base_env}')

    def get_key(key, default=None):
        if key in kwargs:
            return kwargs.pop(key)
        elif flags is not None:
            return flags.get(key, default)
        else:
            return default

    env_id = get_key("env_id")
    crafter_monitor = get_key("use_crafter_monitor", False)
    pred_supervised = get_key("pred_supervised", False)

    if len(env_names) == 1:
        env_kwargs = deepcopy(ENV_KWARGS[base_env])
    else:
        env_kwargs = deepcopy(ENV_KWARGS[base_env].get('default', {}))
        env_kwargs.update(ENV_KWARGS[base_env].get(env_names[1], {}))
    
    if is_crafter:
        if get_key("crafter_original", False):
            env_kwargs = deepcopy(CRAFTER_ORIGINAL_KWARGS)
        else:
            env_kwargs["static_environment"] = get_key("crafter_static", False)
            env_kwargs["repeat_deduction"] = get_key("crafter_repeat_deduction", 0.0)

    num_objectives = get_key("num_objectives")
    objective_selection_algo = get_key("objective_selection_algo")
    causal_graph_load_path = get_key("causal_graph_load_path")
    include_new_tasks = get_key("include_new_tasks", True)
    done_if_reward = get_key("done_if_reward", False)
    env_kwargs.update(kwargs)

    if env is None:
        env = env_cls(**env_kwargs)

    if is_crafter:
        env = CrafterRenderWrapper(env)
        if crafter_monitor and env_id is not None:
            save_dir = get_key("savedir") + "/crafter_monitor"
            os.makedirs(save_dir, exist_ok=True)
            env = CrafterMonitorWrapper(env, env_id, save_dir, save_freq=1, batch_size=100)
    if is_sobokan:
        env = WarpFrame(env, grayscale=False)
        env = ActionSelectWrapper(env)
    if is_minigrid:
        env = RGBImgPartialObsWrapper(env)
        env = ImgObsWrapper(env)
        env = WarpFrame(env, grayscale=False)
    if is_crafter or is_atari or is_sobokan or is_minigrid:
        env = ImageToPyTorch(env)
        env = ObservationDictWrapper(env)
    if pred_supervised:
        env = EventWrapper(env, env_name)
    if num_objectives is not None:
        if objective_selection_algo is not None:
            selection = objective_selection_algo
        elif causal_graph_load_path is not None:
            graph = joblib.load(causal_graph_load_path)
            # print(graph)
            selection = ('graph', {'graph': graph})
        else:
            selection = ('random', {})
        env = ObjectiveWrapper(env, num_objectives, include_new_tasks=include_new_tasks, objective_selection=selection, done_if_reward=done_if_reward)
    elif get_key("goal_generation"):
        env = ObjectiveWrapper(env, 2, include_new_tasks=False, objective_selection="random")

    return env
