import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env
from . import register_env

@register_env("reacher-goal-sparse")
class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    def __init__(self, n_tasks=50, randomize_tasks=True):
        # self.goals = self.sample_tasks(n_tasks)
        np.random.seed(1)
        self.goal_radius = 0.03
        self._goal = [0., 0.]
        utils.EzPickle.__init__(self)
        mujoco_env.MujocoEnv.__init__(self, "reacher.xml", 2)
        # self.reset_task(0)
        self.goals = self.sample_tasks(n_tasks)

    def step(self, a):
        self.do_simulation(a, self.frame_skip)
        ob = self._get_obs()
        fingertip = np.copy(self.get_body_com("fingertip"))
        reward_ctrl = -np.square(a).sum()
        vec = fingertip[:2] - self._goal
        reward_dist = -np.linalg.norm(vec)
        sparse_reward = self.sparsify_rewards(reward_dist)
        reward = sparse_reward + reward_ctrl
        # reward = reward_dist + reward_ctrl
        done = False
        image = self.get_image()
        return image, reward, done, dict(fingertip=fingertip, reward_ctrl=reward_ctrl)

    def sparsify_rewards(self, r):
        if r < -self.goal_radius:
            sparse_r = 0.
        else:
            sparse_r = r + 0.2
        return sparse_r

    def viewer_setup(self):
        self.viewer.cam.trackbodyid = 0

    def reset_model(self):
        qpos = (
            np.random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
        )
        qpos[-2:] = self._goal
        qvel = self.init_qvel + self.np_random.uniform(
            low=-0.005, high=0.005, size=self.model.nv
        )
        qvel[-2:] = 0
        self.set_state(qpos, qvel)
        return self._get_obs()

    def reward(self, info, goal):
        vec = info["fingertip"][:2] - goal
        reward_dist = -np.linalg.norm(vec)
        reward_ctrl = info["reward_ctrl"]
        reward_dist = self.sparsify_rewards(reward_dist)
        reward = reward_ctrl + reward_dist
        done = False
        return (reward, done)

    def initialize_camera(self):
        # set camera parameters for viewing
        sim = self.sim
        viewer = mujoco_py.MjRenderContextOffscreen(sim)
        camera = viewer.cam
        camera.type = 1
        camera.trackbodyid = 0
        camera.elevation = -20
        # camera.azimuth = -90
        sim.add_render_context(viewer)

    def get_image(self, width=100, height=100, camera_name=None):
        image = self.sim.render(
            width=width,
            height=height,
            camera_name=camera_name,
        )
        image[image==98] = 255
        image[image==99] = 255
        image = np.transpose(image, (2, 0, 1))
        image = image[0] * 0.2989 + image[1] * 0.587 + image[2] * 0.114
        return np.expand_dims(image, axis=0)

    def _get_obs(self):
        theta = self.sim.data.qpos.flat[:2]
        return np.concatenate(
            [np.cos(theta), np.sin(theta), self.sim.data.qvel.flat[:2]]
        )

    def reset_task(self, idx):
        self._goal = self.goals[idx]
        self.reset()

    def get_train_goals(self, n_train_tasks):
        return self.goals[:n_train_tasks]

    def get_all_task_idx(self):
        return range(len(self.goals))

    def sample_tasks(self, n_tasks):
        goals = []
        # make sure goals aren't too easy 
        fingertip = self.get_body_com("fingertip")
        # radius = np.random.uniform(0.2, 0.25)
        radius = 0.2
        angles = np.linspace(0, np.pi, num=n_tasks)
        xs = radius * np.cos(angles)
        ys = radius * np.sin(angles)
        goals = np.stack([xs, ys], axis=1)
        np.random.shuffle(goals)
        goals = goals.tolist()
        return goals
        """
        for i in range(n_tasks):
            while True:
                goal = np.random.uniform(low=-0.2, high=0.2, size=2)
                vec = fingertip[:2] - goal
                if np.linalg.norm(goal) < 0.2 and np.linalg.norm(vec) >= 0.1:
                    break
            goals.append(goal)
        return goals
        """
