import rospy
from pick_and_place_module.eef_control import MoveGroupControl
from pick_and_place_module.grasping import GripperInterface
from copy import deepcopy
from math import pi
from tf.transformations import euler_from_quaternion, quaternion_from_euler

def add_lists(list1, list2):

    return [a + b for a, b in zip(list1, list2)]

def subtract_lists(list1, list2):

    return [a - b for a, b in zip(list1, list2)]

class PickAndPlace:
    def __init__(self, gripper_offset, intermediate_z_stop,speed):
        self.gripper_offset = gripper_offset
        self.intermediate_z_stop = intermediate_z_stop
        self.pose0 = None
        self.pose1 = None
        self.gripper_pose = None
        self.gripper_force = None
        self.moveit_control = MoveGroupControl(speed=speed)
        self.gripper = GripperInterface()

    def setPickPose(self, x, y, z, roll, pitch, yaw):
        self.pose0 = [x, y, z, roll + pi/4, pitch, yaw]
    
    def setDropPose(self, x, y, z, roll, pitch, yaw):
        self.pose1 = [x, y, z, roll + pi/4, pitch, yaw]
    
    def setGripperPose(self, finger1, finger2):
        self.gripper_pose = [finger1, finger2]
    
    def generate_waypoints(self, destination_pose, action): # @
        '''
        Generated waypoints are for a particular application
        This is to be changed based on the application it is being used
        '''
        move_group = self.moveit_control

        waypoints = []

        if not action:
            intermediate_pose = deepcopy(destination_pose)
            intermediate_pose[2] = self.intermediate_z_stop
            waypoints.append(intermediate_pose)

            destination_pose_ = deepcopy(destination_pose)
            destination_pose_[2] = destination_pose_[2]  + 0.1 
            waypoints.append(destination_pose_)

            destination_pose_ = deepcopy(destination_pose)
            destination_pose_[2] = destination_pose_[2]  + self.gripper_offset 
            waypoints.append(destination_pose_)
               
        else:
            current_pose = move_group.get_current_pose()
            current_pose_ = deepcopy(destination_pose)
            current_pose_[0] = current_pose.position.x
            current_pose_[1] = current_pose.position.y
            current_pose_[2] = self.intermediate_z_stop
            waypoints.append(current_pose_)
        
            intermediate_pose = deepcopy(destination_pose)
            intermediate_pose[2] = self.intermediate_z_stop
            waypoints.append(intermediate_pose)
            waypoints.append(destination_pose)
        return waypoints
    
    def execute_pick_and_place(self, gripper_force=5, axis=0):
        move_group = self.moveit_control
        self.gripper.grasp(0.1, 0)
        rospy.sleep(2)        

        waypoints = self.generate_waypoints(self.pose0, 0, axis=axis)
        
        for waypoint in waypoints:
            rospy.loginfo("Executing waypoint: %s", waypoint)
            move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])
        self.gripper.grasp(0.01, gripper_force)
        rospy.sleep(1)
            
        waypoints = self.generate_waypoints(self.pose1, 1, axis=axis)
        
        for waypoint in waypoints:
            rospy.loginfo("Executing waypoint: %s", waypoint)
            move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])
                        
        self.gripper.grasp(0.1, 0)
        rospy.sleep(2)

        self.go_to_ready_pose()

    def execute_cartesian_pick_and_place(self):

        waypoints = self.generate_waypoints(self.pose0, 0)
        rospy.loginfo("Generated waypoints for pick: %s", waypoints)
 
        for waypoint in waypoints:
            rospy.loginfo("Executing waypoint: %s", waypoint)
            self.moveit_control.follow_cartesian_path([waypoint])

        self.gripper.grasp(self.gripper_pose[0], self.gripper_pose[1])
        rospy.sleep(2)
    
        waypoints = self.generate_waypoints(self.pose1, 1)
        rospy.loginfo("Generated waypoints for drop: %s", waypoints)
 
        for waypoint in waypoints:
            rospy.loginfo("Executing waypoint: %s", waypoint)
            self.moveit_control.follow_cartesian_path([waypoint])

        self.gripper.grasp(0.05, 0.05)
        rospy.sleep(2)

    def go_to_ready_pose(self):
        move_group = self.moveit_control
        move_group.go_to_joint_state(0.0, -0.785, 0.0, -2.356, 0.0, 1.571, 0.785)
        rospy.sleep(2)  
        print("go to ready pose")

    def execute_cartesian_pick_up(self):
        move_group = self.moveit_control
        
        self.gripper.grasp(0.05, 0.05)
        rospy.sleep(2)        
        
        waypoints = self.generate_waypoints(self.pose0, 0)
        rospy.loginfo("Generated waypoints for pick: %s", waypoints)
   
        for waypoint in waypoints:
            rospy.loginfo("Executing waypoint: %s", waypoint)
            self.moveit_control.follow_cartesian_path([waypoint])

        self.gripper.grasp(self.gripper_pose[0], self.gripper_pose[1])
        rospy.sleep(1)
        
        waypoints = []
        current_pose_ = deepcopy(self.pose0)
        current_pose_[2] = self.intermediate_z_stop
        waypoints.append(current_pose_)

        for waypoint in waypoints:
            rospy.loginfo("Executing waypoint: %s", waypoint)
            self.moveit_control.follow_cartesian_path([waypoint])

        rospy.sleep(1)        


import numpy as np

class PrimitiveSkill:
    def __init__(self, gripper_offset=0.05, intermediate_z_stop=0.3, intermediate_distance=0.2, speed=0.13,
                  push_length=0.02, pull_length=0.01, sweep_count=3, sweep_width=0.03):
        self.gripper_offset = gripper_offset
        self.intermediate_z_stop = intermediate_z_stop
        self.intermediate_distance = intermediate_distance
        self.push_length = push_length
        self.pull_length = pull_length
        self.sweep_count = sweep_count
        self.sweep_width = sweep_width
        self.pose0 = None
        self.pose1 = None
        self.target_pose = None

        self.cartesian = True
        self.waypoint_density = 5
        self.moveit_control = MoveGroupControl(speed)
        self.gripper = GripperInterface()
    
    def getPose(self,):
        pose = self.moveit_control.get_current_pose().position
        orientation = self.moveit_control.get_current_pose().orientation

        quaternion = [orientation.x, orientation.y, orientation.z, orientation.w]

        # Convert quaternion to roll, pitch, yaw
        roll, pitch, yaw = euler_from_quaternion(quaternion)

        # Extract position components
        pose0_x, pose0_y, pose0_z = pose.x, pose.y, pose.z

        # Assign orientation components
        pose0_roll, pose0_pitch, pose0_yaw = roll, pitch, yaw

        import numpy as np
 
        # adjust
        pose0_yaw = pose0_yaw + np.pi/4
        pose0_pitch = pose0_pitch + np.pi
        pose0_roll = pose0_roll + np.pi
        print(f"primitive_skill.setPose0({pose0_x}, {pose0_y}, {pose0_z}, {pose0_yaw}, {pose0_pitch}, {pose0_roll})")


    def setPose0(self, x, y, z, roll, pitch, yaw):
        self.pose0 = [x, y, z, roll + pi/4, pitch, yaw]

    def setPose1(self, x, y, z, roll, pitch, yaw):
        self.pose1 = [x, y, z, roll + pi/4, pitch, yaw]

    def setTargetPose(self, x, y, z, roll, pitch, yaw):
        self.target_pose = [x, y, z, roll + pi/4, pitch, yaw]
    
    def current_pose_list(self, destination_pose):
        move_group = self.moveit_control
        current_pose = move_group.get_current_pose().position
        current_pose1_ = deepcopy(destination_pose)
        current_pose1_[0] = current_pose.x
        current_pose1_[1] = current_pose.y
        current_pose1_[2] = current_pose.z
        return current_pose1_

    def interpolate_pose(self, start, end, steps):
        """Linear interpolation between start and end points."""
        start_array = np.array(start)
        end_array = np.array(end)
        array_list = [start_array + (end_array-start_array) * i / (steps - 1) for i in range(steps)]
        return [array.tolist() for array in array_list]

    def generate_waypoints(self, destination_pose, mode, axis=0, distance = 3): # @
        '''
        Generated waypoints are for a particular application
        This is to be changed based on the application it is being used
        '''
        print(f"mode:{mode}  axis:{axis}")

        move_group = self.moveit_control
        waypoints = []
        if mode == 0:  # pick up
            current_pose1_ = self.current_pose_list(destination_pose)
        
            intermediate_pose = deepcopy(destination_pose)
            intermediate_pose[2] = self.intermediate_z_stop
            if axis<2:
                intermediate_pose[axis] -= self.intermediate_distance

            waypoints += self.interpolate_pose(current_pose1_, intermediate_pose, self.waypoint_density)

            destination_pose_ = deepcopy(destination_pose)
            destination_pose_[2] += self.gripper_offset
            waypoints += self.interpolate_pose(intermediate_pose, destination_pose_, self.waypoint_density)

        elif mode == 1:  # place
            current_pose1_ = self.current_pose_list(destination_pose)

            intermediate_pose = deepcopy(destination_pose)
            intermediate_pose[2] = self.intermediate_z_stop
            if axis<2:
                intermediate_pose[axis] -= self.intermediate_distance

            waypoints += self.interpolate_pose(current_pose1_, intermediate_pose, self.waypoint_density)
            
            destination_pose_ = deepcopy(destination_pose)
            destination_pose_[2] += self.gripper_offset
            waypoints += self.interpolate_pose(intermediate_pose, destination_pose_, self.waypoint_density)

        elif mode == 2: # push
            current_pose1_ = self.current_pose_list(destination_pose)
        
            destination_pose_ = deepcopy(destination_pose)
            destination_pose_[2] += self.gripper_offset 
            # print(axis)
            destination_pose_[axis] +=  self.push_length 
            waypoints += self.interpolate_pose(current_pose1_, destination_pose_, self.waypoint_density)

        elif mode == 3: # go to
            current_pose1_ = self.current_pose_list(destination_pose)
        
            intermediate_pose = deepcopy(destination_pose)
            intermediate_pose[2] = self.intermediate_z_stop
            intermediate_pose[axis] -= self.intermediate_distance
            waypoints += self.interpolate_pose(current_pose1_, intermediate_pose, self.waypoint_density)

            destination_pose_ = deepcopy(destination_pose)
            destination_pose_[2] +=  self.gripper_offset
            waypoints += self.interpolate_pose(intermediate_pose, destination_pose_, self.waypoint_density)

        elif mode == 4: # pull (x-axis)
            current_pose1_ = self.current_pose_list(destination_pose)
        
            destination_pose_ = deepcopy(destination_pose)
            destination_pose_[2] += self.gripper_offset 
            destination_pose_[axis] = destination_pose_[axis] - self.pull_length 
            waypoints += self.interpolate_pose(current_pose1_, destination_pose_, self.waypoint_density)

        elif mode == 5: # sweep (x-axis)
            current_pose1_ = self.current_pose_list(destination_pose)

            # Create intermediate pose for the initial move
            intermediate_pose = deepcopy(destination_pose)
            intermediate_pose[2] = self.intermediate_z_stop  # Adjust z-stop
            intermediate_pose[axis] -= self.intermediate_distance  # Move back along the axis

            # Add intermediate pose to waypoints
            waypoints += self.interpolate_pose(current_pose1_, intermediate_pose, self.waypoint_density)

            # Sweep movement
            for _ in range(self.sweep_count):

                print(self.sweep_count)
                # Move positively along the specified axis
                destination_pose_ = deepcopy(destination_pose)
                destination_pose_[axis] += distance  # Move positively along the axis
                destination_pose_[2] += self.gripper_offset + 0.025  # Adjust for sweep offset
                waypoints.append(destination_pose_)

                # Move negatively along the specified axis
                destination_pose_ = deepcopy(destination_pose)
                destination_pose_[axis] -= distance  # Move negatively along the axis
                destination_pose_[2] += self.gripper_offset + 0.025  # Adjust for sweep offset
                waypoints.append(destination_pose_)

        elif mode == 6: # go to distance 0.15
            current_pose1_ = self.current_pose_list(destination_pose)
            intermediate_pose = deepcopy(destination_pose)
            intermediate_pose[2] = self.intermediate_z_stop
            intermediate_pose[axis] -= self.intermediate_distance

            waypoints += self.interpolate_pose(current_pose1_, intermediate_pose, self.waypoint_density)
        
            destination_pose_ = deepcopy(destination_pose)
            destination_pose_[axis] += 0.15 
            waypoints.append(destination_pose_)
        
        return waypoints
    
    
    def go_to_ready_pose(self):
        move_group = self.moveit_control
        move_group.go_to_joint_state(0.0, -0.785, 0.0, -2.356, 0.0, 1.571, 0.785)
        rospy.sleep(2)  
        print("go to ready pose")

    def go_to_pose_goal(self, waypoint):
        move_group = self.moveit_control  
        move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])

    def execute_pick_and_place(self, gripper_force=5, axis=0):
        move_group = self.moveit_control
        self.gripper.grasp(0.1, 0)
        rospy.sleep(2)        

        waypoints = self.generate_waypoints(self.pose0, 0, axis=axis)
        
        for waypoint in waypoints:
            rospy.loginfo("Executing waypoint: %s", waypoint)
            move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])
        self.gripper.grasp(0.01, gripper_force)
        rospy.sleep(1)
            
        waypoints = self.generate_waypoints(self.pose1, 1, axis=axis)
        
        for waypoint in waypoints:
            rospy.loginfo("Executing waypoint: %s", waypoint)
            move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])
                        
        self.gripper.grasp(0.1, 0)
        rospy.sleep(2)

        self.go_to_ready_pose()

    def execute_pick(self, gripper_force = 5, axis=0):
        move_group = self.moveit_control
        self.gripper.grasp(0.1, 0)
        rospy.sleep(2)        

        waypoints = self.generate_waypoints(self.target_pose, 0 ,axis=axis)
        

        for waypoint in waypoints:
            rospy.loginfo("Executing Pick waypoint: %s", waypoint)
            move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])

        self.gripper.grasp(0.005, gripper_force)
        rospy.sleep(3)
        self.go_to_ready_pose()

    def execute_place(self, axis=0):
        move_group = self.moveit_control
        waypoints = self.generate_waypoints(self.target_pose, 1, axis=axis)
        
        for waypoint in waypoints:
            rospy.loginfo("Executing Place waypoint: %s", waypoint)
            move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])
                        
        self.gripper.grasp(0.1, 0)
        rospy.sleep(2)        
        self.go_to_ready_pose()

    def execute_push(self, gripper_force = 5, axis = 0):
        print(axis)
        move_group = self.moveit_control     
 
        rospy.sleep(1)

        # Go to grasping part 
        self.target_pose[axis] -= self.pull_length
        waypoints = self.generate_waypoints(self.target_pose, 3, axis=axis)
        for waypoint in waypoints:
            rospy.loginfo("Executing Push waypoint: %s", waypoint)
            move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])
        
        # Grasp the part 
        waypoints = self.generate_waypoints(self.target_pose, 2, axis = axis)
        for waypoint in waypoints:
            rospy.loginfo("Executing Push waypoint: %s", waypoint)
            move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])

        self.go_to_ready_pose()
        self.gripper.grasp(0.1, 0)
        rospy.sleep(1)

    def execute_pull(self, gripper_force = 5, axis=0):
        move_group = self.moveit_control     
        self.gripper.grasp(0.1, 0)
        rospy.sleep(1)

        # Go to grasping part 
        waypoints = self.generate_waypoints(self.target_pose, 3, axis=axis)
        for waypoint in waypoints:
            rospy.loginfo("Executing Pull waypoint: %s", waypoint)
            move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])
        
        self.gripper.grasp(0.01, gripper_force)
        rospy.sleep(2)

        # Grasp the part 
        waypoints = self.generate_waypoints(self.target_pose, 4)
        for waypoint in waypoints:
            rospy.loginfo("Executing Pull waypoint: %s", waypoint)
            move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])
        self.gripper.grasp(0.1, 0)
        rospy.sleep(2)
        self.go_to_ready_pose()

    def execute_sweep(self, axis=0, distance = 3): 
        move_group = self.moveit_control 
        waypoints = self.generate_waypoints(self.target_pose, 5, axis=axis, distance = distance)
        for waypoint in waypoints:
            rospy.loginfo("Executing Sweep waypoint: %s", waypoint)
            move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])        

    def execute_rotate(self, gripper_force=5, axis=0):
        move_group = self.moveit_control 
        self.gripper.grasp(0.1, 0)
        rospy.sleep(2)
        # Go to Rotating part 
        waypoints = self.generate_waypoints(self.target_pose, 3, axis=axis)
        for waypoint in waypoints:
            rospy.loginfo("Executing Rotate waypoint: %s", waypoint)
            move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])
        
        self.gripper.grasp(0.01, gripper_force)
        rospy.sleep(2)

        # Rotate the part 
        current_joint_values = move_group.get_current_joint_states()
        current_joint_values[-1] -=  pi/2
        move_group.go_to_joint_state(current_joint_values[0],
                                     current_joint_values[1],
                                     current_joint_values[2],
                                     current_joint_values[3],
                                     current_joint_values[4],
                                     current_joint_values[5],
                                     current_joint_values[6],)
        target_pose = deepcopy(self.target_pose)
        target_pose[2] += 0.2
        self.go_to_pose_goal(target_pose)

    def execute_go(self, axis=0):
        move_group = self.moveit_control 
      
        waypoints = self.generate_waypoints(self.target_pose, 3, axis=axis)
        for waypoint in waypoints:
            rospy.loginfo("Executing Go waypoint: %s", waypoint)
            move_group.go_to_pose_goal(waypoint[0], waypoint[1], waypoint[2], waypoint[3], waypoint[4], waypoint[5])
        
    def execute_gripper(self, width1, width2, force1, force2):
        self.gripper.grasp(width1, force1)
        rospy.sleep(5)
        self.gripper.grasp(width2, force2)

    def current_pose(self):
        move_group = self.moveit_control 
        quaternion_pose = move_group.get_current_pose()

        current_euler_pose = euler_from_quaternion((
                                quaternion_pose.orientation.x,
                               quaternion_pose.orientation.y,
                                quaternion_pose.orientation.z,
                                quaternion_pose.orientation.w
                                ))
        return current_euler_pose