from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import os

import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env

DEFAULT_CAMERA_CONFIG = {
    "trackbodyid": 2,
    "distance": 4.0,
    "lookat": np.array((0.0, 0.0, 1.15)),
    "elevation": -20.0,
}


class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    def __init__(self, terminate_when_unhealthy=True):
        self.prev_qpos = None
        dir_path = os.path.dirname(os.path.realpath(__file__))
        self._terminate_when_unhealthy = terminate_when_unhealthy
        mujoco_env.MujocoEnv.__init__(self, 'walker2d.xml', 4)
        utils.EzPickle.__init__(self)
        
    @property
    def done(self):
        done = not self.is_healthy if self._terminate_when_unhealthy else False
        return done
    
    def step(self, action):
        self.prev_qpos = np.copy(self.sim.data.qpos.flat)
        self.do_simulation(action, self.frame_skip)
        ob = self._get_obs()

        reward_ctrl = -0.001 * np.square(action).sum()
        reward_run = ob[0] - 0.0 * np.square(ob[2])
        reward = reward_run + reward_ctrl + 1.0

        done = self.done
        return ob, reward, done, {}
    
    
    @property
    def healthy_reward(self):
        return (
            float(self.is_healthy or self._terminate_when_unhealthy)
            * 1.0
        )

    @property
    def is_healthy(self):
        z, angle = self.sim.data.qpos[1:3]

        min_z, max_z = 0.8, 2.0
        min_angle, max_angle = -1.0, 1.0

        healthy_z = min_z < z < max_z
        healthy_angle = min_angle < angle < max_angle
        is_healthy = healthy_z and healthy_angle

        return is_healthy

    def _get_obs(self):
        return np.concatenate([
            (self.sim.data.qpos.flat[:1] - self.prev_qpos[:1]) / self.dt,
            self.sim.data.qpos.flat[1:],
            self.sim.data.qvel.flat,
        ])

    def reset_model(self):
        qpos = self.init_qpos + np.random.normal(loc=0, scale=0.005, size=self.model.nq)
        qvel = self.init_qvel + np.random.normal(loc=0, scale=0.005, size=self.model.nv)
        self.set_state(qpos, qvel)
        self.prev_qpos = np.copy(self.sim.data.qpos.flat)
        return self._get_obs()

    def reset_model_to_certain_state(self, state):
        # qpos = self.init_qpos + np.random.normal(loc=0, scale=0.001, size=self.model.nq)
        # qvel = self.init_qvel + np.random.normal(loc=0, scale=0.001, size=self.model.nv)
        
        qpos = state[:9]
        
        qvel = state[9:]
        self.set_state(qpos, qvel)
        self.prev_qpos = np.copy(self.sim.data.qpos.flat)
        return self._get_obs()

    def viewer_setup(self):
        for key, value in DEFAULT_CAMERA_CONFIG.items():
            if isinstance(value, np.ndarray):
                getattr(self.viewer.cam, key)[:] = value
            else:
                setattr(self.viewer.cam, key, value)
