import numpy as np
import torch
import torch.nn.functional as F
from env.sawyer.sawyer import SawyerEnv
import os

class SawyerAssemblyObstacleEnv(SawyerEnv):
    def __init__(self, **kwargs):
        if(kwargs["dr_eval"] == True):
            xml_file = "sawyer_assembly_obstacle_with_distractors.xml"
        else:
            xml_file = "sawyer_assembly_obstacle.xml"
        super().__init__(xml_file, **kwargs)
        self._get_reference()

    def _get_reference(self):
        super()._get_reference()

    @property
    def dof(self):
        return 7

    @property
    def init_qpos(self):
        return np.array([0.427, 0.13, 0.0557, 0.114, -0.0622, 0.0276, 0.00356]) # default configuration
        # return np.array([0.5, 0.13, 0.0557, 0.114, -0.0622, 0.0276, 0.00356]) # robot's arm further away from the table

    def _reset(self):
        # default configuration
        init_qpos = (
            self.init_qpos + self.np_random.randn(self.init_qpos.shape[0]) * 0.02
        )
        self.sim.data.qpos[self.ref_joint_pos_indexes] = init_qpos
        self.sim.data.qvel[self.ref_joint_vel_indexes] = 0.0
        self.sim.forward()


        return self._get_obs()

        # make this task more difficult
        # keep_searching = True
        # while keep_searching:
        #     init_qpos = (
        #         self.init_qpos + self.np_random.randn(self.init_qpos.shape[0]) * 0.1 # make this task harder
        #     )
        #     self.sim.data.qpos[self.ref_joint_pos_indexes] = init_qpos
        #     self.sim.data.qvel[self.ref_joint_vel_indexes] = 0.0
        #     self.sim.forward()
        #     if self.get_contact_force() < 100:
        #         # no collision, stop searching/resetting
        #         keep_searching = False
        #     else:
        #         print('Contact force: ', self.get_contact_force())
        #         print('Collision!!!!!!! Reset init_qpos')

        # return self._get_obs()

    def compute_reward(self, action):
        info = {}
        reward = 0
        reward_type = self._kwargs["reward_type"]
        pegHeadPos = self.sim.data.get_site_xpos("pegHead")
        hole = self.sim.data.get_site_xpos("hole")
        dist = np.linalg.norm(pegHeadPos - hole)
        hole_bottom = self.sim.data.get_site_xpos("hole_bottom")
        dist_to_hole_bottom = np.linalg.norm(pegHeadPos - hole_bottom)
        dist_to_hole = np.linalg.norm(pegHeadPos - hole)
        reward_reach = 0
        if dist < 0.3:
            reward_reach += 0.4 * (1 - np.tanh(15 * dist_to_hole))
        reward += reward_reach
        if dist_to_hole_bottom < 0.025:
            reward += self._kwargs["success_reward"]
            self._success = True
            self._terminal = True

        return reward, info

    def get_env_image(self):
        img = (super().render(mode='rgb_array')).astype('float32')
        img = img.transpose(2, 0, 1)
        img = np.expand_dims(img, axis=0)

        if self._kwargs['save_img_to_disk']:
            img_filepath = os.path.join(self.img_folder, 'img_{}.npy'.format(self.img_counter))
            np.save(img_filepath, img)
            self.img_counter += 1
            return img_filepath

        return img

    def _get_obs(self):
        di = super()._get_obs()
        di["hole"] = self.sim.data.get_site_xpos("hole")
        di["pegHead"] = self.sim.data.get_site_xpos("pegHead")
        di["pegEnd"] = self.sim.data.get_site_xpos("pegEnd")
        di["peg_quat"] = self._get_quat("peg")
        if self._kwargs["obs_space"] != "state":
            di['image'] = self.get_env_image()
        return di

    @property
    def static_bodies(self):
        return ["table"]

    @property
    def manipulation_bodies(self):
        return ["furniture", "0_part0", "1_part1", "4_part4", "2_part2"]

    @property
    def manipulation_geom_ids(self):
        body_ids = []
        for body_name in self.manipulation_bodies:
            body_ids.append(self.sim.model.body_name2id(body_name))

        geom_ids = []
        for geom_id, body_id in enumerate(self.sim.model.geom_bodyid):
            if body_id in body_ids:
                geom_ids.append(geom_id)
        return geom_ids

    @property
    def static_geoms(self):
        return []

    @property
    def static_geom_ids(self):
        body_ids = []
        for body_name in self.static_bodies:
            body_ids.append(self.sim.model.body_name2id(body_name))

        geom_ids = []
        for geom_id, body_id in enumerate(self.sim.model.geom_bodyid):
            if body_id in body_ids:
                geom_ids.append(geom_id)
        return geom_ids

    def _step(self, action, is_planner=False, is_mopa_rl=True, is_bc_policy=False):
        """
        (Optional) does gripper visualization after actions.
        """
        assert len(action) == self.dof, "environment got invalid action dimension"

        if not is_planner or self._prev_state is None:
            self._prev_state = self.sim.data.qpos[self.ref_joint_pos_indexes].copy()

        if self._i_term is None:
            self._i_term = np.zeros_like(self.mujoco_robot.dof)

        if (not is_mopa_rl) or is_bc_policy:
            # no rescaling
            rescaled_ac = np.clip(
                action[: self.robot_dof], -self._ac_scale, self._ac_scale
            )
        else:
            if is_planner:
                rescaled_ac = np.clip(
                    action[: self.robot_dof], -self._ac_scale, self._ac_scale
                )
            else:
                rescaled_ac = np.clip(
                    action[: self.robot_dof] * self._ac_scale,
                    -self._ac_scale,
                    self._ac_scale,
                )
        desired_state = self._prev_state + rescaled_ac

        n_inner_loop = int(self._frame_dt / self.dt)
        for _ in range(n_inner_loop):
            self.sim.data.qfrc_applied[
                self.ref_joint_vel_indexes
            ] = self.sim.data.qfrc_bias[self.ref_joint_vel_indexes].copy()

            if self.use_robot_indicator:
                self.sim.data.qfrc_applied[
                    self.ref_indicator_joint_pos_indexes
                ] = self.sim.data.qfrc_bias[self.ref_indicator_joint_pos_indexes].copy()

            if self.use_target_robot_indicator:
                self.sim.data.qfrc_applied[
                    self.ref_target_indicator_joint_pos_indexes
                ] = self.sim.data.qfrc_bias[
                    self.ref_target_indicator_joint_pos_indexes
                ].copy()
            self._do_simulation(desired_state)

        self._prev_state = np.copy(desired_state)
        reward, info = self.compute_reward(action)

        return self._get_obs(), reward, self._terminal, info
