from environments.kitchen import reward_utils
from environments.kitchen.kitchen_base import KitchenSingleTaskEnv


class KitchenBottomBurnerOnEnvV0(KitchenSingleTaskEnv):
    TASK_NAME = "bottom burner-on"
    BONUS_THRESH = 0.3

    # @classmethod
    # def _compute_reward(cls, obs_dict, dists):
    #     in_place = reward_utils.tolerance(
    #         dists["goal"],
    #         bounds=(0, cls.BONUS_THRESH),
    #         margin=abs(dists["goal_init"] - cls.BONUS_THRESH),
    #         sigmoid="long_tail",
    #     )
    #
    #     handle_reach_radius = 0.08
    #     reach = reward_utils.tolerance(
    #         dists["hand"],
    #         bounds=(0, handle_reach_radius),
    #         margin=abs(dists["hand_init"] - handle_reach_radius),
    #         sigmoid="gaussian",
    #     )
    #
    #     ## Try whatever top burner off is doing
    #     # gripper_closed = 1 - dists["gripper"]
    #     # reach = reward_utils.hamacher_product(reach, gripper_closed)
    #
    #     reward = 0.6 * in_place + 0.4 * reach
    #     if dists["goal"] < cls.BONUS_THRESH:
    #         reward = 1.0
    #     return reward

    @classmethod
    def _compute_reward(cls, obs_dict, dists):
        in_place = reward_utils.tolerance(
            dists["goal"],
            bounds=(0, cls.BONUS_THRESH),
            margin=abs(dists["goal_init"] - cls.BONUS_THRESH),
            sigmoid="long_tail",
        )  # [0,1] with larger value better (smaller distance)

        handle_reach_radius = 0.08
        reach = reward_utils.tolerance(
            dists["hand"],
            bounds=(0, handle_reach_radius),
            margin=abs(dists["hand_init"] - handle_reach_radius),
            sigmoid="gaussian",
        )  # [0,1] with larger value better (smaller distance)
        # gripper_closed = 1.0 - dists["gripper"]
        # reach = reward_utils.hamacher_product(reach, gripper_closed)

        # reward = 0.6 * in_place + 0.4 * reach
        reward = reward_utils.hamacher_product(in_place, reach) - 1.0  # [-1, 0] reward
        if dists["goal"] < cls.BONUS_THRESH:
            reward = 0.0
        return reward
