from gym import spaces
import numpy as np
import mujoco_py
from environments.reacher.reacher_base import ReacherRandomizedEnv

count = 0


class ReacherMultistageEnv(ReacherRandomizedEnv):
    def __init__(
        self, fixed_stage_lengths=False, sparse_tasks=[], include_task_id=False
    ):
        self.task_id = 0
        self._stage = 0
        self._initial_pos = None
        self._steps = 0
        self._stages_completed = [False, False, False]
        self.stage_length = 30
        self.fixed_stage_lengths = fixed_stage_lengths
        self.goal_locations = [  # dirty hack for easy task id retrieval
            [[0, 0.1401], [-0.05, -0.0501], [-0.1, 0.08]],
            [[0, 0.1402], [-0.05, -0.0502], [0.1, 0.1]],
        ]
        self._sparse_tasks = sparse_tasks
        self._include_task_id = include_task_id
        self._num_tasks = 3
        global count
        self._count = count
        count = count + 1
        super().__init__()

    @classmethod
    def disc2cont(cls, a):
        a1 = [-0.5, 0.0, 0.5][a // 3]
        a2 = [-0.5, 0.0, 0.5][a % 3]
        return np.array([a1, a2])

    @classmethod
    def get_n_discrete_actions(self):
        return 9

    @property
    def num_tasks(self):
        return self._num_tasks

    def set_goal(self, goal):
        old_state = self.sim.get_state()
        qpos = old_state.qpos
        qpos[-2:] = goal
        new_state = mujoco_py.MjSimState(
            old_state.time, qpos, old_state.qvel, old_state.act, old_state.udd_state
        )
        self.sim.set_state(new_state)
        self.sim.forward()

    def reset_model(self):
        self.task_id = self._count % self._num_tasks
        self._count += 1
        self._stage = 0
        self._steps = 0
        self._stages_completed = [False, False, False]

        if self.task_id == 0:
            joint1 = np.pi / 2 + self.np_random.uniform(
                low=-0.1 * np.pi / 2,
                high=0.1 * np.pi / 2,
            )
            joint2 = np.pi / 2 + self.np_random.uniform(
                low=-0.1 * np.pi / 2,
                high=0.1 * np.pi / 2,
            )
            self.goal = self.goal_locations[self.task_id][0]

        elif self.task_id == 1:
            joint1 = 0 + self.np_random.uniform(
                low=-0.1 * np.pi / 2,
                high=0.1 * np.pi / 2,
            )
            joint2 = np.pi / 2 + self.np_random.uniform(
                low=-0.1 * np.pi / 2,
                high=0.1 * np.pi / 2,
            )
            self.goal = self.goal_locations[self.task_id][0]
        elif self.task_id == 2:
            if np.random.uniform() < 0.5:
                joint1 = np.pi / 2 + self.np_random.uniform(
                    low=-0.1 * np.pi / 2,
                    high=0.1 * np.pi / 2,
                )
                joint2 = np.pi / 2 + self.np_random.uniform(
                    low=-0.1 * np.pi / 2,
                    high=0.1 * np.pi / 2,
                )
            else:
                joint1 = 0 + self.np_random.uniform(
                    low=-0.1 * np.pi / 2,
                    high=0.1 * np.pi / 2,
                )
                joint2 = np.pi / 2 + self.np_random.uniform(
                    low=-0.1 * np.pi / 2,
                    high=0.1 * np.pi / 2,
                )
            self.goal = [0, 0]

        qpos = np.array([0.0, 0.0, 0.0, 0.0])
        qpos[0] += joint1
        qpos[1] += joint2

        qpos[-2:] = self.goal
        qvel = self.init_qvel + self.np_random.uniform(
            low=-0.005,
            high=0.005,
            size=self.model.nv,
        )
        qvel[-2:] = 0
        self.set_state(qpos, qvel)

        if self.task_id == 2:
            self.goal = self.get_body_com("fingertip")[:2]
            if (
                self._get_task_id(self.goal) < 2
            ):  # dirty hack for easy task id retrieval
                self.goal += [1e-5, 1e-5]
            self.set_goal(self.goal)

        return self._get_obs()

    def step(self, a):

        if self.fixed_stage_lengths:
            ### Calculate which stage based on self._steps so far
            new_stage = min(self._steps // self.stage_length, 2)
            if new_stage != self._stage:
                self._stage = new_stage
                if self.task_id == 0 or self.task_id == 1:
                    self.goal = self.goal_locations[self.task_id][self._stage]
                    self.set_goal(self.goal)

        ob, rew, done, infos = super().step(a)
        if self.task_id in self._sparse_tasks:
            rew = float(infos["success"])

        if self.fixed_stage_lengths:
            ### Keep track of the number of stages with successes
            if self.task_id != 2 and infos["success"]:
                self._stages_completed[self._stage] = True

            infos["stages_completed"] = np.sum(self._stages_completed)

        if not self.fixed_stage_lengths:
            ### Check if success to advance stage
            if self.task_id != 2 and infos["success"]:
                self._stage += 1
                if self._stage > 2:
                    infos["reward_stage"] = 0
                    infos["stages_completed"] = self._stage
                    infos["task_id"] = self.task_id
                    return ob, rew, True, infos
                self.goal = self.goal_locations[self.task_id][self._stage]
                self.set_goal(self.goal)

            infos["stages_completed"] = self._stage

        if self.task_id == 2:
            self._stage = 2 * infos["success"]
            infos["stages_completed"] = 3 * infos["success"]

        ### Add reward based on stage
        reward_stage = 0.25 * (self._stage - 2)
        if self.task_id in self._sparse_tasks:
            reward_stage = 0
        rew += reward_stage
        infos["reward_stage"] = reward_stage
        infos["task_id"] = self.task_id

        self._steps += 1

        return ob, rew, done, infos

    def _get_task_id(
        self, goal, return_stage=False
    ):  # dirty hack for easy task id retrieval
        goal_dist = np.linalg.norm(
            np.asarray(self.goal_locations) - np.asarray(goal),
            axis=-1,
        )
        idx = np.where(goal_dist < 1e-6)
        if return_stage:
            idx = idx[1]
        else:
            idx = idx[0]

        if len(idx) > 0:
            return idx[0]
        else:
            return 0 if return_stage else 2

    def get_task_id(self, observation, return_stage=False):
        if len(observation.shape) > 1:
            goal_pos = self.to_numpy(observation[:, 4:6])
            return [self._get_task_id(pos, return_stage) for pos in goal_pos]
        else:
            goal_pos = self.to_numpy(observation[4:6])
            return self._get_task_id(goal_pos, return_stage)

    def _get_stage_id(self, goal):
        return goal[1].astype(int)

    def get_stage_id(self, observation):
        if len(observation.shape) > 1:
            goal_pos = self.to_numpy(observation[:, 4:6])
            return [self._get_stage_id(pos) for pos in goal_pos]
        else:
            goal_pos = self.to_numpy(observation[4:6])
            return self._get_stage_id(goal_pos)

    def to_numpy(self, x):
        import torch

        if isinstance(x, torch.Tensor):
            return x.cpu().numpy()
        else:
            return np.asarray(x)

    def split_observation(self, observation):
        obs_without_task = (
            observation.copy()
            if isinstance(observation, np.ndarray)
            else observation.clone()
        )

        if self._include_task_id:
            obs_without_task[..., 4] = self.get_task_id(observation)
            obs_without_task[..., 5] = self.get_task_id(observation, return_stage=True)
        else:
            obs_without_task[..., 4] = 0
            obs_without_task[..., 5] = self.get_task_id(observation, return_stage=True)

        # if self._goal_type == 0:  # goal location
        #     pass
        # elif self._goal_type == 1:  # task + stage id
        #     obs_without_task[..., 4] = self.get_task_id(observation)
        #     obs_without_task[..., 5] = self.get_task_id(observation, return_stage=True)
        # elif self._goal_type == 2:  # task id only
        #     obs_without_task[..., 4] = self.get_task_id(observation)
        #     obs_without_task[..., 5] = 0
        # elif self._goal_type == 3:  # stage id only
        #     obs_without_task[..., 4] = 0
        #     obs_without_task[..., 5] = self.get_task_id(observation, return_stage=True)
        # elif self._goal_type == 4:  # none
        #     obs_without_task[..., 4:6] = 0
        # else:
        #     raise ValueError("Unsupported goal type {}".format(self._goal_type))

        task_info = observation  # not altering to exploit previous get_task_id, etc.
        return obs_without_task, task_info


class ReacherMultistageFixedTaskEnv(ReacherMultistageEnv):
    def __init__(self, task_id, include_task_id=False):
        self._ep_ret = 0
        super().__init__(include_task_id=include_task_id)
        self.task_id = task_id
        self._count = self.task_id

    @property
    def episode_total_reward(self):
        return self._ep_ret

    def set_goal(self, goal):
        old_state = self.sim.get_state()
        qpos = old_state.qpos
        qpos[-2:] = goal
        new_state = mujoco_py.MjSimState(
            old_state.time, qpos, old_state.qvel, old_state.act, old_state.udd_state
        )
        self.sim.set_state(new_state)
        self.sim.forward()

    def reset_model(self):
        self._ep_ret = 0
        ret = super().reset_model()
        self._count -= 1
        return ret

    def step(self, a):
        obs, rew, done, info = super().step(a)
        self._ep_ret += rew
        return obs, rew, done, info
