import numpy as np
import matplotlib.pyplot as plt
import zmq

from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm

from copy import deepcopy as copy
from openteach.constants import *
from openteach.utils.timer import FrequencyTimer
from openteach.utils.network import ZMQKeypointSubscriber
from openteach.utils.vectorops import *
from openteach.utils.files import *
from openteach.robot.franka import FrankaArm
from scipy.spatial.transform import Rotation, Slerp
from .operator import Operator


np.set_printoptions(precision=2, suppress=True)

def test_filter(save_data=False):
    if save_data:
        filter = CompStateFilter(np.asarray( [ 0.575466,   -0.17820767,  0.23671454, -0.281564 ,  -0.6797597,  -0.6224841 ,  0.2667619 ]))
        timer = FrequencyTimer(VR_FREQ)

        i = 0
        while True:
            try:
                timer.start_loop()

                rand_pos = np.random.randn(7)
                filtered_pos = filter(rand_pos)

                print('rand_pos: {} - filtered_pos: {}'.format(rand_pos, filtered_pos)) 

                if i == 0:
                    all_poses = np.expand_dims(np.stack([rand_pos, filtered_pos], axis=0), 0)
                else:
                    all_poses = np.concatenate([
                        all_poses,
                        np.expand_dims(np.stack([rand_pos, filtered_pos], axis=0), 0)
                    ], axis=0)

                print('all_poses shape: {}'.format(
                    all_poses.shape
                ))

                i += 1
                timer.end_loop()

            except KeyboardInterrupt:
                np.save('all_poses.npy', all_poses)
                break

    else:
        all_poses = np.load('all_poses.npy')
        fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(10,10))
        pbar = tqdm(total=len(all_poses))
        for i in range(len(all_poses)):
            filtered_pos = all_poses[:i+1,1,:]
            rand_pos = all_poses[:i+1,0,:]

            # print('filtered_pos.shape: {}'.format(filtered_pos.shape))
            for j in range(filtered_pos.shape[1]):
                axs[int(j / 3), j % 3].plot(filtered_pos[:,j], label='Filtered')
                axs[int(j / 3), j % 3].plot(rand_pos[:,j], label='Actual')
                axs[int(j / 3), j % 3].set_title(f'{j}th Axes')
                axs[int(j / 3), j % 3].legend()

            pbar.update(1)
            plt.savefig(os.path.join(f'all_poses/state_{str(i).zfill(3)}.png'))
            fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(10,10))

# Rotation should be filtered when it's being sent
class CompStateFilter:
    def __init__(self, state, comp_ratio=0.6):
        self.pos_state = state[:3]
        self.ori_state = state[3:7]
        self.comp_ratio = comp_ratio

    def __call__(self, next_state):
        self.pos_state = self.pos_state[:3] * self.comp_ratio + next_state[:3] * (1 - self.comp_ratio)
        ori_interp = Slerp([0, 1], Rotation.from_quat(
            np.stack([self.ori_state, next_state[3:7]], axis=0)),)
        self.ori_state = ori_interp([1 - self.comp_ratio])[0].as_quat()
        return np.concatenate([self.pos_state, self.ori_state])
    


class FrankaArmOperator(Operator):
    def __init__(
        self,
        host,
        transformed_keypoints_port,
        moving_average_limit,
        allow_rotation=False,
        arm_type='main_arm',
        use_filter=False,
        arm_resolution_port = None,
        teleoperation_reset_port = None,
    ):
        self.notify_component_start('franka arm operator')
        self._transformed_hand_keypoint_subscriber = ZMQKeypointSubscriber(
            host=host,
            port=transformed_keypoints_port,
            topic='transformed_hand_coords'
        )
        self._transformed_arm_keypoint_subscriber = ZMQKeypointSubscriber(
            host=host,
            port=transformed_keypoints_port,
            topic='transformed_hand_frame'
        )

        # Initalizing the robot controller
        self._robot = FrankaArm()
        self.arm_type = arm_type
        self.allow_rotation = allow_rotation
        self.resolution_scale = 1 # NOTE: Get this from a socket
        self.arm_teleop_state = ARM_TELEOP_STOP # We will start as the cont

        self._arm_resolution_subscriber = ZMQKeypointSubscriber(
            host = host,
            port = arm_resolution_port,
            topic = 'button'
        )

        self._arm_teleop_state_subscriber = ZMQKeypointSubscriber(
            host = host, 
            port = teleoperation_reset_port,
            topic = 'pause'
        )
        
        # current_robot_pose =
        self.robot_init_H = self.robot.get_pose()['position']
        self.is_first_frame = True
        print('ROBOT INIT H: \n{}'.format(self.robot_init_H))

        self.use_filter = use_filter
        if use_filter:
            robot_init_cart = self._homo2cart(self.robot_init_H)
            self.comp_filter = CompStateFilter(robot_init_cart, comp_ratio=0.8)

        if allow_rotation:
            self.initial_quat = np.array(
                [-0.27686286, -0.66575766, -0.63895273,  0.26805457])
            self.rotation_axis = np.array([0, 0, 1])

        # Getting the bounds to perform linear transformation
        bounds_file = get_path_in_package(
            'components/operators/configs/franka.yaml')
        bounds_data = get_yaml_data(bounds_file)

        # Bounds for performing linear transformation
        self.corresponding_robot_axes = bounds_data['corresponding_robot_axes']
        self.franka_bounds = bounds_data['robot_bounds']
        self.wrist_bounds = bounds_data['wrist_bounds']

        # Matrices to reorient the end-effector rotation frames
        self.frame_realignment_matrix = np.array(
            bounds_data['frame_realignment_matrix']).reshape(3, 3)
        self.rotation_realignment_matrix = np.array(
            bounds_data['rotation_alignment_matrix']).reshape(3, 3)

        # Frequency timer
        self._timer = FrequencyTimer(VR_FREQ)

        self.direction_counter = 0
        self.current_direction = 0

        # Moving average queues
        self.moving_Average_queue = []
        self.moving_average_limit = moving_average_limit

        self.hand_frames = []

    @property
    def timer(self):
        return self._timer

    @property
    def robot(self):
        return self._robot

    @property
    def transformed_hand_keypoint_subscriber(self):
        return self._transformed_hand_keypoint_subscriber
    
    @property
    def transformed_arm_keypoint_subscriber(self):
        return self._transformed_arm_keypoint_subscriber

    def _get_hand_frame(self):
        # print('WAITING FOR GET HAND FRAME!')
        for i in range(10):
            data = self.transformed_arm_keypoint_subscriber.recv_keypoints(flags=zmq.NOBLOCK)
            if not data is None: break 
        # print('data: {}'.format(data))
        if data is None: return None
        return np.asanyarray(data).reshape(4, 3)
    
    def _get_resolution_scale_mode(self):
        data = self._arm_resolution_subscriber.recv_keypoints()
        res_scale = np.asanyarray(data).reshape(1)[0] # Make sure this data is one dimensional
        return res_scale
        # return ARM_LOW_RESOLUTION    

    def _get_arm_teleop_state(self):
        reset_stat = self._arm_teleop_state_subscriber.recv_keypoints()
        reset_stat = np.asanyarray(reset_stat).reshape(1)[0] # Make sure this data is one dimensional
        # if reset_stat == ARM_TELEOP_CONT: 
        #     print('ARM TELEOP RESET STAT == ARM_TELEOP_CONT')
        # else:
        #     print('ARM TELEOP RESET STATE == ARM_TELEOP_STOP')
        return reset_stat
        # return ARM_TELEOP_CONT

    def _clip_coords(self, coords):
        # TODO - clip the coordinates
        return coords

    def _round_displacement(self, displacement):
        return np.where(np.abs(displacement) * 1e2 > 1.5, displacement, 0)

    def _realign_frame(self, hand_frame):
        return self.frame_realignment_matrix @ hand_frame

    def _turn_frame_to_homo_mat(self, frame):
        t = frame[0]
        R = frame[1:]

        homo_mat = np.zeros((4, 4))
        homo_mat[:3, :3] = np.transpose(R)
        homo_mat[:3, 3] = t
        homo_mat[3, 3] = 1

        return homo_mat

    def _homo2cart(self, homo_mat):
        # Here we will use the resolution scale to set the translation resolution
        t = homo_mat[:3, 3]
        R = Rotation.from_matrix(
            homo_mat[:3, :3]).as_quat()

        cart = np.concatenate(
            [t, R], axis=0
        )

        return cart
    
    def _get_scaled_cart_pose(self, moving_robot_homo_mat):
        # Get the cart pose without the scaling
        unscaled_cart_pose = self._homo2cart(moving_robot_homo_mat)

        # Get the current cart pose
        current_homo_mat = copy(self.robot.get_pose()['position'])
        current_cart_pose = self._homo2cart(current_homo_mat)

        # Get the difference in translation between these two cart poses
        diff_in_translation = unscaled_cart_pose[:3] - current_cart_pose[:3]
        scaled_diff_in_translation = diff_in_translation * self.resolution_scale
        # print('SCALED_DIFF_IN_TRANSLATION: {}'.format(scaled_diff_in_translation))
        
        scaled_cart_pose = np.zeros(7)
        scaled_cart_pose[3:] = unscaled_cart_pose[3:] # Get the rotation directly
        scaled_cart_pose[:3] = current_cart_pose[:3] + scaled_diff_in_translation # Get the scaled translation only

        # print('unscaled_cart_pose: {}, scaled_cart_pose: {}, current_cart_pose: {}'.format(
        #     unscaled_cart_pose[:3], scaled_cart_pose[:3], current_cart_pose[:3]
        # ))

        return scaled_cart_pose
     
    
    def create_axs(self):

        fig = plt.figure(figsize=(10,10))
        ax = fig.add_subplot(111, projection='3d')

        # For plotting we'll have y -> z, x -> x, z -> -y axis
        ax.set_xlim3d([-1, 1]) 
        ax.set_xlabel('X') 

        ax.set_ylim3d([-1, 1])
        ax.set_ylabel('Y')

        ax.set_zlim3d([-1, 1])
        ax.set_zlabel('Z')

        return ax
    
    def plot_transform(self, transform_mult, plot_init=False, plot_curr=False):
        ax = self.create_axs()

        ax.plot3D(
            xs=[0,0], ys=[0,0], zs=[-1,1], color='k'
        )

        ax.plot3D(
            xs=[0,0], ys=[-1,1], zs=[0,0], color='b'
        )

        ax.plot3D(
            xs=[-1,1], ys=[0,0], zs=[0,0], color='r'
        )

        # Ax axis
        ax.quiver(
            transform_mult[0,3], transform_mult[1,3], transform_mult[2,3],
            transform_mult[0,0], transform_mult[1,0], transform_mult[2,0],
            color='r', length=0.3
        )
        ax.quiver(
            transform_mult[0,3], transform_mult[1,3], transform_mult[2,3],
            transform_mult[0,1], transform_mult[1,1], transform_mult[2,1],
            color='b', length=0.3
        )
        ax.quiver(
            transform_mult[0,3], transform_mult[1,3], transform_mult[2,3],
            transform_mult[0,2], transform_mult[1,2], transform_mult[2,2],
            color='k', length=0.3
        )

        # Ax axis
        if plot_curr:
            ax.quiver(
                transform_mult[0,3], transform_mult[1,3], transform_mult[2,3],
                self.hand_moving_H[0,0], self.hand_moving_H[1,0], self.hand_moving_H[2,0],
                color='peru', length=0.3
            )
            ax.quiver(
                transform_mult[0,3], transform_mult[1,3], transform_mult[2,3],
                self.hand_moving_H[0,1], self.hand_moving_H[1,1], self.hand_moving_H[2,1],
                color='cyan', length=0.3
            )
            ax.quiver(
                transform_mult[0,3], transform_mult[1,3], transform_mult[2,3],
                self.hand_moving_H[0,2], self.hand_moving_H[1,2], self.hand_moving_H[2,2],
                color='slategrey', length=0.3
            )

        # Ax axis
        if plot_init:
            ax.quiver(
                [0,0], [0,0], [0,0],
                self.hand_init_H[0,0], self.hand_init_H[1,0], self.hand_init_H[2,0],
                color='green', length=0.3
            )
            ax.quiver(
                [0,0], [0,0], [0,0],
                self.hand_init_H[0,1], self.hand_init_H[1,1], self.hand_init_H[2,1],
                color='green', length=0.3
            )
            ax.quiver(
                [0,0], [0,0], [0,0],
                self.hand_init_H[0,2], self.hand_init_H[1,2], self.hand_init_H[2,2],
                color='green', length=0.3
            )

        print('** TRANSFORM: **\n{}'.format(transform_mult))
        print('** HAND INIT H: **\n{}'.format(self.hand_init_H))

        plt.draw()
        plt.savefig('transform_mult.png')
        plt.close()

    def get_knuckle_vectors(self):

        for i in range(10):
            index_vector = self._index_knuckle_keypoint_subscriber.recv_keypoints(flags=zmq.NOBLOCK)
            pinky_vector = self._pinky_knuckle_keypoint_subscriber.recv_keypoints(flags=zmq.NOBLOCK)
            if not index_vector is None and not pinky_vector is None: break 
        # print('data: {}'.format(data))
        if index_vector is None and pinky_vector is None: return None
        return np.asanyarray(index_vector), np.asanyarray(pinky_vector)
        

    def plot_knuckle_vectors(self, index_vector, pinky_vector):

        ax = self.create_axs()

        ax.plot3D(
            xs=[0,0], ys=[-1,1] , zs=[0,0], color='k'
        )

        ax.plot3D(
            xs=[0,0],ys=[0,0], zs=[-1,1], color='b'
        )

        ax.plot3D(
            xs=[-1,1], ys=[0,0], zs=[0,0], color='r'
        )

        # ax.quiver(
        #     [0,0], [0,0], [0,0],
        #     index_vector[0], index_vector[1],  index_vector[2],
        #     color='r', length=5
        # )

        # ax.quiver(
        #     [0,0], [0,0], [0,0],
        #     pinky_vector[0], pinky_vector[1], pinky_vector[2],
        #     color='b', length=5
        # )

        x, y, z = [0, index_vector[0]], [0,index_vector[1]], [0, index_vector[2]]
        x1, y1, z1= [0, pinky_vector[0]], [0,pinky_vector[1]], [0, pinky_vector[2]]
        print("Index Knuckle Vector", [x,y,z])
        print("Pinky Knuckle Vector", [x1,y1,z1])
        #ax.scatter(x, y, z, c='red', s=100)
        ax.plot(x, y, z, color='black')
        ax.plot(x1, y1, z1, color='red')
        plt.draw()
        plt.savefig('knuckle_vectors.png')
        plt.close()
            

    def _reset_teleop(self):
        # Just updates the beginning position of the arm
        print('****** RESETTING TELEOP ****** ')
        self.robot_init_H = self.robot.get_pose()['position']
        first_hand_frame = self._get_hand_frame()
        while first_hand_frame is None:
            first_hand_frame = self._get_hand_frame()
        self.hand_init_H = self._turn_frame_to_homo_mat(first_hand_frame)
        self.hand_init_t = copy(self.hand_init_H[:3, 3])

        self.is_first_frame = False

        return first_hand_frame

    def _apply_retargeted_angles(self, log=False):

        # See if there is a reset in the teleop
        new_arm_teleop_state = self._get_arm_teleop_state()
        if self.is_first_frame or (self.arm_teleop_state == ARM_TELEOP_STOP and new_arm_teleop_state == ARM_TELEOP_CONT):
            moving_hand_frame = self._reset_teleop() # Should get the moving hand frame only once
        else:
            moving_hand_frame = self._get_hand_frame()
        self.arm_teleop_state = new_arm_teleop_state

        arm_teleoperation_scale_mode = self._get_resolution_scale_mode()

        if arm_teleoperation_scale_mode == ARM_HIGH_RESOLUTION:
            self.resolution_scale = 1
        elif arm_teleoperation_scale_mode == ARM_LOW_RESOLUTION:
            self.resolution_scale = 0.6

        if moving_hand_frame is None: 
            return # It means we are not on the arm mode yet instead of blocking it is directly returning
        
        # print('moving_hand_frame: {}'.format(moving_hand_frame))
        # return
        self.hand_moving_H = self._turn_frame_to_homo_mat(moving_hand_frame)

        # Transformation code
        H_HI_HH = copy(self.hand_init_H) # Homo matrix that takes P_HI to P_HH - Point in Inital Hand Frame to Point in Home Hand Frame
        H_HT_HH = copy(self.hand_moving_H) # Homo matrix that takes P_HT to P_HH
        H_RI_RH = copy(self.robot_init_H) # Homo matrix that takes P_RI to P_RH
        H_A_R = np.array( 
            [[1/np.sqrt(2), 1/np.sqrt(2), 0, 0],
             [-1/np.sqrt(2), 1/np.sqrt(2), 0, 0],
             [0, 0, 1, -0.06], # The height of the allegro mount is 6cm
             [0, 0, 0, 1]])# Rotation from allegro to robot


        H_HT_HI = np.linalg.pinv(H_HI_HH) @ H_HT_HH # Homo matrix that takes P_HT to P_HI
        #H_HT_HI[:3, :3] = H_HT_HI[:3, :3].T # This is for converting left hand rotation to the right hand rotation (this compensates few problems in openteach/components/recorders/keypoint.py)
        H_RT_RH = H_RI_RH @ H_A_R @ H_HT_HI @ np.linalg.pinv(H_A_R) # Homo matrix that takes P_RT to P_RH

        self.robot_moving_H = copy(H_RT_RH)

        if log:
            print('** ROBOT MOVING H **\n{}\n** ROBOT INIT H **\n{}\n'.format(
                self.robot_moving_H, self.robot_init_H))
            print('** HAND MOVING H: **\n{}\n** HAND INIT H: **\n{} - HAND INIT T: {}'.format(
                self.hand_moving_H, self.hand_init_H, self.hand_init_t
            ))
            print('***** TRANSFORM MULT: ******\n{}'.format(
                H_HT_HI
            ))

            print('\n------------------------------------\n\n\n')

        # Use the resolution scale to get the final cart pose
        final_pose = self._get_scaled_cart_pose(self.robot_moving_H)

        if self.use_filter:
            final_pose = self.comp_filter(final_pose)

        self.robot.arm_control(final_pose)

    def stream(self):
        self.notify_component_start('{} control'.format(self.robot.name))
        print("Start controlling the robot hand using the Oculus Headset.\n")

        # Assume that the initial position is considered initial after 3 seconds of the start
        while True:
            try:
                if self.robot.get_joint_position() is not None:
                    self.timer.start_loop()

                    # Retargeting function
                    self._apply_retargeted_angles(log=False)

                    self.timer.end_loop()
            except KeyboardInterrupt:
                break

        self.transformed_arm_keypoint_subscriber.stop()
        print('Stopping the teleoperator!')
