import os
from typing import Optional, Union
import numpy as np

import gymnasium as gym
import gymnasium_robotics
from gymnasium.wrappers import RescaleAction
from gymnasium.wrappers.pixel_observation import PixelObservationWrapper
from gymnasium.utils.ezpickle import EzPickle
from gymnasium.envs.registration import register
from gymnasium_robotics.envs.fetch import MujocoFetchEnv
from gymnasium_robotics.envs.fetch.reach import MujocoFetchReachEnv, MODEL_XML_PATH

from jax_rl import wrappers

class MujocoFetchReachCustomEnv(MujocoFetchReachEnv):
    def __init__(self, reward_type: str = "sparse", gripper_extra_height: float = 0.0,
                 target_offset: Union[float, np.ndarray] = 0.0, **kwargs):
        initial_qpos = {
            "robot0:slide0": 0.4049,
            "robot0:slide1": 0.48,
            "robot0:slide2": 0.0,
        }
        MujocoFetchEnv.__init__(
            self,
            model_path=MODEL_XML_PATH,
            has_object=False,
            block_gripper=True,
            n_substeps=20,
            gripper_extra_height=gripper_extra_height,
            target_in_the_air=True,
            target_offset=target_offset,
            obj_range=0.15,
            target_range=0.15,
            distance_threshold=0.05,
            initial_qpos=initial_qpos,
            reward_type=reward_type,
            **kwargs,
        )
        EzPickle.__init__(self, reward_type=reward_type, **kwargs)

    def _sample_goal(self):
        if self.has_object:
            goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(
                -self.target_range, self.target_range, size=3
            )
            goal += self.target_offset
            goal[2] = self.height_offset
            if self.target_in_the_air and self.np_random.uniform() < 0.5:
                goal[2] += self.np_random.uniform(0, 0.45)
        else:
            goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(
                -self.target_range, self.target_range, size=3
            )
            goal += self.target_offset
        return goal.copy()

gym.register_envs(gymnasium_robotics)

register(
    id="FetchReachDense-v2", # Override the existing FetchReachDense-v2
    entry_point="jax_rl.utils_fetch:MujocoFetchReachCustomEnv",
    kwargs={
        "reward_type": "dense",
        "gripper_extra_height": 0.1, # 0.2 -> 0.1
        "target_offset": np.array([0.0, 0.0, 0.1]), # 0.0 -> 0.1
    },
    max_episode_steps=50,
)

def make_env(env_name: str,
             seed: int,
             save_folder: Optional[str] = None,
             action_repeat: int = 1,
             frame_stack: int = 1,
             from_pixels: bool = False,
             image_size: int = 84) -> gym.Env:
    # Check if the env is in gym.
    all_envs = gym.envs.registry
    env_ids = [env_spec.id for env_spec in all_envs.values()]

    if env_name in env_ids:
        env = gym.make(env_name, render_mode='rgb_array')
    else:
        # domain_name, task_name = env_name.split('-')
        # env = wrappers.DMCEnv(domain_name=domain_name,
        #                       task_name=task_name,
        #                       task_kwargs={'random': seed})
        raise NotImplementedError()

    env = wrappers.FetchEpisodeMonitor(env)

    if action_repeat > 1:
        env = wrappers.FetchRepeatAction(env, action_repeat)
    
    if isinstance(env.action_space, gym.spaces.box.Box):
        env = RescaleAction(env, -1.0, 1.0)

    if save_folder is not None:
        env = wrappers.FetchVideoRecorder(env, save_folder=save_folder)

    if from_pixels:
        # camera_id = 2 if domain_name == 'quadruped' else 0
        # env = PixelObservationWrapper(env,
        #                               render_kwargs={
        #                                   'pixels': {
        #                                       'height': image_size,
        #                                       'width': image_size,
        #                                       'camera_id': camera_id
        #                                   }
        #                               })
        # env = wrappers.TakeKey(env, take_key='pixels')
        raise NotImplementedError()
    else:
        env = wrappers.FetchObservation(env)

    if frame_stack > 1:
        env = wrappers.FetchFrameStack(env, num_stack=frame_stack)

    # env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)

    return env
