import gym
from gym.envs.mujoco import ant, mujoco_env
import os
import re
from gym import utils
import numpy as np

class AntCustomEnv(ant.AntEnv):
    def __init__(self, gear_ratio=30, max_episode_len=200, angle=0, uniform=False):
        self._max_episode_steps = max_episode_len
        self.t = 0
        print('WARNING: GEAR RATIO FOR ANT_CUSTOM_ENV MAY HAVE BEEN SET INCORRECTLY!!!')
        print('WARNING: HIDING XY COORDINATES!!!')
        print('The angle was set to {}.'.format(angle))
        # assets_dir = os.path.join(os.path.dirname(os.path.realpath(gym.__file__)), 'envs', 'mujoco', 'assets')
        # with open(os.path.join(assets_dir, 'ant.xml')) as f:
        #     xml = f.read()
        # xml_custom_gear = re.sub('gear=\"\d+\"', 'gear=\"%d\"' % gear_ratio, xml)
        # filename_custom_gear = os.path.join(assets_dir, 'ant_custom_gear.xml')
        # with open(filename_custom_gear, 'w') as f:
        #     f.write(xml_custom_gear)
        self.origin = [0., 0.]
        self.distance = 0.
        # input direction range in [0, 360)
        assert angle >= 0 and angle < 360
        rad = (angle / 180.0) * np.pi
        self.direction = np.array([np.cos(rad), np.sin(rad)])
        self.uniform = uniform
        mujoco_env.MujocoEnv.__init__(self, 'ant_custom.xml', 5)
        utils.EzPickle.__init__(self)

    def step(self, a):
        '''Modified to not terminate when ant jumps and flips over.'''
        (obs, r, done, info) = super(AntCustomEnv, self).step(a)
        self.t += 1
        r, done = self._get_reward(obs)
        info['additional_r'] = 0.0
        return (obs, r, done, info)

    def _get_obs(self):
        '''Modified to include global x, y coordinates'''
        return np.concatenate([
            self.data.qpos.flat,
            self.data.qvel.flat,
            # np.clip(self.data.cfrc_ext, -1, 1).flat,
        ])

    def _state(self):
        return self._get_obs()

    def _get_reward(self, obs):
        done = (self.t >= self._max_episode_steps)
        if self.uniform:
            dist = np.linalg.norm(self.origin - obs[:2])
        else:
            dist = np.inner(obs[:2], self.direction)
        r = (dist - self.distance) * 100
        self.distance = dist
        return r, done

    def reset(self):
        self.t = 0
        obs = super(AntCustomEnv, self).reset()
        self.origin = [0., 0.]
        self.distance = 0.
        return obs




if __name__ == '__main__':
    env = AntCustomEnv()
    env.reset()
    while True:
        action = env.action_space.sample()
        obs, reward, done, info = env.step(action)
        env.render()
