import numpy as np
import matplotlib.pyplot as plt
import zmq

from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm

from copy import deepcopy as copy
from openteach.constants import *
from openteach.utils.timer import FrequencyTimer
from openteach.utils.network import ZMQKeypointSubscriber , ZMQKeypointPublisher
from openteach.utils.vectorops import *
from openteach.utils.files import *
from openteach.robot.bimanual_left import BimanualLeft
from scipy.spatial.transform import Rotation, Slerp
from .operator import Operator
from scipy.spatial.transform import Rotation as R
from numpy.linalg import pinv
import time


np.set_printoptions(precision=2, suppress=True)

def test_filter(save_data=False):
    if save_data:
        filter = CompStateFilter(np.asarray( [ 0.575466,   -0.17820767,  0.23671454, -0.281564 ,  -0.6797597,  -0.6224841 ,  0.2667619 ]))
        timer = FrequencyTimer(VR_FREQ)

        i = 0
        while True:
            try:
                timer.start_loop()

                rand_pos = np.random.randn(7)
                filtered_pos = filter(rand_pos)

                print('rand_pos: {} - filtered_pos: {}'.format(rand_pos, filtered_pos)) 

                if i == 0:
                    all_poses = np.expand_dims(np.stack([rand_pos, filtered_pos], axis=0), 0)
                else:
                    all_poses = np.concatenate([
                        all_poses,
                        np.expand_dims(np.stack([rand_pos, filtered_pos], axis=0), 0)
                    ], axis=0)

                print('all_poses shape: {}'.format(
                    all_poses.shape
                ))

                i += 1
                timer.end_loop()

            except KeyboardInterrupt:
                np.save('all_poses.npy', all_poses)
                break

    else:
        all_poses = np.load('all_poses.npy')
        fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(10,10))
        pbar = tqdm(total=len(all_poses))
        for i in range(len(all_poses)):
            filtered_pos = all_poses[:i+1,1,:]
            rand_pos = all_poses[:i+1,0,:]

            # print('filtered_pos.shape: {}'.format(filtered_pos.shape))
            for j in range(filtered_pos.shape[1]):
                axs[int(j / 3), j % 3].plot(filtered_pos[:,j], label='Filtered')
                axs[int(j / 3), j % 3].plot(rand_pos[:,j], label='Actual')
                axs[int(j / 3), j % 3].set_title(f'{j}th Axes')
                axs[int(j / 3), j % 3].legend()

            pbar.update(1)
            plt.savefig(os.path.join(f'all_poses/state_{str(i).zfill(3)}.png'))
            fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(10,10))

# Rotation should be filtered when it's being sent
class Filter:
    def __init__(self, state, comp_ratio=0.6):
        self.pos_state = state[:3]
        self.ori_state = state[3:7]
        self.comp_ratio = comp_ratio

    def __call__(self, next_state):
        self.pos_state = self.pos_state[:3] * self.comp_ratio + next_state[:3] * (1 - self.comp_ratio)
        ori_interp = Slerp([0, 1], Rotation.from_rotvec(
            np.stack([self.ori_state, next_state[3:7]], axis=0)),)
        self.ori_state = ori_interp([1 - self.comp_ratio])[0].as_rotvec()
        return np.concatenate([self.pos_state, self.ori_state])

class BimanualLeftArmOperator(Operator):
    def __init__(
        self,
        host, transformed_keypoints_port,
        moving_average_limit,
        allow_rotation=False,
        arm_type='main_arm',
        use_filter=False,
        arm_resolution_port = None, 
        teleoperation_reset_port = None,
        gripper_port=None,
        gripper_rotate_port=None):

        self.notify_component_start('Bimanual arm operator')
        self._transformed_arm_keypoint_subscriber = ZMQKeypointSubscriber(
            host=host,
            port=transformed_keypoints_port,
            topic='transformed_hand_frame'
        )
        self._transformed_hand_keypoint_subscriber = ZMQKeypointSubscriber(
            host=host,
            port=transformed_keypoints_port,
            topic='transformed_hand_coords'
        )
        self._robot = BimanualLeft(ip=LEFT_ARM_IP)
        self.robot.reset()
        self.arm_type= arm_type
        self.allow_rotation= allow_rotation
        self.resolution_scale =1
        self.arm_teleop_state = ARM_TELEOP_STOP
        self.new_gripper_state= GRIPPER_OPEN
        
        self._arm_resolution_subscriber = ZMQKeypointSubscriber(
            host= host,
            port= arm_resolution_port,
            topic = 'button'
        )
        self.gripper_publisher = ZMQKeypointPublisher(
            host=host,
            port=gripper_port
        )

        self.cartesian_publisher = ZMQKeypointPublisher(
            host=host,
            port=8116
        )

        self.joint_publisher = ZMQKeypointPublisher(
            host=host,
            port=8117
        )

        self.cartesian_command_publisher = ZMQKeypointPublisher(
            host=host,
            port=8121
        )

        # self._arm_teleop_state_subscriber = ZMQKeypointSubscriber(
        #         host = host, 
        #         port = teleoperation_reset_port,
        #         topic = 'pause_left'
        #     )

        # self.gripper_subscriber = ZMQKeypointSubscriber(
        #         host = host,
        #         port = gripper_port,
        #         topic ='gripper_left'
        # )

        # self.gripper_rotate_subscriber = ZMQKeypointSubscriber(
        #         host = host,
        #         port = gripper_rotate_port,
        #         topic ='gripper_rotate_left'
        #)
        
        home=self.robot.get_cartesian_position()
        #print("Position",self._robot.robot.get_position_aa())
        home_pose=np.array(home)
        self.gripper_flag=1
        self.pause_flag=1
        self.flag=0
        self.H_HT_HI_n  = np.zeros((4,4))
        self.robot_init_H = self.robot_pose_aa_to_affine(home_pose)
        self.is_first_frame= True
        self.use_filter = use_filter
        if use_filter:
            robot_init_cart = self._homo2cart(self.robot_init_H)
            self.comp_filter = Filter(robot_init_cart, comp_ratio=0.8)

        self._timer = FrequencyTimer(90)

        self.direction_counter = 0
        self.current_direction = 0

        # Moving average queues
        self.moving_Average_queue = []
        self.moving_average_limit = moving_average_limit

        self.hand_frames = []
        self.direction_counter = 0
        self.current_direction = 0

        # Moving average queues
        self.moving_Average_queue = []
        self.moving_average_limit = moving_average_limit

        self.hand_frames = []
        self.prev_gripper_flag=0
        self.prev_pause_flag=0
        self.pause_cnt=0
        self.gripper_correct_state=1

    
    @property
    def timer(self):
        return self._timer

    @property
    def robot(self):
        return self._robot

    def return_real(self):
        return True

    @property
    def transformed_hand_keypoint_subscriber(self):
        return self._transformed_hand_keypoint_subscriber

    @property
    def transformed_arm_keypoint_subscriber(self):
        return self._transformed_arm_keypoint_subscriber
        

    def robot_pose_aa_to_affine(self,pose_aa: np.ndarray) -> np.ndarray:
        """Converts a robot pose in axis-angle format to an affine matrix.
        Args:
            pose_aa (list): [x, y, z, ax, ay, az] where (x, y, z) is the position and (ax, ay, az) is the axis-angle rotation.
            x, y, z are in mm and ax, ay, az are in radians.
        Returns:
            np.ndarray: 4x4 affine matrix [[R, t],[0, 1]]
        """

        rotation = R.from_rotvec(pose_aa[3:]).as_matrix()
        translation = np.array(pose_aa[:3]) / SCALE_FACTOR
        

        return np.block([[rotation, translation[:, np.newaxis]],
                        [0, 0, 0, 1]])

    
    def _get_hand_frame(self):
        #print('WAITING FOR GET HAND FRAME!')
        #data = None  # Initialize with a default value
        for i in range(10):
            data = self.transformed_arm_keypoint_subscriber.recv_keypoints(flags=zmq.NOBLOCK)
            if data is not None:
                break 
        # print('data: {}'.format(data))
        if data is None:
            return None
        return np.asanyarray(data).reshape(4, 3)
    
    def _get_resolution_scale_mode(self):
        data = self._arm_resolution_subscriber.recv_keypoints()
        res_scale = np.asanyarray(data).reshape(1)[0] # Make sure this data is one dimensional
        return res_scale
        # return ARM_LOW_RESOLUTION   
        # 
    def _get_gripper_mode(self):
        data = self.gripper_subscriber.recv_keypoints()
        res_scale = np.asanyarray(data).reshape(1)[0] # Make sure this data is one dimensional
        return res_scale
        # return ARM_LOW_RESOLUTION 

    def _get_arm_teleop_state(self):
        reset_stat = self._arm_teleop_state_subscriber.recv_keypoints()
        reset_stat = np.asanyarray(reset_stat).reshape(1)[0] # Make sure this data is one dimensional
        # if reset_stat == ARM_TELEOP_CONT: 
        #     print('ARM TELEOP RESET STAT == ARM_TELEOP_CONT')
        # else:
        #     print('ARM TELEOP RESET STATE == ARM_TELEOP_STOP')
        return reset_stat
    
    def _get_arm_teleop_state_from_hand_keypoints(self):
        pause_state ,pause_status,pause_left =self.get_pause_state_from_hand_keypoints()
        pause_status =np.asanyarray(pause_status).reshape(1)[0] 

        return pause_state,pause_status,pause_left
        # return ARM_TELEOP_CONT

    def _clip_coords(self, coords):
        # TODO - clip the coordinates
        return coords

    def _turn_frame_to_homo_mat(self, frame):
        t = frame[0]
        R = frame[1:]

        homo_mat = np.zeros((4, 4))
        homo_mat[:3, :3] = np.transpose(R)
        homo_mat[:3, 3] = t
        homo_mat[3, 3] = 1

        return homo_mat

    def _homo2cart(self, homo_mat):
        # Here we will use the resolution scale to set the translation resolution
        t = homo_mat[:3, 3]
        R = Rotation.from_matrix(
            homo_mat[:3, :3]).as_rotvec(degrees=False)

        cart = np.concatenate(
            [t, R], axis=0
        )

        return cart
    
    def _get_scaled_cart_pose(self, moving_robot_homo_mat):
        # Get the cart pose without the scaling
        unscaled_cart_pose = self._homo2cart(moving_robot_homo_mat)

        # Get the current cart pose
        home_pose = self.robot.get_cartesian_position()
        home_pose_array = np.array(home_pose)  # Convert tuple to numpy array
        current_homo_mat = self.robot_pose_aa_to_affine(home_pose_array)
        current_cart_pose = self._homo2cart(current_homo_mat)

        # Get the difference in translation between these two cart poses
        diff_in_translation = unscaled_cart_pose[:3] - current_cart_pose[:3]
        scaled_diff_in_translation = diff_in_translation * self.resolution_scale
        # print('SCALED_DIFF_IN_TRANSLATION: {}'.format(scaled_diff_in_translation))
        
        scaled_cart_pose = np.zeros(6)
        scaled_cart_pose[3:] = unscaled_cart_pose[3:] # Get the rotation directly
        scaled_cart_pose[:3] = current_cart_pose[:3] + scaled_diff_in_translation # Get the scaled translation only

        # print('unscaled_cart_pose: {}, scaled_cart_pose: {}, current_cart_pose: {}'.format(
        #     unscaled_cart_pose[:3], scaled_cart_pose[:3], current_cart_pose[:3]
        # ))

        return scaled_cart_pose

    def _reset_teleop(self):
        # Just updates the beginning position of the arm
        print('****** RESETTING TELEOP ****** ')
        
        #print(self.robot_init_H)
        first_hand_frame = self._get_hand_frame()
        while first_hand_frame is None:
            first_hand_frame = self._get_hand_frame()

        self.hand_init_H = self._turn_frame_to_homo_mat(first_hand_frame)
        self.hand_init_t = copy(self.hand_init_H[:3, 3])

        self.is_first_frame = False
        home_pose = self.robot.get_cartesian_position()
        home_pose_array = np.array(home_pose)  # Convert tuple to numpy array
        self.robot_init_H = self.robot_pose_aa_to_affine(home_pose_array)
        print("Resetting complete")

        return first_hand_frame
    
    def get_gripper_state(self):
        data = self.gripper_subscriber.recv_keypoints()
        gripper_state = np.asanyarray(data).reshape(1)[0] # Make sure this data is one dimensional
        return gripper_state

    def get_gripper_state_new_from_hand_keypoints(self):
        transformed_hand_coords= self._transformed_hand_keypoint_subscriber.recv_keypoints()
        # distance = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['middle'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
        # distance2 = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['index'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
        # distance3 = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['ring'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
        distance = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['pinky'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
        thresh = 0.04
        gripper_fr =False
        if distance < thresh:#and distance2 < thresh and distance3 < thresh) or (distance < thresh and distance2 < thresh) or (distance < thresh and distance3 < thresh) or (distance2 < thresh and distance3 < thresh):
            self.gripper_cnt+=1
            if self.gripper_cnt==1:
                self.prev_gripper_flag = self.gripper_flag
                self.gripper_flag = not self.gripper_flag 
                gripper_fl=True
        else: 
            self.gripper_cnt=0
            #gripper_fl=False
        gripper_state = np.asanyarray(self.gripper_flag).reshape(1)[0]
        status= False  
        if gripper_state!= self.prev_gripper_flag:
            status= True
        return gripper_state , status , gripper_fr
    
    def get_gripper_state_from_hand_keypoints(self):
        transformed_hand_coords= self._transformed_hand_keypoint_subscriber.recv_keypoints()
        distance = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['middle'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
        thresh = 0.03 
        gripper_fl =True
        if distance < thresh:
            self.gripper_cnt+=1
            if self.gripper_cnt==1:
                self.prev_gripper_flag = self.gripper_flag
                self.gripper_flag = not self.gripper_flag 
                
        else: 
            self.gripper_cnt=0
            #gripper_fl=False
        gripper_state = np.asanyarray(self.gripper_flag).reshape(1)[0]
        status= False  
        if gripper_state!= self.prev_gripper_flag:
            status= True
        return gripper_state , status , gripper_fl         
        
    # def get_pause_state_from_hand_keypoints(self):
    #     transformed_hand_coords= self._transformed_hand_keypoint_subscriber.recv_keypoints()
    #     distance = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['ring'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
    #     thresh = 0.03
    #     self.pause_flag = 1  
    #     if distance < thresh:
    #         self.pause_flag = 0
    #     pause_right_status = self.pause_flag   
    #     pause_state = np.asanyarray(pause_right_status).reshape(1)[0]  
    #     return pause_state      

    def get_pause_state_from_hand_keypoints(self):
        transformed_hand_coords= self._transformed_hand_keypoint_subscriber.recv_keypoints()
        distance = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['ring'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
        #distance2 = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['index'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
        distance3 = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['middle'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
        thresh = 0.04
        # self.pause_flag = 1  
        pause_left= True
        if distance < thresh or distance3 < thresh:
            self.pause_cnt+=1
            if self.pause_cnt==1:
                self.prev_pause_flag=self.pause_flag
                self.pause_flag = not self.pause_flag       
        else:
            self.pause_cnt=0
        pause_state = np.asanyarray(self.pause_flag).reshape(1)[0]
        pause_status= False  
        if pause_state!= self.prev_pause_flag:
            pause_status= True 
        return pause_state , pause_status , pause_left

    # def get_pause_state_from_hand_keypoints(self):
    #     transformed_hand_coords= self._transformed_hand_keypoint_subscriber.recv_keypoints()
    #     distance = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['ring'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
    #     thresh = 0.03
    #     # self.pause_flag = 1  
    #     pause_left= True
    #     if distance < thresh:
    #         self.pause_cnt+=1
    #         if self.pause_cnt==1:
    #             self.prev_pause_flag=self.pause_flag
    #             self.pause_flag = not self.pause_flag       
    #     else:
    #         self.pause_cnt=0
    #     pause_state = np.asanyarray(self.pause_flag).reshape(1)[0]
    #     pause_status= False  
    #     if pause_state!= self.prev_pause_flag:
    #         pause_status= True 
    #     return pause_state , pause_status , pause_left

    def _apply_retargeted_angles(self, log=False):

        # See if there is a reset in the teleop
        #print("Arm retargeting happening")
       
        new_arm_teleop_state,pause_status,pause_left = self._get_arm_teleop_state_from_hand_keypoints()
        if self.is_first_frame or (self.arm_teleop_state == ARM_TELEOP_STOP and new_arm_teleop_state == ARM_TELEOP_CONT):
            moving_hand_frame = self._reset_teleop() # Should get the moving hand frame only once
            # print("Reset Happening", moving_hand_frame)
        else:
            moving_hand_frame = self._get_hand_frame()
        self.arm_teleop_state = new_arm_teleop_state
    
        #print("Moving Hand frame is ", moving_hand_frame)
        arm_teleoperation_scale_mode = self._get_resolution_scale_mode()

        if arm_teleoperation_scale_mode == ARM_HIGH_RESOLUTION:
            self.resolution_scale = 1
        elif arm_teleoperation_scale_mode == ARM_LOW_RESOLUTION:
            self.resolution_scale = 0.6

        #print("Resolution set")

        if moving_hand_frame is None: 
            return # It means we are not on the arm mode yet instead of blocking it is directly returning
        
        # print('moving_hand_frame: {}'.format(moving_hand_frame))
        # return
        self.hand_moving_H = self._turn_frame_to_homo_mat(moving_hand_frame)

        #print("Moving frame received")

        # Transformation code
        H_HI_HH = copy(self.hand_init_H) # Homo matrix that takes P_HI to P_HH - Point in Inital Hand Frame to Point in Home Hand Frame
        H_HT_HH = copy(self.hand_moving_H) # Homo matrix that takes P_HT to P_HH
        H_RI_RH = copy(self.robot_init_H) # Homo matrix that takes P_RI to P_RH
      

        H_HT_HI = np.linalg.pinv(H_HI_HH) @ H_HT_HH # Homo matrix that takes P_HT to P_HI
        
        H_R_V= np.array([[0 , 0, 1, 0], 
                        [0 , 1, 0, 0],
                        [-1, 0, 0, 0],
                        [0, 0 ,0 , 1]])
        H_T_V = np.array([[0, 0 ,-1, 0],
                         [0 ,-1, 0, 0],
                         [-1, 0, 0, 0],
                        [0, 0, 0, 1]])

        # H_X=[[0,0, 0,0],
        #     [0,0,-1,0],
        #     [0,1,0,0],
        #     [0,0,0,1]]
        # H_Z=[[0,-1, 0,0],
        #     [1,0,0,0],
        #     [0,0,0,0],
        #     [0,0,0,1]]
        H_HT_HI_r=(pinv(H_R_V)@H_HT_HI@H_R_V)[:3,:3]
        H_HT_HI_t=(pinv(H_T_V)@H_HT_HI@H_T_V)[:3,3]
        

        relative_affine = np.block(
        [[ H_HT_HI_r,  H_HT_HI_t.reshape(3, 1)], [0, 0, 0, 1]])
       
        #print("New Transformation matrix",  H_HT_HI)
        # self.robot_moving_H = copy(H_RT_RH)
        # print("Transformation matrix after 2", H_RT_RH)

        # if log:
        #     print('** ROBOT MOVING H **\n{}\n** ROBOT INIT H **\n{}\n'.format(
        #         self.robot_moving_H, self.robot_init_H))
        #     print('** HAND MOVING H: **\n{}\n** HAND INIT H: **\n{} - HAND INIT T: {}'.format(
        #         self.hand_moving_H, self.hand_init_H, self.hand_init_t
        #     ))
        #     print('***** TRANSFORM MULT: ******\n{}'.format(
        #         H_HT_HI
        #     ))

        #     print('\n------------------------------------\n\n\n')

        # # Use the resolution scale to get the final cart pose

        # print("Scaled Pose found")
        # final_pose = self._get_scaled_cart_pose(self.robot_moving_H)
        # final_pose[0:3]=final_pose[0:3]*1000
        # print("Final pose with first", final_pose)

        # if self.use_filter:
        #     final_pose = self.comp_filter(final_pose)

        #self.robot.arm_control(final_pose)
        relative_affine = np.block(
        [[ H_HT_HI_r,  H_HT_HI_t.reshape(3, 1)], [0, 0, 0, 1]])
        target_translation = H_RI_RH[:3,3] + relative_affine[:3,3]

        target_rotation = H_RI_RH[:3, :3] @ relative_affine[:3,:3]
        H_RT_RH = np.block(
                    [[target_rotation, target_translation.reshape(-1, 1)], [0, 0, 0, 1]])
        # print("Homo", H_RI_RH)
        # H_RT_RH = H_RI_RH @  relative_affine
        # H_RT_RH = H_RI_RH @ H_A_R @ H_HT_HI @ np.linalg.pinv(H_A_R)
        self.robot_moving_H = copy(H_RT_RH)
        # print("Transformation matrix after 2", H_RT_RH)

        if log:
            print('** ROBOT MOVING H **\n{}\n** ROBOT INIT H **\n{}\n'.format(
                self.robot_moving_H, self.robot_init_H))
            print('** HAND MOVING H: **\n{}\n** HAND INIT H: **\n{} - HAND INIT T: {}'.format(
                self.hand_moving_H, self.hand_init_H, self.hand_init_t
            ))
            print('***** TRANSFORM MULT: ******\n{}'.format(
                H_HT_HI
            ))

            print('\n------------------------------------\n\n\n')

        # Use the resolution scale to get the final cart pose

        # print("Scaled Pose found")
        final_pose = self._get_scaled_cart_pose(self.robot_moving_H)
        final_pose[0:3]=final_pose[0:3]*1000

        if self.use_filter:
            final_pose = self.comp_filter(final_pose)
        # gripper_state=self.get_gripper_state()
       
        # self.robot.set_gripper_state(gripper_state*800)
        gripper_state,status_change, gripper_flag =self.get_gripper_state_new_from_hand_keypoints()
        # print("Status Change", status_change)
        # if status_change is True:
        #     if self.gripper_cnt >= 30:
        #         self.gripper_cnt = 0

        #     if self.gripper_cnt == 0:
        #         self.robot.set_gripper_state(gripper_state*800)
        if self.gripper_cnt==1 and status_change is True:
            self.gripper_correct_state= gripper_state
            print("Left Gripper State",self.gripper_correct_state)
            self.robot.set_gripper_state(self.gripper_correct_state*800)
        self.gripper_publisher.pub_keypoints(self.gripper_correct_state,"gripper_left")
        position=self.robot.get_cartesian_position()
        joint_position= self.robot.get_joint_position()
        self.cartesian_publisher.pub_keypoints(position,"cartesian")
        self.joint_publisher.pub_keypoints(joint_position,"joint")
        self.cartesian_command_publisher.pub_keypoints(final_pose,"cartesian")
        # self.gripper_cnt+=1
        #gripper=self.gripper_subscriber.recv_keypoints()
        # print("Final pose with second is ", final_pose)
        #
        # print("Arm State", self.arm_teleop_state)
        if self.arm_teleop_state == ARM_TELEOP_CONT and gripper_flag == False:
            self.robot.arm_control(final_pose)

       
        


