from gym.envs.mujoco import ant
import numpy as np

class AntAngle(ant.AntEnv):
    def __init__(self, angle=0.0, max_episode_steps=1000):
        self._goal_index = 0
        self.angle = angle
        self.x_coef = np.cos(angle)
        self.y_coef = np.sin(angle)
        self._max_episode_steps = max_episode_steps
        self.t = 0
        super(AntAngle, self).__init__()

    def step(self, a):
        '''Modified to not terminate when ant jumps and flips over.'''
        (obs, r, done, info) = super(AntAngle, self).step(a)
        self.t += 1
        # two dimensions shift because two positions was added in the front
        info_new = {}
        info_new['x_velocity'] = obs[15]
        info_new['y_velocity'] = obs[16]
        info_new['additional_r'] = r - info['reward_forward']
        r_new, done = self._get_reward(info_new)
        return (obs, r_new, done, info_new)

    def _get_obs(self):
        '''Modified to include global x, y coordinates'''
        return np.concatenate([
            self.data.qpos.flat,
            self.data.qvel.flat,
            np.clip(self.data.cfrc_ext, -1, 1).flat,
        ])

    def _state(self):
        return self._get_obs()

    def _get_reward(self, info):
        reward = info['additional_r'] + info['x_velocity'] * self.x_coef + info['y_velocity'] * self.y_coef
        done = (self.t >= self._max_episode_steps)
        return reward, done

    def reset(self):
        self.t = 0
        return super(AntAngle, self).reset()