from gym.envs.mujoco import half_cheetah
import numpy as np

class CheetahJump(half_cheetah.HalfCheetahEnv):
    def __init__(self, z_coef=0.0, max_episode_steps=200):
        self._goal_index = 0
        self.z_coef = z_coef
        self._max_episode_steps = max_episode_steps
        self.t = 0
        self.init_z = 0.0
        super(CheetahJump, self).__init__()

    def step(self, a):
        '''Modified to include z axis reward.'''
        (obs, r, done, info) = super(CheetahJump, self).step(a)
        self.t += 1
        info_new = {}
        info_new['z_position'] = obs[0]
        info_new['x_velocity'] = obs[8]      
        info_new['additional_r'] = r - info['reward_run']
        r_new, done = self._get_reward(info_new)
        return (obs, r_new, done, info_new)

    def _get_reward(self, info):
        # MOPO states the forward reward is max(v_x, 3), I suppose it is min(v_x, 3)?
        reward = info['additional_r'] + min(info['x_velocity'], 3) + self.z_coef * (info['z_position'])
        done = (self.t >= self._max_episode_steps)
        return reward, done

    def reset(self):
        self.t = 0
        obs = super(CheetahJump, self).reset()
        self.init_z = obs[0]
        return obs