import os
import numpy as np
from gymnasium.utils.ezpickle import EzPickle
from gymnasium_robotics.envs.fetch import MujocoFetchEnv
from gymnasium_robotics.envs.fetch.fetch_env import goal_distance
from gymnasium_robotics.envs.fetch.reach import MujocoFetchReachEnv
from gymnasium_robotics.envs.fetch.push import MujocoFetchPushEnv
from gymnasium_robotics.utils import rotations


class CustomMujocoFetchPushEnv(MujocoFetchEnv, EzPickle):

    def __init__(self, n_cubes=4, **kwargs):

        self.n_cubes = n_cubes
        assert self.n_cubes in [2,3,4], 'Too many or too few cubes!'
        initial_qpos = {
            "robot0:slide0": 0.405,
            "robot0:slide1": 0.48,
            "robot0:slide2": 0.0,
        }
        model_xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", f"custom_fetch_{self.n_cubes}.xml")
        start_dist = 0.75 - ((self.n_cubes-1) * 0.5 * 0.07)
        self.initial_cube_position = [np.array([1.25, start_dist + 0.07*i]) for i in range(self.n_cubes)]
        for i in range(self.n_cubes):
            initial_qpos[f'object{i}:joint'] = [1.25, self.initial_cube_position[i][1], 0.4, 1.0, 0.0, 0.0, 0.0]        

        MujocoFetchEnv.__init__(
            self,
            model_path=model_xml_path,
            has_object=True,
            block_gripper=True,
            n_substeps=20,
            gripper_extra_height=0.0,
            target_in_the_air=False,
            target_offset=0.0,
            obj_range=0.15,
            target_range=0.15,
            distance_threshold=0.05,
            initial_qpos=initial_qpos,
            reward_type="sparse",
            **kwargs,
        )
        EzPickle.__init__(self, reward_type="sparse", **kwargs)

    def generate_mujoco_observations(self):
        # positions
        grip_pos = self._utils.get_site_xpos(self.model, self.data, "robot0:grip")

        dt = self.n_substeps * self.model.opt.timestep
        grip_velp = (
            self._utils.get_site_xvelp(self.model, self.data, "robot0:grip") * dt
        )

        robot_qpos, robot_qvel = self._utils.robot_get_obs(
            self.model, self.data, self._model_names.joint_names
        )

        object_obs = []
        for i in range(self.n_cubes):
            object_pos = self._utils.get_site_xpos(self.model, self.data, f"object{i}")
            # rotations
            object_rot = rotations.mat2euler(
                self._utils.get_site_xmat(self.model, self.data, f"object{i}")
            )
            # velocities
            object_velp = (
                self._utils.get_site_xvelp(self.model, self.data, f"object{i}") * dt
            )
            object_velr = (
                self._utils.get_site_xvelr(self.model, self.data, f"object{i}") * dt
            )
            # gripper state
            object_rel_pos = object_pos - grip_pos
            object_velp -= grip_velp

            object_obs.append((object_pos.copy(), object_rel_pos.copy(), object_rot.copy(), object_velp.copy(), object_velr.copy()))

        gripper_state = robot_qpos[-2:]

        gripper_vel = (
            robot_qvel[-2:] * dt
        )  # change to a scalar if the gripper is made symmetric

        return (grip_pos, gripper_state, grip_velp, gripper_vel), object_obs

    def _reset_sim(self):
        self.data.time = self.initial_time
        self.data.qpos[:] = np.copy(self.initial_qpos)
        self.data.qvel[:] = np.copy(self.initial_qvel)
        if self.model.na != 0:
            self.data.act[:] = None

        # Randomize start position of object.
        for i in range(self.n_cubes):
            object_xpos = self.initial_cube_position[i]
            object_qpos = self._utils.get_joint_qpos(
                self.model, self.data, f"object{i}:joint"
            )
            assert object_qpos.shape == (7,)
            object_qpos[:2] = object_xpos
            self._utils.set_joint_qpos(
                self.model, self.data, f"object{i}:joint", object_qpos
            )

        self._mujoco.mj_forward(self.model, self.data)
        return True

    def _env_setup(self, initial_qpos):
        for name, value in initial_qpos.items():
            self._utils.set_joint_qpos(self.model, self.data, name, value)
        self._utils.reset_mocap_welds(self.model, self.data)
        self._mujoco.mj_forward(self.model, self.data)

        # Move end effector into position.
        gripper_target = np.array(
            [-0.688, 0.005, -0.341 + self.gripper_extra_height]
        ) + self._utils.get_site_xpos(self.model, self.data, "robot0:grip")
        gripper_rotation = np.array([1.0, 0.0, 1.0, 0.0])
        self._utils.set_mocap_pos(self.model, self.data, "robot0:mocap", gripper_target)
        self._utils.set_mocap_quat(
            self.model, self.data, "robot0:mocap", gripper_rotation
        )
        for _ in range(10):
            self._mujoco.mj_step(self.model, self.data, nstep=self.n_substeps)
        # Extract information for sampling goals.
        self.initial_gripper_xpos = self._utils.get_site_xpos(
            self.model, self.data, "robot0:grip"
        ).copy()
        self.height_offsets = [self._utils.get_site_xpos(
            self.model, self.data, f"object{i}"
        )[2] for i in range(self.n_cubes)]

    def _get_obs(self):
        (
            grip_pos,
            gripper_state,
            grip_velp,
            gripper_vel,
        ), object_obs = self.generate_mujoco_observations()


        achieved_goal = np.concatenate([np.squeeze(obj[0].copy()) for obj in object_obs])

        obs = np.concatenate(
            [
                grip_pos,
                np.concatenate([obj[0].ravel() for obj in object_obs]),
                np.concatenate([obj[1].ravel() for obj in object_obs]),
                gripper_state,
                np.concatenate([obj[2].ravel() for obj in object_obs]),
                np.concatenate([obj[3].ravel() for obj in object_obs]),
                np.concatenate([obj[4].ravel() for obj in object_obs]),
                grip_velp,
                gripper_vel,
            ]
        )

        return {
            "observation": obs.copy(),
            "achieved_goal": achieved_goal.copy(),
            "desired_goal": self.goal.copy(),
        }

    def _sample_goal(self):
        goals = []
        for i in range(self.n_cubes):
            goal = np.zeros((3,))
            goal[:2] = self.initial_cube_position[i]
            goal[0] += 0.1
            goal[2] = self.height_offsets[i]
            goals.append(goal.copy())
        return np.concatenate(goals)

    def _is_success(self, achieved_goal, desired_goal):
        return np.prod([(goal_distance(achieved_goal[(i*3):(i+1)*3], desired_goal[(i*3):(i+1)*3]) < self.distance_threshold).astype(np.float32) for i in range(self.n_cubes)])
    
    def compute_reward(self, achieved_goal, goal, info):
        # Compute distance between goal and the achieved goal.
        d = np.max([goal_distance(achieved_goal[(i*3):(i+1)*3], goal[(i*3):(i+1)*3]) for i in range(self.n_cubes)])
        return -(d > self.distance_threshold).astype(np.float32)

    def _render_callback(self):
        # Visualize target.
        sites_offset = (self.data.site_xpos - self.model.site_pos).copy()
        for i in range(self.n_cubes):
            site_id = self._mujoco.mj_name2id(
                self.model, self._mujoco.mjtObj.mjOBJ_SITE, f"target{i}"
            )
            self.model.site_pos[site_id] = self.goal[i*3:(i+1)*3] - sites_offset[0]
        self._mujoco.mj_forward(self.model, self.data)

    def render(self):
        self.render_mode = 'human'
        return super().render()
