import gymnasium as gym
import gin
import numpy as np


@gin.configurable
class FetchEnvWrapper(gym.Wrapper):
  """Wrapper around gym-robotics fetch task."""
  def __init__(self, env, mode=-1, distance_threshold=0.1, dense_reward=False):
    super().__init__(env)
    self.mode = mode
    self.game_over = False

  def _flatten_obs(self, obs):
    fetch_observation = obs['observation']
    achieved_observation = obs['achieved_goal']
    goal_observation = obs['desired_goal']
    return np.concatenate([fetch_observation, achieved_observation, goal_observation])

  def _get_flat_observation_space(self):
    observation_space_length = self.observation_space['observation'].shape[0] + self.observation_space['desired_goal'].shape[0]
    observation_space = gym.spaces.Box(
        low=-np.inf, high=np.inf, shape=(observation_space_length,), dtype=np.float32)
    return observation_space

  def reset(self, seed=None, options=None):
    seed = np.random.randint(0, 1000000) if seed is None else seed
    start_state = super().reset(seed=seed)
    self.game_over = False
    if self.mode == -1:
      return start_state
    if self.mode == 0:
      return self._easy_mode_reset(), {}
    if self.mode == 1:
      return self._medium_mode_reset(), {}
    return self._hard_mode_reset(), {}

  def _wrapper_set_obj_pos(self, offset):
    assert isinstance(offset, tuple) and len(offset) == 2
    object_xpos = self.initial_gripper_xpos[:2] + np.array(offset)
    object_qpos = self.unwrapped._utils.get_joint_qpos(self.model, self.data, "object0:joint")
    assert object_qpos.shape == (7,)
    object_qpos[:2] = object_xpos
    self.unwrapped._utils.set_joint_qpos(
      self.model, self.data, "object0:joint", object_qpos
    )
    self.unwrapped._mujoco.mj_forward(self.model, self.data)


  def _wrapper_set_goal_pos(self, offset):
    assert isinstance(offset, tuple) and len(offset) == 3
    goal = self.initial_gripper_xpos[:3].copy()
    if self.has_object:
      goal[2] = self.height_offset
      goal += self.target_offset
    goal += np.array(offset)
    self.unwrapped.goal = goal.copy()


  def _deterministic_reset(self, obj_offset, goal_offset):
    # env = self.environment.env
    if self.has_object:
      self._wrapper_set_obj_pos(obj_offset)
    self._wrapper_set_goal_pos(goal_offset)

    new_start_state = self.unwrapped._get_obs()
    new_flattened_start_state = self._flatten_obs(new_start_state)

    return new_flattened_start_state

  def _easy_mode_reset(self):
    raise NotImplemented()

  def _medium_mode_reset(self):
    raise NotImplemented()

  def _hard_mode_reset(self):
    raise NotImplemented()

  def step(self, action):
    next_obs, reward, terminal, truncated, info = super().step(action)
    # flattened_next_obs = self._flatten_obs(next_obs)
    info = self.get_current_info(info)
    # if not self.dense_reward:
    reward = float(info["is_success"])
    terminal = (reward == 1)
    if terminal:
      self.game_over = True

    return next_obs, reward, terminal, truncated, info

  def get_current_info(self, info=None):
    if info is None:
      info = {}
    info["fetch_pos"] = (0, 0, 0) # placeholder
    return info


class FetchReachEnvWrapper(FetchEnvWrapper):

  def _easy_mode_reset(self):
    return self._deterministic_reset(None, (-0.05, 0.05, 0.025))

  def _medium_mode_reset(self):
    return self._deterministic_reset(None, (-0.1, 0.1, 0.5))

  def _hard_mode_reset(self):
    return self._deterministic_reset(None, (-0.15, 0.15, 0.75))


class FetchPushEnvWrapper(FetchEnvWrapper):

  def _easy_mode_reset(self):
    return self._deterministic_reset((-0.1, 0.1), (-0.05, 0.05, 0.0))

  def _medium_mode_reset(self):
    return self._deterministic_reset((-0.1, 0.1), (0.025, -0.025, 0.0))

  def _hard_mode_reset(self):
    return self._deterministic_reset((-0.1, 0.1), (0.1, -0.1, 0.0))


class FetchSlideEnvWrapper(FetchEnvWrapper):

  def _easy_mode_reset(self):
    return self._deterministic_reset((-0.05, 0.05), (-0.25, -0.1, 0.0))

  def _medium_mode_reset(self):
    return self._deterministic_reset((-0.05, 0.05), (-0., -0.1, 0.0))

  def _hard_mode_reset(self):
    return self._deterministic_reset((-0.05, 0.05), (0.25, 0.1, 0.0))


class FetchPickAndPlaceEnvWrapper(FetchEnvWrapper):

  def _easy_mode_reset(self):
    return self._deterministic_reset((0.05, 0.05), (-0.05, -0.05, 0.1))

  def _medium_mode_reset(self):
    return self._deterministic_reset((0.05, 0.05), (-0.05, -0.05, 0.2))

  def _hard_mode_reset(self):
    return self._deterministic_reset((0.05, 0.05), (-0.05, -0.05, 0.4))