import os
from typing import Optional, Union
import numpy as np

import gymnasium as gym
import manip_envs
from gymnasium.wrappers import RescaleAction

from jax_rl import wrappers


def make_env(env_name: str,
             seed: int,
             save_folder: Optional[str] = None,
             action_repeat: int = 1,
             frame_stack: int = 1,
             from_pixels: bool = False,
             image_size: int = 84) -> gym.Env:
    # Check if the env is in gym.
    all_envs = gym.envs.registry
    env_ids = [env_spec.id for env_spec in all_envs.values()]

    if env_name in env_ids:
        env = gym.make(env_name, render_mode='rgb_array')
    else:
        # domain_name, task_name = env_name.split('-')
        # env = wrappers.DMCEnv(domain_name=domain_name,
        #                       task_name=task_name,
        #                       task_kwargs={'random': seed})
        raise NotImplementedError()

    env = wrappers.ManipEpisodeMonitor(env)

    if action_repeat > 1:
        env = wrappers.ManipRepeatAction(env, action_repeat)
    
    if isinstance(env.action_space, gym.spaces.box.Box):
        env = RescaleAction(env, -1.0, 1.0)

    if save_folder is not None:
        env = wrappers.ManipVideoRecorder(env, save_folder=save_folder)

    if from_pixels:
        # camera_id = 2 if domain_name == 'quadruped' else 0
        # env = PixelObservationWrapper(env,
        #                               render_kwargs={
        #                                   'pixels': {
        #                                       'height': image_size,
        #                                       'width': image_size,
        #                                       'camera_id': camera_id
        #                                   }
        #                               })
        # env = wrappers.TakeKey(env, take_key='pixels')
        raise NotImplementedError()
    else:
        env = wrappers.ManipObservation(env)

    if frame_stack > 1:
        env = wrappers.ManipFrameStack(env, num_stack=frame_stack)

    # env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    return env
