from manipulator_learning.sim.robots.robots import Gripper
from manipulator_learning.sim.utils import transformations as tf
from manipulator_learning.sim.utils.general import *


class FloatingGripper(Gripper):
    """
    The base class for a floating gripper in pybullet. State space includes pose, velocity, and
    acceleration of the actual gripper frame and the current position of the gripper joints. The action
    space, depending on control style, can be acceleration, velocity, or position of the gripper, and
    either a binary open or closed signal to the gripper, or a specific gripper joint position.
    """
    GRIPPER_OPTS = ['kuka', 'pr2']

    def __init__(self, pybullet_client, urdf_root, time_step=0.01, t_control_style='a', r_control_style='p',
                 g_control_style='b', max_trans_acc=.07, max_trans_vel=.7, max_rot_vel=10, gripper='kuka'):

        self._pb_client = pybullet_client
        self.urdf_root = urdf_root
        self.time_step = time_step
        self._step_counter = 0
        self.t_control_style = t_control_style
        self.r_control_style = r_control_style
        self.g_control_style = g_control_style
        self.gripper = gripper

        if self.gripper == 'kuka':
            self.g_maximums = [[1.3, 1.5], [1.7, 1.5]]  # fully open is number on left, fully closed on right
            self.g_force = 10
            self._g_joint_ind = [0, 3]
        elif self.gripper == 'pr2':
            self.g_maximums = [[.550569, 0], [.550569, 0]]
            self.g_force = 5
            self._g_joint_ind = []  # todo implement this if necessary
        self.constraint_id = None  # set in reset
        self.body_id = None  # set in reset
        self.gripper_joints = None  # set in reset
        self.reset(reload_urdf=True)  # ***this needs to be called to set some instance variables

        #####################################################################################
        ##### This is where acceleration and velocity maximums can be set and modified ######
        #####################################################################################
        super().__init__(pybullet_client=self._pb_client, body=self.body_id,
                         cid=self.constraint_id, max_trans_acc=max_trans_acc,
                         max_rot_acc=25, max_trans_vel=max_trans_vel, max_rot_vel=max_rot_vel, time_step=time_step,
                         t_control_style=self.t_control_style,
                         r_control_style=self.r_control_style,
                         g_control_style=self.g_control_style,
                         g_joint_maximums=self.g_maximums, g_close_time=.1,
                         g_open_time=.05, init_pose=self.pose)

    def step(self, action):
        """ Apply action, as done in a gym style environment. For description of actions,
        see apply_action method"""
        self.apply_action(action[0], action[1], action[2])
        self._pb_client.stepSimulation()
        self.receive_observation()
        self._step_counter += 1

    def apply_action(self, t_command, r_command, g_command):
        """
        Set the desired translational, rotational and gripper commands.
        Functionality depends on control styles of gripper.

        For acceleration (a) control, settings should be xyz floats for t and r
        For velocity (v) control, settings should be xyz floats for t and r
        For position (p) control, settings should be xyz float for t and wxyz quat for r,
            as well as float from 0 (open) to 1 (closed) for g_command
        For binary (b) control for the gripper joint positions (g_command), setting should be
            False (open) or True (closed)
        """

        # handle command for position style control
        p_trans = None; p_quat = None; p_joint_pos = None

        if self.t_control_style == 'a':
            self.trans_acc = t_command
        elif self.t_control_style == 'v':
            self.trans_acc = 0
            self.trans_vel = t_command
        elif self.t_control_style == 'p':
            self.trans_acc = 0
            # self.trans_vel = 0  # set properly in update_state
            p_trans = t_command

        if self.r_control_style == 'a':
            self.rot_acc = r_command
        elif self.r_control_style == 'v':
            self.rot_acc = 0
            self.rot_vel = r_command
        elif self.r_control_style == 'p':
            self.rot_acc = 0
            # self.rot_vel = 0  # this is set properly in update state
            p_quat = r_command

        if self.g_control_style == 'b':
            self.g_command = g_command
        if self.g_control_style == 'p':
            p_joint_pos = g_command

        # use Gripper method to actually update the state
        self.update_state(p_trans, p_quat, p_joint_pos)

        # update actual gripper pose in pybullet
        pose = self.get_des_pose_for_pybullet()
        self._pb_client.changeConstraint(self.constraint_id, pose[0], pose[1], maxForce=50)

        # update gripper joint poses in pybullet
        g_pos = self.get_gripper_joints_for_pybullet()
        for i in range(2):
            self._pb_client.setJointMotorControl2(self.body_id, self.gripper_joints[i],
                                                  controlMode=self._pb_client.POSITION_CONTROL,
                                                  targetPosition=g_pos[i], force=self.g_force)

    def receive_observation(self):
        """
        Get the current data from the gripper (pose, velocity, acceleration, gripper position
        """
        return self.get_pose(), self.get_vel(), self.get_acc(), \
               self.get_g_joint_pos(self._g_joint_ind)

    def receive_action(self):
        """
        Get the current desired data from the gripper (pose, velocity, acceleration, gripper command).
        Many of these values might be zero depending on the control method selected.
        """
        return self.pose, [self.trans_vel, self.rot_vel], [self.trans_acc, self.rot_acc], \
                self.g_command

    def reset(self, reload_urdf, init_pose=None):
        """
        Reset the gripper inside the simulation to any pose.
        :param reload_urdf: Whether or not to completely reload the urdf
        :param init_pose: The desired initial pose of the gripper, as a 3-float translation vector followed by a
                          4-float [xyzw] quaternion.
        """
        # default argument
        if init_pose is None:
            init_pose = ([0, 0, 1], [0, 0, 0, 1])

        if reload_urdf:
            if self.gripper == 'kuka':
                self.body_id = self._pb_client.loadSDF(self.urdf_root + "/gripper_only.sdf")[0]
                self.constraint_id = self._pb_client.createConstraint(self.body_id, -1, -1, -1,
                                                                      self._pb_client.JOINT_FIXED,
                                                                      [0,0,0], [0,0,0,0], [0, 0, 1])
                self.gripper_joints = [0, 3]  # joints that actually grip
            elif self.gripper == 'pr2':
                self.body_id = [self._pb_client.loadURDF("pr2_gripper.urdf", 0.500000, 0.300006, 0.700000,
                                                         -0.000000, -0.000000, -0.000031, 1.000000)][0]
                self.constraint_id = self._pb_client.createConstraint(self.body_id, -1, -1, -1,
                                                                      self._pb_client.JOINT_FIXED,
                                                                      [0, 0, 0], [0.2, 0, 0],
                                                                      [0.500000, 0.300006, 0.700000])
                self.gripper_joints = [0, 2]
            else:
                raise NotImplementedError('Gripper type must be one of: %s' % FloatingGripper.GRIPPER_OPTS)

        # set body pose
        self._pb_client.resetBasePositionAndOrientation(self.body_id, init_pose[0], init_pose[1])
        self.pose = tf.quaternion_matrix(convert_quat_pb_to_tf(init_pose[1]))
        self.pose[:3, 3] = init_pose[0]

        if self.gripper == 'kuka':
            # set the outer joints to base position
            self._pb_client.setJointMotorControl2(self.body_id, 2, controlMode=self._pb_client.POSITION_CONTROL,
                                                  targetPosition=0, force=10)
            self._pb_client.setJointMotorControl2(self.body_id, 5, controlMode=self._pb_client.POSITION_CONTROL,
                                                  targetPosition=0, force=10)

        # set gripper positions
        for i in range(2):
            self._pb_client.setJointMotorControl2(self.body_id, self.gripper_joints[i],
                                                  controlMode=self._pb_client.POSITION_CONTROL,
                                                  targetPosition=self.g_maximums[i][0])
