# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Wrapper for creating the ant environment."""

import math
import numpy as np
import mujoco_py
import os

from gym import utils
from gym.envs.mujoco import mujoco_env
from d4rl_ext.locomotion import mujoco_goal_env

from d4rl_ext.locomotion import goal_reaching_env
from d4rl_ext.locomotion import maze_env
from d4rl_ext import offline_env
from d4rl_ext.locomotion import wrappers

GYM_ASSETS_DIR = os.path.join(os.path.dirname(mujoco_goal_env.__file__), "assets")


class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    """Basic ant locomotion environment."""

    FILE = os.path.join(GYM_ASSETS_DIR, "ant.xml")

    def __init__(
        self,
        file_path=None,
        expose_all_qpos=False,
        expose_body_coms=None,
        expose_body_comvels=None,
        non_zero_reset=False,
    ):
        if file_path is None:
            file_path = self.FILE

        self._expose_all_qpos = expose_all_qpos
        self._expose_body_coms = expose_body_coms
        self._expose_body_comvels = expose_body_comvels
        self._body_com_indices = {}
        self._body_comvel_indices = {}

        self._non_zero_reset = non_zero_reset

        mujoco_env.MujocoEnv.__init__(self, file_path, 5)
        utils.EzPickle.__init__(self)

    @property
    def physics(self):
        # Check mujoco version is greater than version 1.50 to call correct physics
        # model containing PyMjData object for getting and setting position/velocity.
        # Check https://github.com/openai/mujoco-py/issues/80 for updates to api.
        if mujoco_py.get_version() >= "1.50":
            return self.sim
        else:
            return self.model

    def _step(self, a):
        return self.step(a)

    def step(self, a):
        xposbefore = self.get_body_com("torso")[0]
        self.do_simulation(a, self.frame_skip)
        xposafter = self.get_body_com("torso")[0]
        forward_reward = (xposafter - xposbefore) / self.dt
        ctrl_cost = 0.5 * np.square(a).sum()
        contact_cost = (
            0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
        )
        survive_reward = 1.0
        reward = forward_reward - ctrl_cost - contact_cost + survive_reward
        state = self.state_vector()
        notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0
        done = not notdone
        ob = self._get_obs()
        return (
            ob,
            reward,
            done,
            dict(
                reward_forward=forward_reward,
                reward_ctrl=-ctrl_cost,
                reward_contact=-contact_cost,
                reward_survive=survive_reward,
            ),
        )

    def _get_obs(self):
        # No cfrc observation.
        if self._expose_all_qpos:
            obs = np.concatenate(
                [
                    self.physics.data.qpos.flat[:15],  # Ensures only ant obs.
                    self.physics.data.qvel.flat[:14],
                ]
            )
        else:
            obs = np.concatenate(
                [
                    self.physics.data.qpos.flat[2:15],
                    self.physics.data.qvel.flat[:14],
                ]
            )

        if self._expose_body_coms is not None:
            for name in self._expose_body_coms:
                com = self.get_body_com(name)
                if name not in self._body_com_indices:
                    indices = range(len(obs), len(obs) + len(com))
                    self._body_com_indices[name] = indices
                obs = np.concatenate([obs, com])

        if self._expose_body_comvels is not None:
            for name in self._expose_body_comvels:
                comvel = self.get_body_comvel(name)
                if name not in self._body_comvel_indices:
                    indices = range(len(obs), len(obs) + len(comvel))
                    self._body_comvel_indices[name] = indices
                obs = np.concatenate([obs, comvel])
        return obs

    def reset_model(self):
        qpos = self.init_qpos + self.np_random.uniform(
            size=self.model.nq, low=-0.1, high=0.1
        )
        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1

        if self._non_zero_reset:
            """Now the reset is supposed to be to a non-zero location"""
            reset_location = self._get_reset_location()
            qpos[:2] = reset_location

        # Set everything other than ant to original position and 0 velocity.
        qpos[15:] = self.init_qpos[15:]
        qvel[14:] = 0.0
        self.set_state(qpos, qvel)
        return self._get_obs()

    def viewer_setup(self):
        self.viewer.cam.distance = self.model.stat.extent * 0.5

    def get_xy(self):
        return self.physics.data.qpos[:2]

    def set_xy(self, xy):
        qpos = np.copy(self.physics.data.qpos)
        qpos[0] = xy[0]
        qpos[1] = xy[1]
        qvel = self.physics.data.qvel
        self.set_state(qpos, qvel)


class GoalReachingAntEnv(goal_reaching_env.GoalReachingEnv, AntEnv):
    """Ant locomotion rewarded for goal-reaching."""

    BASE_ENV = AntEnv

    def __init__(
        self,
        goal_sampler=goal_reaching_env.disk_goal_sampler,
        file_path=None,
        expose_all_qpos=False,
        non_zero_reset=False,
        eval=False,
        reward_type="dense",
        **kwargs
    ):
        goal_reaching_env.GoalReachingEnv.__init__(
            self, goal_sampler, eval=eval, reward_type=reward_type
        )
        AntEnv.__init__(
            self,
            file_path=file_path,
            expose_all_qpos=expose_all_qpos,
            expose_body_coms=None,
            expose_body_comvels=None,
            non_zero_reset=non_zero_reset,
        )


class AntMazeEnv(maze_env.MazeEnv, GoalReachingAntEnv, offline_env.OfflineEnv):
    """Ant navigating a maze."""

    LOCOMOTION_ENV = GoalReachingAntEnv

    def __init__(
        self,
        goal_sampler=None,
        expose_all_qpos=True,
        reward_type="dense",
        v2_resets=False,
        *args,
        **kwargs
    ):
        if goal_sampler is None:
            goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(self, np_rand)
        maze_env.MazeEnv.__init__(
            self,
            *args,
            manual_collision=False,
            goal_sampler=goal_sampler,
            expose_all_qpos=expose_all_qpos,
            reward_type=reward_type,
            **kwargs
        )
        offline_env.OfflineEnv.__init__(self, **kwargs)

        ## We set the target foal here for evaluation
        self.set_target()
        self.v2_resets = v2_resets

    def reset(self):
        if self.v2_resets:
            """
            The target goal for evaluation in antmazes is randomized.
            antmazes-v0 and -v1 resulted in really high-variance evaluations
            because the target goal was set once at the seed level. This led to
            each run running evaluations with one particular goal. To accurately
            cover each goal, this requires about 50-100 seeds, which might be
            computationally infeasible. As an alternate fix, to reduce variance
            in result reporting, we are creating the v2 environments
            which use the same offline dataset as v0 environments, with the distinction
            that the randomization of goals during evaluation is performed at the level of
            each rollout. Thus running a few seeds, but performing the final evaluation
            over 100-200 episodes will give a valid estimate of an algorithm's performance.
            """
            self.set_target()
        return super().reset()

    def set_target(self, target_location=None):
        return self.set_target_goal(target_location)

    def seed(self, seed=0):
        mujoco_env.MujocoEnv.seed(self, seed)


def make_ant_maze_env(**kwargs):
    env = AntMazeEnv(**kwargs)
    return wrappers.NormalizedBoxEnv(env)
