# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for creating the point environment."""

import math
import os

import mujoco_py
import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env

from d4rl.locomotion import (goal_reaching_env, maze_env, mujoco_goal_env,
                             wrappers)

MY_ASSETS_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                             'assets')


class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    FILE = os.path.join(MY_ASSETS_DIR, 'point.xml')

    def __init__(self,
                 file_path=None,
                 expose_all_qpos=False,
                 non_zero_reset=False):
        if file_path is None:
            file_path = self.FILE

        self._expose_all_qpos = expose_all_qpos

        self._non_zero_reset = non_zero_reset

        mujoco_env.MujocoEnv.__init__(self, file_path, 1)
        # mujoco_goal_env.MujocoGoalEnv.__init__(self, file_path, 1)
        utils.EzPickle.__init__(self)

    @property
    def physics(self):
        # Check mujoco version is greater than version 1.50 to call correct physics
        # model containing PyMjData object for getting and setting position/velocity.
        # Check https://github.com/openai/mujoco-py/issues/80 for updates to api.
        if mujoco_py.get_version() >= '1.50':
            return self.sim
        else:
            return self.model

    def _step(self, a):
        return self.step(a)

    def step(self, action):
        action = np.clip(action, -1., 1.)
        action = 0.1 * action
        qpos = np.copy(self.physics.data.qpos)
        # Compute increment in each direction.
        dx = action[0]
        dy = action[1]
        # Collision detection.
        rowcol = self._xy_to_rowcol((qpos[0] + dx, qpos[1] + dy))
        if self._maze_map[rowcol[0]][rowcol[1]] != 1:
            # Ensure that the robot is within reasonable range.
            qpos[0] = np.clip(qpos[0] + dx, -2,
                              4 * len(self._maze_map[0]) - 10)
            qpos[1] = np.clip(qpos[1] + dy, -2, 4 * len(self._maze_map) - 10)
            qvel = self.physics.data.qvel
            self.set_state(qpos, qvel)
            for _ in range(0, self.frame_skip):
                self.physics.step()
            self.clip_velocity()
        next_obs = self._get_obs()
        reward = 0
        done = False
        info = {}
        return next_obs, reward, done, info

    def clip_velocity(self):
        qvel = np.clip(self.sim.data.qvel, -0, 0)
        self.set_state(self.sim.data.qpos, qvel)

    def _get_obs(self):
        if self._expose_all_qpos:
            return np.concatenate([
                self.physics.data.qpos.flat[:3],  # Only point-relevant coords.
                self.physics.data.qvel.flat[:3]
            ])
        return np.concatenate([
            self.physics.data.qpos.flat[2:3], self.physics.data.qvel.flat[:3]
        ])

    def reset_model(self):
        qpos = self.init_qpos + self.np_random.uniform(
            size=self.physics.model.nq, low=-.1, high=.1)
        qvel = self.init_qvel + self.np_random.randn(
            self.physics.model.nv) * .1

        if self._non_zero_reset:
            """Now the reset is supposed to be to a non-zero location"""
            reset_location = self._get_reset_location()
            qpos[:2] = reset_location

        # Set everything other than point to original position and 0 velocity.
        qpos[2:] = self.init_qpos[2:]
        qvel[2:] = 0.
        self.set_state(qpos, qvel)
        return self._get_obs()

    def viewer_setup(self):
        self.viewer.cam.lookat[0] = 4.
        self.viewer.cam.lookat[1] = 4.
        self.viewer.cam.lookat[2] = 8.
        self.viewer.cam.elevation = -90

    def get_xy(self):
        return self.physics.data.qpos[:2]

    def set_xy(self, xy):
        qpos = np.copy(self.physics.data.qpos)
        qpos[0] = xy[0]
        qpos[1] = xy[1]
        qvel = self.physics.data.qvel
        self.set_state(qpos, qvel)

    def reset_to_state(self, state):
        self.sim.reset()
        reset_state = np.array(state).astype(self.observation_space.dtype)
        qpos = self.init_qpos
        qvel = self.init_qvel
        qpos[:3] = reset_state[0:3]
        qvel[:3] = reset_state[3:6]
        self.set_state(qpos, qvel)
        return self._get_obs()


class GoalReachingPointEnv(goal_reaching_env.GoalReachingEnv, PointEnv):
    """Point locomotion rewarded for goal-reaching."""
    BASE_ENV = PointEnv

    def __init__(self,
                 goal_sampler=goal_reaching_env.disk_goal_sampler,
                 file_path=None,
                 expose_all_qpos=False,
                 non_zero_reset=False,
                 eval=False,
                 reward_type='dense',
                 **kwargs):
        goal_reaching_env.GoalReachingEnv.__init__(self,
                                                   goal_sampler,
                                                   eval=eval,
                                                   reward_type=reward_type)
        PointEnv.__init__(self,
                          file_path=file_path,
                          expose_all_qpos=expose_all_qpos,
                          non_zero_reset=non_zero_reset)


# class GoalReachingPointDictEnv(goal_reaching_env.GoalReachingDictEnv, PointEnv):
#   """Ant locomotion for goal reaching in a disctionary compatible format."""
#   BASE_ENV = PointEnv

#   def __init__(self, goal_sampler=goal_reaching_env.disk_goal_sampler,
#                file_path=None,
#                expose_all_qpos=False):
#     goal_reaching_env.GoalReachingDictEnv.__init__(self, goal_sampler)
#     PointEnv.__init__(self,
#                     file_path=file_path,
#                     expose_all_qpos=expose_all_qpos)


class PointMazeEnv(maze_env.MazeEnv, GoalReachingPointEnv):
    """Point navigating a maze."""
    LOCOMOTION_ENV = GoalReachingPointEnv

    def __init__(self,
                 goal_sampler=None,
                 expose_all_qpos=True,
                 *args,
                 **kwargs):
        if goal_sampler is None:
            goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(
                self, np_rand)
        maze_env.MazeEnv.__init__(self,
                                  *args,
                                  manual_collision=False,
                                  goal_sampler=goal_sampler,
                                  expose_all_qpos=expose_all_qpos,
                                  **kwargs)

        self.set_target()

    def set_target(self, target_location=None):
        return self.set_target_goal(target_location)


def create_goal_reaching_policy(obs_to_goal=lambda obs: obs[-2:],
                                obs_to_ori=lambda obs: obs[0]):
    """A hard-coded policy for reaching a goal position."""

    def policy_fn(obs):
        goal_x, goal_y = obs_to_goal(obs)
        goal_dist = np.linalg.norm([goal_x, goal_y])
        goal_ori = np.arctan2(goal_y, goal_x)
        ori = obs_to_ori(obs)
        ori_diff = (goal_ori - ori) % (2 * np.pi)

        radius = goal_dist / 2. / max(0.1, np.abs(np.sin(ori_diff)))
        rotation_left = (2 * ori_diff) % np.pi
        circumference_left = max(goal_dist, radius * rotation_left)

        speed = min(circumference_left * 2., 1.0)
        velocity = speed
        if ori_diff > np.pi / 2 and ori_diff < 3 * np.pi / 2:
            velocity *= -1

        time_left = min(circumference_left / (speed * 0.2), 10.)
        signed_ori_diff = ori_diff
        if signed_ori_diff >= 3 * np.pi / 2:
            signed_ori_diff = 2 * np.pi - signed_ori_diff
        elif signed_ori_diff > np.pi / 2 and signed_ori_diff < 3 * np.pi / 2:
            signed_ori_diff = signed_ori_diff - np.pi

        angular_velocity = signed_ori_diff / time_left
        angular_velocity = np.clip(angular_velocity * 2., -1., 1.)

        return np.array([velocity, angular_velocity])

    return policy_fn


def create_maze_navigation_policy(maze_env):
    """Creates a hard-coded policy to navigate a maze."""
    ori_index = 2 if maze_env._expose_all_qpos else 0
    obs_to_ori = lambda obs: obs[ori_index]

    goal_reaching_policy = create_goal_reaching_policy(obs_to_ori=obs_to_ori)
    goal_reaching_policy_fn = lambda obs, goal: goal_reaching_policy(
        np.concatenate([obs, goal]))

    return maze_env.create_navigation_policy(goal_reaching_policy_fn)


def make_point_maze_env(**kwargs):
    env = PointMazeEnv(**kwargs)
    return wrappers.NormalizedBoxEnv(env)
