import gym
import numpy as np
import copy


# used for earl_kitchen environment
component_to_state_idx = {
    'arm': [0, 1, 2, 3, 4, 5, 6, 7, 8],
    'burner0': [9, 10],
    'burner1': [11, 12],
    'burner2': [13, 14],
    'burner3': [15, 16],
    'light_switch': [17, 18],
    'slide_cabinet': [19],
    'hinge_cabinet': [20, 21],
    'microwave': [22],
}

class ConvertResetFreeEnvWrapper:
  """
  Given a ResetFree env that returns obs array (observation, goal), we modify obs array to obs dict contain {'observation', 'goal'}.
  Only support obs and goal have the same length.
  """
  def __init__(self, env, obs_key='observation', goal_key='goal', obs_dim=None):
    self._env = env
    self.obs_key = obs_key
    self.goal_key = goal_key
    if obs_dim == None:
      self.obs_dim = self._env.observation_space.shape[0] // 2
    else:
      self.obs_dim = obs_dim

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def step(self, action):
    obs, reward, done, info = self._env.step(action)
    obs = {self.obs_key: obs[:self.obs_dim], self.goal_key: obs[self.obs_dim:]}
    return obs, reward, done, info

  def reset(self):
    obs = self._env.reset()
    obs = {self.obs_key: obs[:self.obs_dim], self.goal_key: obs[self.obs_dim:]}
    return obs

  def get_obs(self):
    try:
      obs = self._env.env.get_obs()
    except:
      obs = self._env.env._get_obs()
    obs = {self.obs_key: np.array(obs[:self.obs_dim]), self.goal_key: np.array(obs[self.obs_dim:])}
    return obs

  @property
  def observation_space(self):
    # just return dict with observation.
    return gym.spaces.Dict({self.obs_key: gym.spaces.Box(low=self._env.observation_space.low[:self.obs_dim], high=self._env.observation_space.high[:self.obs_dim]), self.goal_key: gym.spaces.Box(low=self._env.observation_space.low[self.obs_dim:], high=self._env.observation_space.high[self.obs_dim:])})


class ConvertLEXAEnvWrapper:
  def __init__(self, env, episode_horizon, reset_free_max_steps=150, goal_dim=9, obs_dim=11):
    self._env = env
    self._episode_horizon = episode_horizon
    self._steps_since_reset = 1
    self.reset_free_max_steps = reset_free_max_steps
    self.goal_dim = goal_dim
    self.obs_dim = obs_dim

  @property
  def observation_space(self):
    # just return dict with observation.
    return gym.spaces.Dict({'qpos': gym.spaces.Box(low=self._env.observation_space['qpos'].low, high=self._env.observation_space['qpos'].high), 'goal': gym.spaces.Box(low=self._env.observation_space['qpos'].low[:self.goal_dim], high=self._env.observation_space['qpos'].high[:self.goal_dim])})

  @property
  def action_space(self):
    # just return dict with observation.
    return self._env.action_space

  def reset(self):
    if self._steps_since_reset >= self._episode_horizon:
      self._steps_since_reset = 1
      return self._env.reset()
    else:
      state = self._env._env._get_obs()
      obs = self._env._get_obs(state)
      obs['qpos'] = obs.pop('state')
      return obs

  def step(self, action):
    obs, reward, done, info = self._env.step(action)
    if self._steps_since_reset >= self._episode_horizon:
      done = True
    else:
      done = False
      if self._steps_since_reset % self.reset_free_max_steps == 0:
        done = True
    self._steps_since_reset += 1
    return obs, reward, done, info

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def compute_reward(self, obs_goal):
    achieved_goal = obs_goal[:self.goal_dim]
    desired_goal = obs_goal[self.obs_dim:]

    # reward fn for robobin env.
    hand_distance = np.linalg.norm(achieved_goal[:3] -  desired_goal[:3])
    obj1_distance = np.linalg.norm(achieved_goal[3:6] - desired_goal[3:6])
    obj2_distance = np.linalg.norm(achieved_goal[6:9] - desired_goal[6:9])
    reward = -obj1_distance -obj2_distance
    success = float((obj1_distance < 0.1) and (obj2_distance < 0.1))

    return reward


class ConvertIBCEnvWrapper:
  """
  Given a ResetFree env that returns obs array (observation, goal), we modify obs array to obs dict contain {'observation', 'goal'}.
  Only support obs and goal have the same length.
  """
  def __init__(self, env, obs_key='observation', goal_key='goal', obs_dim=10, goal_dim=3, if_object=False):
    self._env = env
    self.obs_key = obs_key
    self.goal_key = goal_key
    self.obs_dim = obs_dim
    self.goal_dim = goal_dim
    self.if_object = if_object

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def step(self, action):
    obs, reward, done, info = self._env.step(action)
    obs = {self.obs_key: obs['observation'], self.goal_key: obs['desired_goal']}
    reward = self.compute_reward(np.concatenate((obs[self.obs_key], obs[self.goal_key])))
    return obs, reward, done, info

  def reset(self):
    obs = self._env.reset()
    obs = {self.obs_key: obs['observation'], self.goal_key: obs['desired_goal']}
    return obs

  def get_obs(self):
    try:
      obs = self._env.env.get_obs()
    except:
      obs = self._env.env._get_obs()
    obs = {self.obs_key: np.array(obs['observation']), self.goal_key: np.array(obs['desired_goal'])}
    return obs

  def compute_reward(self, obs_goal):
    if self.if_object:
      achieved_goal = obs_goal[3:6]
    else:
      achieved_goal = obs_goal[:self.goal_dim]
    desired_goal = obs_goal[self.obs_dim:]
    reward = self._env.compute_reward(achieved_goal, desired_goal, info=None)
    success = self._env._is_success(achieved_goal, desired_goal)
    #return reward
    return success

  @property
  def observation_space(self):
    # just return dict with observation.
    return gym.spaces.Dict({self.obs_key: gym.spaces.Box(low=self._env.observation_space['observation'].low, high=self._env.observation_space['observation'].high), self.goal_key: gym.spaces.Box(low=self._env.observation_space['desired_goal'].low, high=self._env.observation_space['desired_goal'].high)})

class ResetFreeDoneFlagEnvWrapper:
  """
      Convert a reset-free env to an env that returns a done=True flag every n steps, and empty the reset() to return the same current state (no reset).
  """
  def __init__(self, env, n=200):
    self._env = env
    self.n = n
    self.n_cntr = 0

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def step(self, action):
    self.n_cntr += 1
    obs, reward, done, info = self._env.step(action)
    reward = 0 # remove all reward, recompute it in the training loop, original PEG does not depend on it, while if we wanna use reward estimator from Dreamer, it does depend on the env reward.
    if self.n_cntr >= self.n:
      done = True
    return obs, reward, done, info

  def reset(self):
    self.n_cntr = 0
    if self._steps_since_reset >= self._episode_horizon: # if the env already goes beyond maximum long episode horizon, then reset the env.
      return self._env.reset()
    # return the same current obs
    return self._env.get_obs()

class EpisodicDoneFlagEnvWrapper:
  """
      Convert a reset-free env to episodic env.
  """
  def __init__(self, env, n=200):
    self._env = env
    self.n = n
    self.n_cntr = 0

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def step(self, action):
    self.n_cntr += 1
    obs, reward, done, info = self._env.step(action)
    reward = 0 # remove all reward, recompute it in the training loop, original PEG does not depend on it, while if we wanna use reward estimator from Dreamer, it does depend on the env reward.
    if self.n_cntr >= self.n:
      done = True
    return obs, reward, done, info

  def reset(self):
    self.n_cntr = 0
    return self._env.reset()

class EvalEnvWrapper:
  """
      Add more features:
      - pertask_success info used for logging
      - set goal idx for evalutaion
  """
  def __init__(self, env, all_goals=None):
    self._env = env
    self.all_goals = all_goals
    self.goal_idx = 0

  def step(self, action):
    obs, reward, done, info = self._env.step(action)
    info = self.add_pertask_success(info, reward, self.goal_idx)
    return obs, reward, done, info

  def reset(self):
    self._env.reset()
    self._env.reset_goal(self.all_goals[self.goal_idx])
    return self._env.get_obs()

  def add_pertask_success(self, info, reward, goal_idx):
    info[f"metric_success/goal_{goal_idx}"] = reward
    return info

  def set_goal_idx(self, idx):
    self.goal_idx = idx

  def get_goal_idx(self):
    return self.goal_idx

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def get_metrics_dict(self):
    info = {}
    info = self.add_pertask_success(info, reward=0, goal_idx=self.goal_idx)
    return info

class EvalEnvRNDGoalWrapper:
  """
      Add more features:
      - pertask_success info used for logging
      - set goal idx for evalutaion
  """
  def __init__(self, env):
    self._env = env

  def step(self, action):
    obs, reward, done, info = self._env.step(action)
    info = self.add_pertask_success(info, reward)
    return obs, reward, done, info

  def reset(self):
    self._env.reset()
    #self._env.reset_goal()
    return self._env.get_obs()

  def add_pertask_success(self, info, reward):
    info[f"metric_success/goal_0"] = reward
    return info

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def get_metrics_dict(self):
    info = {}
    info = self.add_pertask_success(info, reward=0)
    return info

class TabletopLooseEvalEnvWrapper:
  """
    Give a higher distance threshold to make eval easier. By default, it is 0.2/
  """
  def __init__(self, env, eval_distance_threshold=0.2):
    self._env = env
    self.eval_distance_threshold = eval_distance_threshold

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def is_successful(self, obs=None):
    if obs is None:
      obs = self._get_obs()

    if self._wide_init_distr:
      return np.linalg.norm(obs[2:4] - obs[8:-2]) <= self.eval_distance_threshold
    else:
      return np.linalg.norm(obs[:4] - obs[6:-2]) <= self.eval_distance_threshold

  def compute_reward(self, obs):
    if self._reward_type == "sparse":
      reward = float(self.is_successful(obs=obs))
    elif self._reward_type == "dense":
      # remove gripper, attached object from reward computation
      reward = -np.linalg.norm(obs[2:4] - obs[8:-2])
      for obj_idx in range(1, 2):
        reward += 2. * np.exp(
            -(np.linalg.norm(obs[2 * obj_idx:2 * obj_idx + 2] -
                             obs[2 * obj_idx + 6:2 * obj_idx + 8])**2) / 0.01)

        grip_to_object = 0.5 * np.linalg.norm(obs[:2] - obs[2:4])
        reward += -grip_to_object
        reward += 0.5 * np.exp(-(grip_to_object**2) / 0.01)
    return reward 

  def step(self, action):
   # rescale and clip action
    next_obs, reward, done, _ = self._env.step(action)
    reward = self.compute_reward(next_obs)
    return next_obs, reward, done, {}

class DoorLooseEvalEnvWrapper:
  """
    Give a higher distance threshold to make eval easier. By default, it is 0.2/
  """
  def __init__(self, env, eval_distance_threshold=0.08): # in paper 0.08, in codebase 0.02
    self._env = env
    self.eval_distance_threshold = eval_distance_threshold

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def is_successful(self, obs=None):
    if obs is None:
      obs = self._get_obs()
    
    return np.linalg.norm(obs[4:7] - obs[11:14]) <= self.eval_distance_threshold

  def compute_reward(self, obs):
    if self._reward_type == 'sparse':
      reward = float(self.is_successful(obs=obs))
    else:
      raise NotImplementedError
    return reward 

  def step(self, action):
   # rescale and clip action
    next_obs, reward, done, _ = self._env.step(action)
    reward = self.compute_reward(next_obs)
    return next_obs, reward, done, {}

class DoorLimitXYZEnvWrapper:
  """
    Give a higher distance threshold to make eval easier. By default, it is 0.2/
  """
  def __init__(self, env, workspace_min=None, workspace_max=None): # in paper 0.08, in codebase 0.02
    self._env = env
    self.workspace_min = workspace_min
    self.workspace_max = workspace_max

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def compute_reward(self, obs):
    if self._reward_type == 'sparse':
      reward = float(self.is_successful(obs=obs))
    else:
      raise NotImplementedError
    return reward 

  def step(self, action):

    # check if action is out of bounds
    action = action.copy()
    curr_eef_state = self.env.get_endeff_pos()
    next_eef_state = curr_eef_state + (action[:3] * 0.01) #
    next_eef_state = np.clip(next_eef_state, self.workspace_min, self.workspace_max)
    clipped_ac = (next_eef_state - curr_eef_state) / 0.01
    action[:3] = clipped_ac

    next_obs, reward, done, _ = self._env.step(action)
    reward = self.compute_reward(next_obs)
    return next_obs, reward, done, {}

class PegLooseEvalEnvWrapper:
  """
    Give a higher distance threshold to make eval easier. By default, it is 0.2/
  """
  def __init__(self, env, eval_distance_threshold=0.05):
    self._env = env
    self.eval_distance_threshold = eval_distance_threshold

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def is_successful(self, obs=None):
    if obs is None:
      obs = self._get_obs()
    
    return np.linalg.norm(obs[4:7] - obs[11:14]) <= self.TARGET_RADIUS

  def compute_reward(self, obs):
    if self._reward_type == 'sparse':
      reward = float(self.is_successful(obs=obs))
    else:
      raise NotImplementedError
    return reward 

  def step(self, action):
   # rescale and clip action
    next_obs, reward, done, _ = self._env.step(action)
    reward = self.compute_reward(next_obs)
    return next_obs, reward, done, {}

class PenLooseEvalEnvWrapper:
  """
    Give a higher distance threshold to make eval easier. By default, it is 0.2/
  """
  def __init__(self, env, eval_distance_threshold=0.1):
    self._env = env
    self.eval_distance_threshold = eval_distance_threshold

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def is_successful(self, obs=None):
    if obs is None:
      obs = self._get_obs()
    if obs is None:
      obs = self._get_observation()
    current_pos = np.array(obs[-4:-2])
    goal_pos = np.array(obs[-2:])
    if np.sqrt(np.sum((current_pos - goal_pos)**2)) < self.eval_distance_threshold:
      return 1.0
    else:
      return 0.0
    
  def compute_reward(self, obs):
    x_dist = obs[28] - obs[30] 
    y_dist = obs[29] - obs[31] 
    distance_reward = -abs(x_dist) - abs(y_dist)
    energy_reward = np.abs(np.dot(obs[8:16], obs[16:24])) * self._time_step
    reward = (self._distance_weight * distance_reward - self._energy_weight * energy_reward)
    return reward 

  def step(self, action):
   # rescale and clip action
    next_obs, reward, done, _ = self._env.step(action)
    reward = self.compute_reward(next_obs)
    return next_obs, reward, done, {}

class KitchenLooseEvalEnvWrapper:
  """
    Give a higher distance threshold to make eval easier. By default, it is 0.2/
  """
  def __init__(self, env):
    self._env = env

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def _get_reward_n_score(self, obs_dict):
    reward_dict = {}
    if isinstance(obs_dict, dict):
      obs = np.append(np.append(obs_dict['qp'], obs_dict['obj_qp']), obs_dict['goal'])
    else:
      obs = obs_dict
    
    task_to_site = {'microwave': 'microhandle_site',
                    'hinge_cabinet': 'hinge_site2',
                    'slide_cabinet': 'slide_site',
                    'burner0': 'knob1_site',
                    'burner1': 'knob2_site',
                    'burner2': 'knob3_site',
                    'burner3': 'knob4_site',
                    'light_switch': 'light_site',}

    reward_dict['true_reward'] = -10 * np.linalg.norm(obs[9:23] - obs[9+23:23+23])
    
    reaching_component = False
    for key in component_to_state_idx.keys():
      if key == 'arm':
        continue

      cur_idxs = np.array(component_to_state_idx[key])
      num_idxs = len(component_to_state_idx[key])
      if np.linalg.norm(obs[cur_idxs] - obs[cur_idxs + 23]) < num_idxs * 0.01:
        reward_dict['true_reward'] += 1
      elif not reaching_component:
        reaching_component = True
        reward_dict['true_reward'] += -0.5 * np.linalg.norm(self.sim.data.mocap_pos[0] - \
                                            self.sim.data.get_site_xpos(task_to_site[key]))
    reward_dict['r_total'] = reward_dict['true_reward']

    score = 0.
    return reward_dict, score

  def is_successful(self, obs=None):
    if obs is None:
      obs = self._get_obs()
    return bool(np.linalg.norm(obs[9:23] - obs[9+23:23+23]) <= 0.3)
    
  def compute_reward(self, obs):
    return self._get_reward_n_score(obs)[0]['r_total']

  def step(self, action):
   # rescale and clip action
    next_obs, reward, done, _ = self._env.step(action)
    reward = self.compute_reward(next_obs)
    return next_obs, reward, done, {}

class PointMaze2ResetfreeEnvWrapper:
  """
      Add basic fn for episodic pointmaze to make it as a reset-free env,
  """
  def __init__(self, env):
    self._env = env

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def is_successful(self, obs=None):
    if obs is None:
      obs = self._get_obs()

    if self._wide_init_distr:
      return np.linalg.norm(obs[2:4] - obs[8:-2]) <= self.eval_distance_threshold
    else:
      return np.linalg.norm(obs[:4] - obs[6:-2]) <= self.eval_distance_threshold

  def compute_reward(self, obs):
    if self._reward_type == "sparse":
      reward = float(self.is_successful(obs=obs))
    elif self._reward_type == "dense":
      # remove gripper, attached object from reward computation
      reward = -np.linalg.norm(obs[2:4] - obs[8:-2])
      for obj_idx in range(1, 2):
        reward += 2. * np.exp(
            -(np.linalg.norm(obs[2 * obj_idx:2 * obj_idx + 2] -
                             obs[2 * obj_idx + 6:2 * obj_idx + 8])**2) / 0.01)

        grip_to_object = 0.5 * np.linalg.norm(obs[:2] - obs[2:4])
        reward += -grip_to_object
        reward += 0.5 * np.exp(-(grip_to_object**2) / 0.01)
    return reward 

  def step(self, action):
   # rescale and clip action
    next_obs, reward, done, _ = self._env.step(action)
    reward = self.compute_reward(next_obs)
    done = False
    return next_obs, reward, done, {}

class PointMazeResetfreeWrapper:
  """
    Give a higher distance threshold to make eval easier. By default, it is 0.2/
  """
  def __init__(self, env):
    self._env = env

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def step(self, action):
   # rescale and clip action
    next_obs, reward, done, info = self._env.step(action)
    done = False
    return next_obs, reward, done, info

  def get_obs(self):
    s_xy = self.s_xy
    g_xy = self.g_xy
    obs_dict = {
                'observation': s_xy,
                'achieved_goal': s_xy,
                'desired_goal': g_xy,
                }
    return obs_dict

  def compute_reward(self, obs_goal):
    achieved_goal = obs_goal[:2]
    desired_goal = obs_goal[2:]
    return self._env.compute_reward(achieved_goal, desired_goal, info=None)

class AntUMazeResetfreeWrapper:
  """
    Give a higher distance threshold to make eval easier. By default, it is 0.2/
  """
  def __init__(self, env, goal_dim):
    self._env = env
    self.g_dim = goal_dim

  def __getattr__(self, name):
    if name.startswith('__'):
      raise AttributeError(name)
    try:
      return getattr(self._env, name)
    except AttributeError:
      raise ValueError(name)

  def step(self, action):
   # rescale and clip action
    next_obs, reward, done, info = self._env.step(action)
    success = np.allclose(0., reward)
    done = False
    for k in [*next_obs.keys()]:
      if 'observation' not in k:
        next_obs[k] = next_obs[k][:self.g_dim]
    return next_obs, success, done, info

  def get_obs(self):
    s_xy = self.maze._get_obs()
    g_xy = self.g_xy
    obs_dict = {
                'observation': s_xy,
                #'achieved_goal': s_xy,
                'achieved_goal': s_xy[:self.g_dim],
                'desired_goal': g_xy, # this goal would not be used, since peg will set goal anyway. 
                }
    return obs_dict

  def reset(self):
    obs = self._env.reset()
    for k in [*obs.keys()]:
      if 'observation' not in k:
        obs[k] = obs[k][:self.g_dim]
    return obs

  def compute_reward(self, obs_goal):
    #obs_dim = len(obs_goal)
    achieved_goal = obs_goal[:self.g_dim]
    desired_goal = obs_goal[-self.g_dim:]
    reward = self._env.compute_reward(achieved_goal, desired_goal, info=None)
    success = np.allclose(0., reward)
    return success

  @property
  def observation_space(self):
    # just return dict with observation.
    return gym.spaces.Dict({'observation': gym.spaces.Box(low=self._env.observation_space['observation'].low, high=self._env.observation_space['observation'].high), 'desired_goal': gym.spaces.Box(low=self._env.observation_space['desired_goal'].low[:self.g_dim], high=self._env.observation_space['desired_goal'].high[:self.g_dim])})

class Episodic2Resetfree:
  def __init__(self, env, episode_horizon):
    self._env = env
    self._episode_horizon = episode_horizon
    self._total_step_count = 0
    self._steps_since_reset = 0
    self._num_interventions = 0

  def reset(self):
    self._num_interventions += 1
    self._steps_since_reset = 0
    return self._env.reset()

  def step(self, action):
    obs, reward, done, info = self._env.step(action)
    
    self._total_step_count += 1
    self._steps_since_reset += 1

    if not done and self._steps_since_reset >= self._episode_horizon:
      done = True

    return obs, reward, done, info

  def is_successful(self, obs=None):
    if hasattr(self._env, 'is_successful'):
        return self._env.is_successful(obs)
    else:
        return False

  @property
  def num_interventions(self):
    return self._num_interventions

  @property
  def total_steps(self):
    return self._total_step_count
  
  def __getattr__(self, name):
    return getattr(self._env, name)
