# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Wrapper for creating the ant environment."""

import math
import numpy as np
import mujoco_py
import os

from gym import utils
from gym.envs.mujoco import mujoco_env
from d4rl.locomotion import mujoco_goal_env

from d4rl.locomotion import goal_reaching_env
from d4rl.locomotion import maze_env
from d4rl import offline_env
from d4rl.locomotion import wrappers

GYM_ASSETS_DIR = os.path.join(
    os.path.dirname(mujoco_goal_env.__file__),
    'assets')

class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
  """Basic ant locomotion environment."""
  FILE = os.path.join(GYM_ASSETS_DIR, 'ant.xml')

  def __init__(self, file_path=None, expose_all_qpos=False,
               expose_body_coms=None, expose_body_comvels=None, non_zero_reset=False,
               goal_conditioned=False, task_conditioned=False, task_idx=None):
    if file_path is None:
      file_path = self.FILE

    self._expose_all_qpos = expose_all_qpos
    self._expose_body_coms = expose_body_coms
    self._expose_body_comvels = expose_body_comvels
    self._body_com_indices = {}
    self._body_comvel_indices = {}

    self._non_zero_reset = non_zero_reset
    self.goal_conditioned = goal_conditioned
    self.task_conditioned = task_conditioned
    self.task_idx = task_idx

    mujoco_env.MujocoEnv.__init__(self, file_path, 5)
    utils.EzPickle.__init__(self)

  @property
  def physics(self):
    # Check mujoco version is greater than version 1.50 to call correct physics
    # model containing PyMjData object for getting and setting position/velocity.
    # Check https://github.com/openai/mujoco-py/issues/80 for updates to api.
    if mujoco_py.get_version() >= '1.50':
      return self.sim
    else:
      return self.model

  def _step(self, a):
    return self.step(a)

  def step(self, a):
    xposbefore = self.get_body_com("torso")[0]
    self.do_simulation(a, self.frame_skip)
    xposafter = self.get_body_com("torso")[0]
    forward_reward = (xposafter - xposbefore) / self.dt
    ctrl_cost = .5 * np.square(a).sum()
    contact_cost = 0.5 * 1e-3 * np.sum(
        np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
    survive_reward = 1.0
    reward = forward_reward - ctrl_cost - contact_cost + survive_reward
    state = self.state_vector()
    notdone = np.isfinite(state).all() \
        and state[2] >= 0.2 and state[2] <= 1.0
    done = not notdone
    ob = self._get_obs()
    return ob, reward, done, dict(
        reward_forward=forward_reward,
        reward_ctrl=-ctrl_cost,
        reward_contact=-contact_cost,
        reward_survive=survive_reward)

  def _get_obs(self):
    # No cfrc observation.
    if self._expose_all_qpos:
      obs = np.concatenate([
          self.physics.data.qpos.flat[:15],  # Ensures only ant obs.
          self.physics.data.qvel.flat[:14],
      ])
    else:
      obs = np.concatenate([
          self.physics.data.qpos.flat[2:15],
          self.physics.data.qvel.flat[:14],
      ])

    if self._expose_body_coms is not None:
      for name in self._expose_body_coms:
        com = self.get_body_com(name)
        if name not in self._body_com_indices:
          indices = range(len(obs), len(obs) + len(com))
          self._body_com_indices[name] = indices
        obs = np.concatenate([obs, com])

    if self._expose_body_comvels is not None:
      for name in self._expose_body_comvels:
        comvel = self.get_body_comvel(name)
        if name not in self._body_comvel_indices:
          indices = range(len(obs), len(obs) + len(comvel))
          self._body_comvel_indices[name] = indices
        obs = np.concatenate([obs, comvel])
    if self.task_conditioned:
      # obs = np.concatenate([obs, np.array([6,])], 0)
      one_hot = np.zeros([7,])
      if self.task_idx is None:
        one_hot[-1] = 1.0
      else:
        one_hot[self.task_idx] = 1.0
      obs = np.concatenate([obs, one_hot], 0)
    return obs

  def reset_model(self):
    qpos = self.init_qpos + self.np_random.uniform(
        size=self.model.nq, low=-.1, high=.1)
    qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1

    if self._non_zero_reset:
      """Now the reset is supposed to be to a non-zero location"""
      reset_location = self._get_reset_location()
      qpos[:2] = reset_location

    # Set everything other than ant to original position and 0 velocity.
    qpos[15:] = self.init_qpos[15:]
    qvel[14:] = 0.
    self.set_state(qpos, qvel)
    return self._get_obs()

  def viewer_setup(self):
    self.viewer.cam.distance = self.model.stat.extent * 0.5

  def get_xy(self):
    return self.physics.data.qpos[:2]

  def set_xy(self, xy):
    qpos = np.copy(self.physics.data.qpos)
    qpos[0] = xy[0]
    qpos[1] = xy[1]
    qvel = self.physics.data.qvel
    self.set_state(qpos, qvel)
  

class GoalReachingAntEnv(goal_reaching_env.GoalReachingEnv, AntEnv):
  """Ant locomotion rewarded for goal-reaching."""
  BASE_ENV = AntEnv

  def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,
               file_path=None,
               expose_all_qpos=False, non_zero_reset=False, eval=False, reward_type='dense',
               goal_conditioned=False, task_conditioned=False, task_idx=None, **kwargs):
    goal_reaching_env.GoalReachingEnv.__init__(self, goal_sampler, eval=eval, reward_type=reward_type, goal_conditioned=goal_conditioned, task_idx=task_idx)
    AntEnv.__init__(self,
                    file_path=file_path,
                    expose_all_qpos=expose_all_qpos,
                    expose_body_coms=None,
                    expose_body_comvels=None,
                    non_zero_reset=non_zero_reset,
                    goal_conditioned=goal_conditioned,
                    task_conditioned=task_conditioned,
                    task_idx=task_idx)

class AntMazeEnv(maze_env.MazeEnv, GoalReachingAntEnv, offline_env.OfflineEnv):
  """Ant navigating a maze."""
  LOCOMOTION_ENV = GoalReachingAntEnv

  def __init__(self, goal_sampler=None, expose_all_qpos=True,
               reward_type='dense', goal_conditioned=False,
               task_conditioned=False, task_idx=None,
               *args, **kwargs):
    if goal_sampler is None:
      goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand)
    if goal_conditioned:
      task_conditioned = False
    if task_conditioned:
      goal_conditioned = False
    maze_env.MazeEnv.__init__(
        self, *args, manual_collision=False,
        goal_sampler=goal_sampler,
        expose_all_qpos=expose_all_qpos,
        reward_type=reward_type,
        goal_conditioned=goal_conditioned,
        task_conditioned=task_conditioned,
        task_idx=task_idx,
        **kwargs)
    offline_env.OfflineEnv.__init__(self, **kwargs)
    ## We set the target foal here for evaluation
    self.set_target()
  
  def set_target(self, target_location=None):
    return self.set_target_goal(target_location)

  def seed(self, seed=0):
      mujoco_env.MujocoEnv.seed(self, seed)

  # def step(self, action):
  #   next_obs, reward, done, info = super().step(action)
  #   print ('next_obs_shape: ', next_obs.shape)
  #   if self.goal_conditioned:
  #     return np.concatenate([next_obs, np.array([6,])], 0), reward, done, info
  #   else:
  #     return next_obs, reward, done, info

  # def reset_model(self):
  #   obs = super().reset_model()
  #   print ('obs_shape: ', obs.shape)
  #   if self.goal_conditioned:
  #     return np.concatenate([obs, np.array([6,])], 0)
  #   else:
  #     return obs

def make_ant_maze_env(**kwargs):
  env = AntMazeEnv(**kwargs)
  return wrappers.NormalizedBoxEnv(env)
  
