import numpy as np
from gym.envs.mujoco import HumanoidEnv as HumanoidEnv_

def mass_center(model, sim):
    mass = np.expand_dims(model.body_mass, 1)
    xpos = sim.data.xipos
    return (np.sum(mass * xpos, 0) / np.sum(mass))

class HumanoidDirEnv(HumanoidEnv_):

    def __init__(self, task=None, n_tasks=2, randomize_tasks=True, max_episode_steps=200, modify_init_state_dist=False, on_circle_init_state=False):
        self.randomize_tasks = randomize_tasks
        self.tasks = self.sample_tasks(n_tasks) if task is None or task == {} else task
        self.reset_task(0)
        self._max_episode_steps = max_episode_steps
        self._step = 0
        self.num_tasks = n_tasks
        super(HumanoidDirEnv, self).__init__()

    def step(self, action):
        pos_before = np.copy(mass_center(self.model, self.sim)[:2])
        self.do_simulation(action, self.frame_skip)
        pos_after = mass_center(self.model, self.sim)[:2]

        alive_bonus = 1.0
        data = self.sim.data
        goal_direction = (np.cos(self._goal), np.sin(self._goal))
        lin_vel_cost = 0.25 * np.sum(goal_direction * (pos_after - pos_before)) / self.model.opt.timestep
        quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum()
        quad_impact_cost = .5e-6 * np.square(data.cfrc_ext).sum()
        quad_impact_cost = min(quad_impact_cost, 10)
        reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus
        qpos = self.sim.data.qpos
        # done = bool((qpos[2] < 1.0) or (qpos[2] > 2.0))
        done = False
        self._step += 1
        if self._step >= self._max_episode_steps:
            done = True
        return self._get_obs(), reward, done, dict(reward_linvel=lin_vel_cost,
                                                   reward_quadctrl=-quad_ctrl_cost,
                                                   reward_alive=alive_bonus,
                                                   reward_impact=-quad_impact_cost)

    def _get_obs(self):
        data = self.sim.data
        return np.concatenate([data.qpos.flat[2:],
                               data.qvel.flat,
                               data.cinert.flat,
                               data.cvel.flat,
                               data.qfrc_actuator.flat,
                               data.cfrc_ext.flat])

    def get_all_task_idx(self):
        return range(len(self.tasks))
    
    def reset(self):
        self._step = 0
        return super().reset()

    def reset_task(self, idx):
        self._goal_idx = idx
        self._task = self.tasks[idx]
        self._goal = self._task['goal'] # assume parameterization of task by single vector

    def sample_tasks(self, num_tasks):
        # velocities = np.random.uniform(0., 1.0 * np.pi, size=(num_tasks,))
        if self.randomize_tasks:
            directions = np.random.uniform(0., 2.0 * np.pi, size=(num_tasks,))
        else:
            directions = np.linspace(0, 2*np.pi, num_tasks+1)[:-1]
        tasks = [{'goal': d} for d in directions]
        return tasks
    
    def set_all_goals(self, goals):
        assert self.num_tasks == len(goals)
        self.tasks = [{'goal': dir[0]} for dir in goals]
        self.reset_task(0)