from manipulator_learning.sim.utils.general import *
from manipulator_learning.sim.utils import transformations as tf


class Gripper(object):
    """
    Store the state of a gripper provide an interface for updates.
    Either update_pose_direct or update_pose_interal should be called before StepSimulation every iteration.
    The actual pose can be found using get_pose(), whereas instance.pose stores the desired pose (as a
    homogeneous transformation matrix in the global frame).

    Note that if position based control is being used, a given position is considered an UPDATE to the
    current pose, NOT an actual pose in the global frame.
    """
    def __init__(self, pybullet_client, body, cid, max_trans_acc, max_rot_acc, max_trans_vel,
                 max_rot_vel, time_step, t_control_style, r_control_style,
                 g_control_style, g_joint_maximums, g_close_time, g_open_time, init_pose=np.eye(4)):

        # user selected parameters -- non-private can be modified on the fly
        self._pb_client = pybullet_client
        self.body = body
        self.cid = cid
        self.max_trans_acc = max_trans_acc  # m/s^2
        self.max_rot_acc = max_rot_acc # rad/s^2
        self.max_trans_vel = max_trans_vel  # m/s
        self.max_rot_vel = max_rot_vel
        self.time_step = time_step
        self.t_control_style = t_control_style
        self.r_control_style = r_control_style  # position good for mice, v/a good for joystick
        self.g_control_style = g_control_style
        self.g_maximums = g_joint_maximums
        self.g_range = self.g_maximums[0][1] - self.g_maximums[0][0], \
                       self.g_maximums[1][1] - self.g_maximums[1][0]
        self.g_close_time = g_close_time  # in seconds
        self.g_open_time = g_open_time  # in seconds

        # state variables -- all vel in m/s or rad/s, all acc in m/s^2 or rad/s^2
        # many of these are intended to be directly set by the user
        self.pose = init_pose  # should be a 4x4 transformation matrix
        self.trans_vel = np.zeros([3])  # CAN be directly modified, but risky since no limits enforced
        self.rot_vel = np.zeros([3])  # CAN be directly modified, but risky since no limits enforced
        self.trans_acc = np.zeros([3])  # directly modified
        self.rot_acc = np.zeros([3])  # directly modified
        self.g_command = False  # only used in binary control, False for open, True for closed
                                # directly set by  user
        self.g_joint_pos = 0  # float value from 0 to 1, 0 is open, 1 is closed, directly modified
                              # if p control is chosen for g_control_style
        self._cur_actual_vel = np.zeros([2,3])  # for getting acceleration
        self._prev_actual_vel = np.zeros([2,3]) # for getting acceleration

    @property
    def t_control_style(self):
        return self.__t_control_style

    @t_control_style.setter
    def t_control_style(self, style):
        control_opts = ['p', 'v', 'a']
        if style not in control_opts:
            raise ValueError("t_control_style must be one of: %s" % control_opts)
        else:
            self.__t_control_style = style

    @property
    def r_control_style(self):
        return self.__r_control_style

    @r_control_style.setter
    def r_control_style(self, style):
        control_opts = ['p', 'v', 'a']
        if style not in control_opts:
            raise ValueError("r_control_style must be one of: %s" % control_opts)
        else:
            self.__r_control_style = style

    @property
    def g_control_style(self):
        return self.__g_control_style

    @g_control_style.setter
    def g_control_style(self, style):
        control_opts = ['p', 'b']
        if style not in control_opts:
            raise ValueError("g_control_style must be one of: %s" % control_opts)
        else:
            self.__g_control_style = style

    def get_des_pose_for_pybullet(self):
        """ Get translation and quaternion for actually updating constraint in pybullet. """
        return self.pose[:3, 3], convert_quat_tf_to_pb(tf.quaternion_from_matrix(self.pose))

    def get_gripper_joints_for_pybullet(self):
        """ Get gripper joints position for pybullet """
        return (self.g_maximums[0][0] + self.g_range[0] * self.g_joint_pos,
                self.g_maximums[1][0] + self.g_range[1] * self.g_joint_pos)

    def modify_pose(self, trans, quat):
        """
        Directly update the current pose of the gripper in the global frame.
        Note that quat is expected in wxyz (aka tf) form.
        """
        t_modify = tf.quaternion_matrix(quat)
        t_modify[:3, 3] = trans
        self.pose = np.dot(self.pose, t_modify)  # modification done in frame of gripper

    def update_state(self, trans=None, quat=None, joint_pos=None):
        '''
        Update the pose of the gripper using internal state variables. User inputs own trans in [x, y, z] or rot in
        [w, x, y, z] form if control style is p for trans or rot.  Also enforces vel and acc limits.

        Should be called every iteration before call to changeConstraint (and also the same number of times as
        stepSimulation).
        '''

        # fix acc based on maximums
        self.trans_acc = np.clip(self.trans_acc, -self.max_trans_acc, self.max_trans_acc)
        self.rot_acc = np.clip(self.trans_acc, -self.max_trans_acc, self.max_trans_acc)

        # update vel based on acc
        if self.t_control_style == 'a':
            self.trans_vel = self.trans_vel + self.trans_acc
        elif self.t_control_style == 'p':
            self.trans_vel = trans / self.time_step  # since in p control, trans is just a modification of current position
            self.trans_acc = np.zeros([3])

        self.trans_vel = np.clip(self.trans_vel, -self.max_trans_vel, self.max_trans_vel)

        if self.r_control_style == 'a':
            self.rot_vel = self.rot_vel + self.rot_acc
            self.rot_vel = np.clip(self.rot_vel + self.rot_acc, -self.max_rot_vel, self.max_rot_vel)
        elif self.r_control_style == 'p':
            eul = np.array(tf.euler_from_quaternion(quat, axes='sxyz'))
            self.rot_vel = eul / self.time_step
            self.rot_acc = np.zeros([3])

        self.rot_vel = np.clip(self.rot_vel, -self.max_rot_vel, self.max_rot_vel)

        # make updates based on speed/acc in timesteps, not in seconds
        if self.t_control_style == 'p':
            des_trans = trans
        else:
            des_trans = self.time_step * self.trans_vel

        if self.r_control_style == 'p':
            des_rot = quat
        else:
            des_rot_eul = self.time_step * self.rot_vel
            des_rot = tf.quaternion_from_euler(des_rot_eul[0], des_rot_eul[1],
                                               des_rot_eul[2], axes='sxyz')

        self.modify_pose(des_trans, des_rot)

        # update the gripper joint positions
        if self.g_control_style == 'b':
            if self.g_command:
                self.g_joint_pos += 1 / self.g_close_time * self.time_step
            else:
                self.g_joint_pos -= 1 / self.g_open_time * self.time_step
        elif self.g_control_style == 'p':
            self.g_joint_pos = joint_pos

        # ensure it never goes further than fully open or closed
        if self.g_joint_pos > 1:
            self.g_joint_pos = 1
        elif self.g_joint_pos < 0:
            self.g_joint_pos = 0

        # get previous velocity for finding acceleration... if this slows things down get rid of it
        self._prev_actual_vel = self._cur_actual_vel.copy()
        vel = self.get_vel()
        self._cur_actual_vel = np.array(vel)

    def get_pose(self):
        ''' Using pybullet frame info, grab the current global pose '''
        return np.array(self._pb_client.getBasePositionAndOrientation(self.body))

    def get_vel(self):
        ''' Using pybullet frame info, grab the current trans and rot velocity in frame of gripper '''
        pose = self.get_pose()
        pose_mat = tf.quaternion_matrix(convert_quat_pb_to_tf(pose[1]))
        pose_mat = np.linalg.inv(pose_mat)
        vel = np.array(self._pb_client.getBaseVelocity(self.body))
        gripper_frame_trans_vel = np.dot(pose_mat, np.append(vel[0], 1))[:3]
        gripper_frame_rot_vel = np.dot(pose_mat, np.append(vel[1], 1))[:3]

        return np.array([gripper_frame_trans_vel, gripper_frame_rot_vel])

    def get_acc(self):
        ''' Using pybullet frame info, grab the current acceleration in frame of gripper '''
        current_vel = np.array(self.get_vel())
        return (current_vel - self._prev_actual_vel) / self.time_step

    def get_g_joint_pos(self, grip_joint_ind):
        ''' Using pybullet info, get the actual gripper joint positions.
         
         grip_joint_ind should be a list of joint indices to output. '''
        joint_pos = []
        for i, joint_ind in enumerate(grip_joint_ind):
            raw_pos = self._pb_client.getJointState(self.body, joint_ind)[0]
            # set the positions to be between 0 and 1
            pos_normalized = (raw_pos - self.g_maximums[i][0]) / self.g_range[i]
            joint_pos.append(pos_normalized)

        return np.array(joint_pos)
