import numpy as np
import torch
import torch.nn.functional as F
from env.sawyer.sawyer import SawyerEnv
from util.transform_utils import *
import os


class SawyerLiftObstacleEnv(SawyerEnv):
    def __init__(self, **kwargs):
        if(kwargs["dr_eval"] == True):
            xml_file = "sawyer_lift_obstacle_with_distractors.xml"
        else:
            xml_file = "sawyer_lift_obstacle.xml"
        super().__init__(xml_file, **kwargs)
        self._get_reference()

    @property
    def init_qpos(self):
        return np.array([-0.0305, -0.7325, 0.03043, 1.16124, 1.87488, 0, 0])

    def _get_reference(self):
        super()._get_reference()

        self.cube_body_id = self.sim.model.body_name2id("cube")
        self.cube_geom_id = self.sim.model.geom_name2id("cube")
        self.cube_site_id = self.sim.model.site_name2id("cube")

    def _reset(self):
        self._set_camera_position(2, [1.16, 0., 2.85]) # default camera position

        # self._set_camera_position(2, [0.07, 0., 1.9]) # candidate 1 camera position
        # self._set_camera_rotation(2, [0.5, -0.2, 0.99])

        # self._set_camera_position(2, [0.3, -0.6, 2.35]) # candidate 2 camera position
        # self._set_camera_rotation(2, [0.35, -0.2, 0.99])

        # self._set_camera_position(2, [0.71, -0.39, 1.4]) # candidate 3 camera position
        # self._set_camera_rotation(2, [0, 0.1, 0])

        init_qpos = (
            self.init_qpos + self.np_random.randn(self.init_qpos.shape[0]) * 0.02
        )
        self.sim.data.qpos[self.ref_joint_pos_indexes] = init_qpos
        self.sim.data.qvel[self.ref_joint_vel_indexes] = 0.0
        self.sim.data.qvel[self.ref_joint_vel_indexes] = 0.0
        self.sim.forward()

        return self._get_obs()

    def initialize_joints(self):
        init_qpos = (
            self.init_qpos + self.np_random.randn(self.init_qpos.shape[0]) * 0.02
        )
        self.sim.data.qpos[self.ref_joint_pos_indexes] = init_qpos
        self.sim.data.qvel[self.ref_joint_vel_indexes] = 0.0
        self.sim.forward()

    @property
    def left_finger_geoms(self):
        return ["l_finger_g0", "l_finger_g1", "l_fingertip_g0"]

    @property
    def right_finger_geoms(self):
        return ["r_finger_g0", "r_finger_g1", "r_fingertip_g0"]

    @property
    def l_finger_geom_ids(self):
        return [self.sim.model.geom_name2id(name) for name in self.left_finger_geoms]

    @property
    def r_finger_geom_ids(self):
        return [self.sim.model.geom_name2id(name) for name in self.right_finger_geoms]

    @property
    def gripper_bodies(self):
        return [
            "clawGripper",
            "rightclaw",
            "leftclaw",
            "right_gripper_base",
            "right_gripper",
            "r_gripper_l_finger_tip",
            "r_gripper_r_finger_tip",
        ]

    @property
    def gripper_indicator_bodies(self):
        return [
            "clawGripper_indicator",
            "rightclaw_indicator",
            "leftclaw_indicator",
            "right_gripper_base_indicator",
            "r_gripper_l_finger_tip_indicator",
            "r_gripper_r_finger_tip_indicator",
        ]

    @property
    def gripper_target_bodies(self):
        return [
            "clawGripper_target",
            "rightclaw_target",
            "leftclaw_target",
            "right_gripper_base_target",
            "r_gripper_l_finger_tip_target",
            "r_gripper_r_finger_tip_target",
        ]

    def compute_reward(self, action):
        reward_type = self._kwargs["reward_type"]
        info = {}
        reward = 0

        reach_mult = 0.1
        grasp_mult = 0.35
        lift_mult = 0.5
        hover_mult = 0.7

        reward_reach = 0.0
        gripper_site_pos = self.sim.data.get_site_xpos("grip_site")
        cube_pos = np.array(self.sim.data.body_xpos[self.cube_body_id])
        gripper_to_cube = np.linalg.norm(cube_pos - gripper_site_pos)
        reward_reach = (1 - np.tanh(10 * gripper_to_cube)) * reach_mult

        touch_left_finger = False
        touch_right_finger = False
        for i in range(self.sim.data.ncon):
            c = self.sim.data.contact[i]
            if c.geom1 == self.cube_geom_id:
                if c.geom2 in self.l_finger_geom_ids:
                    touch_left_finger = True
                if c.geom2 in self.r_finger_geom_ids:
                    touch_right_finger = True
            elif c.geom2 == self.cube_geom_id:
                if c.geom1 in self.l_finger_geom_ids:
                    touch_left_finger = True
                if c.geom1 in self.r_finger_geom_ids:
                    touch_right_finger = True
        has_grasp = touch_right_finger and touch_left_finger
        reward_grasp = int(has_grasp) * grasp_mult

        reward_lift = 0.0
        object_z_locs = self.sim.data.body_xpos[self.cube_body_id][2]
        if reward_grasp > 0.0:
            z_target = self._get_pos("bin1")[2] + 0.45
            z_dist = np.maximum(z_target - object_z_locs, 0.0)
            reward_lift = grasp_mult + (1 - np.tanh(15 * z_dist)) * (
                lift_mult - grasp_mult
            )

        reward += max(reward_reach, reward_grasp, reward_lift)
        info = dict(
            reward_reach=reward_reach,
            reward_grasp=reward_grasp,
            reward_lift=reward_lift,
        )

        if reward_grasp > 0.0 and np.abs(object_z_locs - z_target) < 0.05:
            reward += self._kwargs["success_reward"]
            self._success = True
            self._terminal = True
        else:
            self._success = False

        return reward, info

    def get_env_image(self):
        img = (super().render(mode='rgb_array')).astype('float32')
        img = img.transpose(2, 0, 1)
        img = np.expand_dims(img, axis=0)

        if self._kwargs['save_img_to_disk']:
            img_filepath = os.path.join(self.img_folder, 'img_{}.npy'.format(self.img_counter))
            np.save(img_filepath, img)
            self.img_counter += 1
            return img_filepath

        return img

    def _get_obs(self):
        di = super()._get_obs()
        cube_pos = np.array(self.sim.data.body_xpos[self.cube_body_id])
        cube_quat = convert_quat(
            np.array(self.sim.data.body_xquat[self.cube_body_id]), to="xyzw"
        )
        di["cube_pos"] = cube_pos
        di["cube_quat"] = cube_quat
        gripper_site_pos = np.array(self.sim.data.site_xpos[self.eef_site_id])
        di["gripper_to_cube"] = gripper_site_pos - cube_pos
        if self._kwargs["obs_space"] != "state":
            di['image'] = self.get_env_image()
        return di

    @property
    def static_bodies(self):
        return ["table", "bin1"]

    @property
    def static_geoms(self):
        return []

    @property
    def static_geom_ids(self):
        body_ids = []
        for body_name in self.static_bodies:
            body_ids.append(self.sim.model.body_name2id(body_name))

        geom_ids = []
        for geom_id, body_id in enumerate(self.sim.model.geom_bodyid):
            if body_id in body_ids:
                geom_ids.append(geom_id)
        return geom_ids

    @property
    def manipulation_geom(self):
        return ["cube"]

    @property
    def manipulation_geom_ids(self):
        return [self.sim.model.geom_name2id(name) for name in self.manipulation_geom]

    def _step(self, action, is_planner=False, is_mopa_rl=True, is_bc_policy=False):
        """
        (Optional) does gripper visualization after actions.
        """
        assert len(action) == self.dof, "environment got invalid action dimension"

        if not is_planner or self._prev_state is None:
            self._prev_state = self.sim.data.qpos[self.ref_joint_pos_indexes].copy()

        if self._i_term is None:
            self._i_term = np.zeros_like(self.mujoco_robot.dof)

        if (not is_mopa_rl) or is_bc_policy:
            # no rescaling
            rescaled_ac = np.clip(
                action[: self.robot_dof], -self._ac_scale, self._ac_scale
            )
        else:
            if is_planner:
                rescaled_ac = np.clip(
                    action[: self.robot_dof], -self._ac_scale, self._ac_scale
                )
            else:
                rescaled_ac = np.clip(
                    action[: self.robot_dof] * self._ac_scale,
                    -self._ac_scale,
                    self._ac_scale,
                )
        desired_state = self._prev_state + rescaled_ac
        arm_action = desired_state
        gripper_action = self._gripper_format_action(np.array([action[-1]]))
        converted_action = np.concatenate([arm_action, gripper_action])

        n_inner_loop = int(self._frame_dt / self.dt)
        for _ in range(n_inner_loop):
            self.sim.data.qfrc_applied[
                self.ref_joint_vel_indexes
            ] = self.sim.data.qfrc_bias[self.ref_joint_vel_indexes].copy()

            if self.use_robot_indicator:
                self.sim.data.qfrc_applied[
                    self.ref_indicator_joint_pos_indexes
                ] = self.sim.data.qfrc_bias[self.ref_indicator_joint_pos_indexes].copy()

            if self.use_target_robot_indicator:
                self.sim.data.qfrc_applied[
                    self.ref_target_indicator_joint_pos_indexes
                ] = self.sim.data.qfrc_bias[
                    self.ref_target_indicator_joint_pos_indexes
                ].copy()
            self._do_simulation(converted_action)

        self._prev_state = np.copy(desired_state)
        reward, info = self.compute_reward(action)

        return self._get_obs(), reward, self._terminal, info
