from typing import Dict

import torch

from furniture_bench_api.skills.robot_skill import RobotSkill


class SetGripperStateSkill(RobotSkill):

    def __init__(self, tolerance, sampling_interval):
        super().__init__(tolerance, sampling_interval)
        self.last_5_values = []

    def map_to_gripper_action(self, gripper_open: torch.Tensor) -> torch.Tensor:
        if gripper_open > 0.5:
            return torch.as_tensor([0], dtype=torch.float32, device=gripper_open.device)
        else:
            return torch.as_tensor([0.7], dtype=torch.float32, device=gripper_open.device)

    def compute_action(self, current_pose, inputs):
        gripper_open = inputs["gripper_percent_open"]
        if gripper_open > 0.5:
            gripper_action = torch.as_tensor([-1], dtype=torch.float32, device=gripper_open.device)
        else:
            gripper_action = torch.as_tensor([1], dtype=torch.float32, device=gripper_open.device)
        action = self.default_action().to(gripper_open.device)
        action[7] = gripper_action
        return action

    def is_goal_reached(self, current_pose, inputs):
        gripper_open = inputs["gripper_percent_open"]
        if gripper_open > 0.5:
            goal_state = torch.as_tensor([0.07], dtype=torch.float32, device=gripper_open.device)
            return torch.allclose(current_pose[7], goal_state, atol=0.0005)
        else:
            # when closing the gripper, we might grasp something, so the gripper doesn't close 100%
            # therefore we set the goal until gripper does not close anymore
            self.last_5_values.append(current_pose[7])
            if len(self.last_5_values) > 5:
                del self.last_5_values[0]
            if len(self.last_5_values) == 5:
                return (self.last_5_values[0] - self.last_5_values[-1]) < 0.01
            else:
                return False
