"""Environments using kitchen and Franka robot."""

import os
import sys
class SuppressStdout:  # TODO: temporarily added here to suppress D4RL output
    def __enter__(self):
        self._prev_stdout = sys.stdout
        sys.stdout = open(os.devnull, "w")

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._prev_stdout


import mujoco_py
import numpy as np
from d4rl.kitchen.adept_envs.utils.configurable import configurable
from d4rl.kitchen.adept_envs.franka.kitchen_multitask_v0 import KitchenTaskRelaxV1
from d4rl.kitchen.kitchen_envs import OBS_ELEMENT_INDICES
from environments.kitchen.kitchen_base import KitchenSingleTaskEnv

OBS_ELEMENT_SITES = {
    "bottom burner": "knob2_site",
    "top burner": "knob4_site",
    "light switch": "light_site",
    "slide cabinet": "slide_site",
    "hinge cabinet": "hinge_site2",
    "microwave": "microhandle_site",
    "kettle": "kettle_site",
}

OBS_ELEMENT_INITS = {
    "bottom burner-on": np.array([3.12877220e-05, -4.51199853e-05]),
    "bottom burner-off": np.array([-0.88, -0.01]),
    "top burner-on": np.array([6.28065475e-05, 4.04984708e-05]),
    "top burner-off": np.array([-0.92, -0.01]),
    "light switch-on": np.array([4.62730939e-04, -2.26906415e-04]),
    "light switch-off": np.array([-0.69, -0.05]),
    "slide cabinet-open": np.array([-4.65501369e-04]),
    "slide cabinet-close": np.array([0.37]),
    "hinge cabinet-open": np.array([-6.44129196e-03, -1.77048263e-03]),
    "hinge cabinet-close": np.array([0.0, 1.45]),
    "microwave-open": np.array([1.08009684e-03]),
    "microwave-close": np.array([-0.75]),
    "kettle-push": np.array(
        [
            -2.69397440e-01,
            3.50383255e-01,
            1.61944683e00,
            1.00618764e00,
            4.06395120e-03,
            -6.62095997e-03,
            -2.68278933e-04,
        ]
    ),
    "kettle-pull": np.array([-0.23, 0.75, 1.62, 0.99, 0.0, 0.0, -0.06]),
}

OBS_ELEMENT_GOALS = {
    "bottom burner-on": np.array([-0.88, -0.01]),
    "bottom burner-off": np.array([3.12877220e-05, -4.51199853e-05]),
    "top burner-on": np.array([-0.92, -0.01]),
    "top burner-off": np.array([6.28065475e-05, 4.04984708e-05]),
    "light switch-on": np.array([-0.69, -0.05]),
    "light switch-off": np.array([4.62730939e-04, -2.26906415e-04]),
    "slide cabinet-open": np.array([0.37]),
    "slide cabinet-close": np.array([-4.65501369e-04]),
    "hinge cabinet-open": np.array([0.0, 1.45]),
    "hinge cabinet-close": np.array([-6.44129196e-03, -1.77048263e-03]),
    "microwave-open": np.array([-0.75]),
    "microwave-close": np.array([1.08009684e-03]),
    "kettle-push": np.array([-0.23, 0.75, 1.62, 0.99, 0.0, 0.0, -0.06]),
    "kettle-pull": np.array(
        [
            -2.69397440e-01,
            3.50383255e-01,
            1.61944683e00,
            1.00618764e00,
            4.06395120e-03,
            -6.62095997e-03,
            -2.68278933e-04,
        ]
    ),
}

OBJ_JNT_NAMES = (
    "knob1_joint",
    "knob2_joint",
    "knob3_joint",
    "knob4_joint",
    "lightswitch_joint",
    "slidedoor_joint",
    "leftdoorhinge",
    "rightdoorhinge",
    "micro0joint",
    "kettle0:Tx",
    "kettle0:Ty",
    "kettle0:Tz",
    "kettle0:Rx",
    "kettle0:Ry",
    "kettle0:Rz",
)

@configurable(pickleable=True)
class KitchenDenseEnv(KitchenSingleTaskEnv):
    # MAX_EPISODE_STEPS = 70
    TASK_NAME = None
    BONUS_THRESH = 0.3

    def __init__(
        self,
        **kwargs
    ):
        with SuppressStdout():
            super().__init__(**kwargs)

        # configure env-objs
        self.obj = {}
        obj_dof_ranges = []
        obj_dof_type = []
        for goal_adr, jnt_name in enumerate(obj_jnt_names):
            jnt_id = self.sim.model.joint_name2id(jnt_name)
            obj_dof_ranges.append(self.sim.model.jnt_range[jnt_id])
            obj_dof_type.append(self.sim.model.jnt_type[jnt_id]) # record joint type (later used to determine goal_th)
        self.obj["dof_proximity"] = self.get_dof_proximity(obj_dof_ranges, obj_dof_type)

    def get_dof_proximity(self, obj_dof_ranges, obj_dof_type):
        """
        Get proximity of obj joints based on their joint type and ranges
        """
        small_angular_th = 0.15
        large_angular_th = np.radians(15)
        small_linear_th = 0.15
        large_linear_th = 0.05

        n_dof = len(obj_dof_type)
        dof_prox = np.zeros(n_dof)

        for i_dof in range(n_dof):
            dof_span = obj_dof_ranges[i_dof][1] - obj_dof_ranges[i_dof][0]
            # pick proximity dist based on joint type and scale
            if obj_dof_type[i_dof] == self.sim.lib.mjtJoint.mjJNT_HINGE:
                dof_prox[i_dof] = small_angular_th*dof_span if dof_span<np.pi else large_angular_th
            elif obj_dof_type[i_dof] == self.sim.lib.mjtJoint.mjJNT_SLIDE:
                dof_prox[i_dof] = small_linear_th*dof_span if dof_span<1.0 else large_linear_th
            else:
                raise TypeError("Unsupported Joint Type")
        return dof_prox

    def step(self, action):
        obs, reward, done, env_info = super().step(action)
        self._t += 1

        success = env_info["rewards"]["success"]
        env_info["success"] = success
        env_info["timeout"] = False
        env_info["VIS:dist_goal"] = env_info["rewards"]["dist_goal"]
        env_info["VIS:dist_hand"] = env_info["rewards"]["dist_hand"]
        env_info["VIS:dist_left_pad_x"] = env_info["rewards"]["dist_left_pad_x"]
        env_info["VIS:dist_right_pad_x"] = env_info["rewards"]["dist_right_pad_x"]
        env_info["VIS:dist_hand_yz"] = env_info["rewards"]["dist_hand_yz"]
        env_info["VIS:ctrl_penalty"] = float(ctrl_penalty) * self._control_penalty

        if self._terminate_on_success:
            done |= success
        if self._t >= self.MAX_EPISODE_STEPS:
            # done = True
            env_info["timeout"] = True
        if self.initializing:
            done = False

        # remove dictionary from env_info since garage doesn't support it
        del env_info["obs_dict"]
        del env_info["rewards"]
        return obs, reward, done, env_info

    def _get_reward_n_score(self, obs_dict):
        reward_dict, score = super()._get_reward_n_score(obs_dict)
        next_q_obs = obs_dict["qp"]
        next_obj_obs = obs_dict["obj_qp"]
        idx_offset = len(next_q_obs)
        next_element = next_obj_obs[..., self._element_idx - idx_offset]
        gripper_pos = self._get_gripper_pos()
        left_pad, right_pad = self._get_pad_pos()
        obj_pos = self._get_obj_pos()
        gripper_yz = gripper_pos + np.array([-gripper_pos[0], 0.0, 0.0])
        obj_yz = obj_pos + np.array([-obj_pos[0], 0.0, 0.0])
        dists = {
            "goal": np.linalg.norm(next_element - self._element_goal),
            "goal_init": self._dist_goal_init,
            "hand": np.linalg.norm(
                obj_pos
                # - self._get_hand_pos()
                - gripper_pos
            ),
            "hand_init": self._dist_hand_init,
            "gripper": self._get_gripper_dist(next_q_obs),
            "left_pad_x": left_pad[0] - obj_pos[0],
            "right_pad_x": right_pad[0] - obj_pos[0],
            "hand_yz": np.linalg.norm(gripper_yz - obj_yz, ord=2),
        }
        reward, success = self._compute_reward(obs_dict, dists)
        bonus = float(success)
        if self._sparse_reward:
            reward = bonus
        else:
            if self._terminate_on_success and success:
                reward = 1.0 * self.MAX_EPISODE_STEPS ### To Do: Check max reward
            reward *= self._reward_scale
        reward_dict["bonus"] = bonus
        reward_dict["dist_goal"] = dists["goal"]
        reward_dict["dist_hand"] = dists["hand"]
        reward_dict["dist_left_pad_x"] = dists["left_pad_x"]
        reward_dict["dist_right_pad_x"] = dists["right_pad_x"]
        reward_dict["dist_hand_yz"] = dists["hand_yz"]
        reward_dict["r_total"] = reward
        reward_dict["success"] = success
        score = bonus
        return reward_dict, score


    def _get_obs(self):
        ### To Do: Add everything in get_obs_dict
        t, qp, qv, obj_qp, obj_qv = self.robot.get_obs(
            self, robot_noise_ratio=self.robot_noise_ratio)

        self.obs_dict = {}
        self.obs_dict['t'] = t
        self.obs_dict['qp'] = qp
        self.obs_dict['qv'] = qv
        self.obs_dict['obj_qp'] = obj_qp
        self.obs_dict['obj_qv'] = obj_qv
        self.obs_dict['goal'] = self.goal

        ### For dense reward calculation
        self.obs_dict['goal_err'] = (
            self.goal - obj_qp[..., self._element_idx - len(qp)]
        )  # mix of translational and rotational erros
        self.obs_dict['approach_err'] = (
            self._get_obj_pos()
            - self._get_hand_pos()
        )

        if self.goal_concat:
            return np.concatenate([self.obs_dict['qp'], self.obs_dict['obj_qp'], self.obs_dict['goal']])

    def _compute_reward(self, obs_dict, dists):
        goal_dist = np.abs(obs_dict["goal_err"])
        rew_goal = -np.sum(goal_dist, axis=-1)
        rew_bonus =  1.0*np.product(goal_dist < 5 * self.obj["dof_proximity"], axis=-1) + 1.0*np.product(goal_dist < 1.67 * self.obj["dof_proximity"], axis=-1)
        rew_approach = - np.linalg.norm(obs_dict["approach_err"], axis=-1)
        reward = rew_goal + 0.5 * rew_bonus + 0.5 * rew_approach

        success = dists[goal] < cls.BONUS_THRESH

        return reward, success
