"""Wrapper for creating the swimmer environment."""

import math
import os

import mujoco_py
import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env

from d4rl import offline_env
from d4rl.locomotion import (goal_reaching_env, maze_env, mujoco_goal_env,
                             wrappers)

MY_ASSETS_DIR = os.path.join(os.path.dirname(mujoco_goal_env.__file__),
                             'assets')


class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    """Basic swimmer locomotion environment."""
    FILE = os.path.join(MY_ASSETS_DIR, 'swimmer.xml')

    def __init__(self,
                 file_path=None,
                 expose_all_qpos=False,
                 non_zero_reset=False):
        if file_path is None:
            file_path = self.FILE

        self._expose_all_qpos = expose_all_qpos

        self._non_zero_reset = non_zero_reset

        mujoco_env.MujocoEnv.__init__(self,
                                      file_path,
                                      5,
                                      mujoco_bindings="mujoco_py")
        utils.EzPickle.__init__(self)

    @property
    def physics(self):
        # Check mujoco version is greater than version 1.50 to call correct physics
        # model containing PyMjData object for getting and setting position/velocity.
        # Check https://github.com/openai/mujoco-py/issues/80 for updates to api.
        if mujoco_py.get_version() >= '1.50':
            return self.sim
        else:
            return self.model

    def _step(self, a):
        return self.step(a)

    def step(self, a):
        ctrl_cost_coeff = 0.0001
        xposbefore = self.sim.data.qpos[0]
        self.do_simulation(a, self.frame_skip)
        xposafter = self.sim.data.qpos[0]
        reward_fwd = (xposafter - xposbefore) / self.dt
        reward_ctrl = -ctrl_cost_coeff * np.square(a).sum()
        reward = reward_fwd + reward_ctrl
        ob = self._get_obs()
        return ob, reward, False, dict(reward_fwd=reward_fwd,
                                       reward_ctrl=reward_ctrl)

    def _get_obs(self):
        if self._expose_all_qpos:
            obs = np.concatenate([
                self.physics.data.qpos.flat[:5],  # Ensures only swimmer obs.
                self.physics.data.qvel.flat[:5],
            ])
        else:
            obs = np.concatenate([
                self.physics.data.qpos.flat[2:5],
                self.physics.data.qvel.flat[:5],
            ])

        return obs

    def reset_model(self):
        qpos = self.init_qpos + self.np_random.uniform(
            size=self.model.nq, low=-.1, high=.1)
        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1

        if self._non_zero_reset:
            """Now the reset is supposed to be to a non-zero location"""
            reset_location = self._get_reset_location()
            qpos[:2] = reset_location

        # Set everything other than swimmer to original position and 0 velocity.
        qpos[5:] = self.init_qpos[5:]
        qvel[5:] = 0.
        self.set_state(qpos, qvel)
        return self._get_obs()

    def viewer_setup(self):
        self.viewer.cam.lookat[0] = 4.
        self.viewer.cam.lookat[1] = 4.
        self.viewer.cam.lookat[2] = 8.
        self.viewer.cam.elevation = -90

    def get_xy(self):
        return self.physics.data.qpos[:2]

    def set_xy(self, xy):
        qpos = np.copy(self.physics.data.qpos)
        qpos[0] = xy[0]
        qpos[1] = xy[1]
        qvel = self.physics.data.qvel
        self.set_state(qpos, qvel)


class GoalReachingSwimmerEnv(goal_reaching_env.GoalReachingEnv, SwimmerEnv):
    """Swimmer locomotion rewarded for goal-reaching."""
    BASE_ENV = SwimmerEnv

    def __init__(self,
                 goal_sampler=goal_reaching_env.disk_goal_sampler,
                 file_path=None,
                 expose_all_qpos=False,
                 non_zero_reset=False,
                 eval=False,
                 reward_type="dense",
                 **kwargs):
        goal_reaching_env.GoalReachingEnv.__init__(self,
                                                   goal_sampler,
                                                   eval=eval,
                                                   reward_type=reward_type)
        SwimmerEnv.__init__(self,
                            file_path=file_path,
                            expose_all_qpos=expose_all_qpos,
                            non_zero_reset=non_zero_reset)


class SwimmerMazeEnv(maze_env.MazeEnv, GoalReachingSwimmerEnv,
                     offline_env.OfflineEnv):
    """Swimmer navigating a maze."""
    LOCOMOTION_ENV = GoalReachingSwimmerEnv

    def __init__(self,
                 goal_sampler=None,
                 expose_all_qpos=True,
                 reward_type='dense',
                 *args,
                 **kwargs):
        if goal_sampler is None:
            goal_sampler = lambda np_rand: maze_env.MazeEnv.goal_sampler(
                self, np_rand)
        maze_env.MazeEnv.__init__(self,
                                  *args,
                                  manual_collision=False,
                                  goal_sampler=goal_sampler,
                                  expose_all_qpos=expose_all_qpos,
                                  reward_type=reward_type,
                                  **kwargs)
        offline_env.OfflineEnv.__init__(self, **kwargs)

        self.set_target()

    def set_target(self, target_location=None):
        return self.set_target_goal(target_location)


def make_swimmer_maze_env(**kwargs):
    env = SwimmerMazeEnv(**kwargs)
    return wrappers.NormalizedBoxEnv(env)
