import numpy as np
import matplotlib.pyplot as plt
import zmq

from tqdm import tqdm

from copy import deepcopy as copy
from openteach.constants import *
from openteach.utils.timer import FrequencyTimer
from openteach.utils.network import ZMQKeypointSubscriber, ZMQKeypointPublisher
from openteach.utils.vectorops import *
from openteach.utils.files import *
from scipy.spatial.transform import Rotation, Slerp
from .operator import Operator
from .calibrators.allegro import OculusThumbBoundCalibrator

import robosuite.utils.transform_utils as T
from scipy.spatial.transform import Rotation as R

np.set_printoptions(precision=2, suppress=True)

def test_filter(save_data=False):
	if save_data:
		filter = CompStateFilter(np.asarray( [ 0.575466,   -0.17820767,  0.23671454, -0.281564 ,  -0.6797597,  -0.6224841 ,  0.2667619 ]))
		timer = FrequencyTimer(VR_FREQ)

		i = 0
		while True:
			try:
				timer.start_loop()

				rand_pos = np.random.randn(7)
				filtered_pos = filter(rand_pos)

				print('rand_pos: {} - filtered_pos: {}'.format(rand_pos, filtered_pos)) 

				if i == 0:
					all_poses = np.expand_dims(np.stack([rand_pos, filtered_pos], axis=0), 0)
				else:
					all_poses = np.concatenate([
						all_poses,
						np.expand_dims(np.stack([rand_pos, filtered_pos], axis=0), 0)
					], axis=0)

				print('all_poses shape: {}'.format(
					all_poses.shape
				))

				i += 1
				timer.end_loop()

			except KeyboardInterrupt:
				np.save('all_poses.npy', all_poses)
				break

	else:
		all_poses = np.load('all_poses.npy')
		fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(10,10))
		pbar = tqdm(total=len(all_poses))
		for i in range(len(all_poses)):
			filtered_pos = all_poses[:i+1,1,:]
			rand_pos = all_poses[:i+1,0,:]

			for j in range(filtered_pos.shape[1]):
				axs[int(j / 3), j % 3].plot(filtered_pos[:,j], label='Filtered')
				axs[int(j / 3), j % 3].plot(rand_pos[:,j], label='Actual')
				axs[int(j / 3), j % 3].set_title(f'{j}th Axes')
				axs[int(j / 3), j % 3].legend()

			pbar.update(1)
			plt.savefig(os.path.join(f'all_poses/state_{str(i).zfill(3)}.png'))
			fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(10,10))

# Rotation should be filtered when it's being sent
class CompStateFilter:
	def __init__(self, state, comp_ratio=0.6):
		self.pos_state = state[:3]
		self.ori_state = state[3:7]
		self.comp_ratio = comp_ratio

	def __call__(self, next_state):
		self.pos_state = self.pos_state[:3] * self.comp_ratio + next_state[:3] * (1 - self.comp_ratio)
		ori_interp = Slerp([0, 1], Rotation.from_quat(
			np.stack([self.ori_state, next_state[3:7]], axis=0)),)
		self.ori_state = ori_interp([1 - self.comp_ratio])[0].as_quat()
		return np.concatenate([self.pos_state, self.ori_state])
	


class LiberoSimOperator(Operator):
	def __init__(
		self,
		host,
		transformed_keypoints_port,
		stream_configs,
		stream_oculus,
		endeff_publish_port,
		endeffpossubscribeport,
		robotposesubscribeport,
		moving_average_limit,
		allow_rotation=False,
		use_filter=False,
		arm_resolution_port = None,
		teleoperation_reset_port = None,
	):
		self.notify_component_start('libero operator')
		self._host, self._port = host, transformed_keypoints_port
		self._hand_transformed_keypoint_subscriber = ZMQKeypointSubscriber(
			host = self._host,
			port = self._port,
			topic = 'transformed_hand_coords'
		)
		self._arm_transformed_keypoint_subscriber = ZMQKeypointSubscriber(
			host=host,
			port=transformed_keypoints_port,
			topic='transformed_hand_frame'
		)


		# Initalizing the robot controller
		self.allow_rotation = allow_rotation
		self.resolution_scale = 1 # NOTE: Get this from a socket
		self.arm_teleop_state = ARM_TELEOP_STOP # We will start as the cont
		self.gripper_correct_state =0
		self.pause_flag=0
		self.gripper_flag=1
		self.prev_gripper_flag=0
		self.prev_pause_flag=0
		self.pause_cnt=0

		self.robot_information=dict()

		self._arm_resolution_subscriber = ZMQKeypointSubscriber(
			host = host,
			port = arm_resolution_port,
			topic = 'button'
		)

		self.end_eff_position_subscriber = ZMQKeypointSubscriber(
			host = host,
			port =  endeffpossubscribeport,
			topic = 'endeff_coords'

		)

		self.end_eff_position_publisher = ZMQKeypointPublisher(
			host = host,
			port = endeff_publish_port
		)

		# robot pose subscriber
		self.robot_pose_subscriber = ZMQKeypointSubscriber(
			host = host,
			port = robotposesubscribeport,
			topic = 'robot_pose'
		)

		# Calibrating to get the thumb bounds
		self._calibrate_bounds()
	   

		print("Cuda available")
		self._stream_oculus=stream_oculus
		self.stream_configs=stream_configs

		self._timer = FrequencyTimer(VR_FREQ)
	   
		self.real=False
		self._robot='Libero_Sim'
		
		# current_robot_pose =
		self.is_first_frame = True

		self.use_filter = use_filter
		if use_filter:
			robot_init_cart = self._homo2cart(self.robot_init_H)
			self.comp_filter = CompStateFilter(robot_init_cart, comp_ratio=0.8)

		if allow_rotation:
			self.initial_quat = np.array(
				[-0.27686286, -0.66575766, -0.63895273,  0.26805457])
			self.rotation_axis = np.array([0, 0, 1])

		# Frequency timer
		self._timer = FrequencyTimer(VR_FREQ)

		self.direction_counter = 0
		self.current_direction = 0

		# Moving average queues
		self.moving_Average_queue = []
		self.moving_average_limit = moving_average_limit

		self.hand_frames = []

		self.count = 0

	@property
	def timer(self):
		return self._timer

	@property
	def robot(self):
		return self._robot
	
	@property
	def transformed_hand_keypoint_subscriber(self):
		return self._hand_transformed_keypoint_subscriber
	
	@property
	def transformed_arm_keypoint_subscriber(self):
		return self._arm_transformed_keypoint_subscriber
	
	def _calibrate_bounds(self):
		self.notify_component_start('calibration')
		calibrator = OculusThumbBoundCalibrator(self._host, self._port)
		self.hand_thumb_bounds = calibrator.get_bounds() # Provides [thumb-index bounds, index-middle bounds, middle-ring-bounds]
		print(f'THUMB BOUNDS IN THE OPERATOR: {self.hand_thumb_bounds}')

	def _get_hand_frame(self):
		for i in range(10):
			data = self.transformed_arm_keypoint_subscriber.recv_keypoints(flags=zmq.NOBLOCK)
			if not data is None: break 
		if data is None: return None
		return np.asanyarray(data).reshape(4, 3)
	
	def _get_resolution_scale_mode(self):
		data = self._arm_resolution_subscriber.recv_keypoints()
		res_scale = np.asanyarray(data).reshape(1)[0] # Make sure this data is one dimensional
		return res_scale
	

	def _turn_frame_to_homo_mat(self, frame):
		t = frame[0]
		R = frame[1:]

		homo_mat = np.zeros((4, 4))
		homo_mat[:3, :3] = np.transpose(R)
		homo_mat[:3, 3] = t
		homo_mat[3, 3] = 1

		return homo_mat

	def _homo2cart(self, homo_mat):
		# Here we will use the resolution scale to set the translation resolution
		t = homo_mat[:3, 3]
		R = Rotation.from_matrix(
			homo_mat[:3, :3]).as_quat()

		cart = np.concatenate(
			[t, R], axis=0
		)

		return cart
	
	def cart2homo(self, cart):
		homo=np.zeros((4,4))
		t = cart[0:3]
		R = Rotation.from_quat(cart[3:]).as_matrix()

		homo[0:3,3] = t
		homo[:3,:3] = R
		homo[3,:] = np.array([0,0,0,1])
		return homo

	
	def _get_scaled_cart_pose(self, moving_robot_homo_mat):
		# Get the cart pose without the scaling
		unscaled_cart_pose = self._homo2cart(moving_robot_homo_mat)

		# Get the current cart pose
		current_homo_mat = copy(self.robot.get_pose()['position'])
		current_cart_pose = self._homo2cart(current_homo_mat)

		# Get the difference in translation between these two cart poses
		diff_in_translation = unscaled_cart_pose[:3] - current_cart_pose[:3]
		scaled_diff_in_translation = diff_in_translation * self.resolution_scale
		
		scaled_cart_pose = np.zeros(7)
		scaled_cart_pose[3:] = unscaled_cart_pose[3:] # Get the rotation directly
		scaled_cart_pose[:3] = current_cart_pose[:3] + scaled_diff_in_translation # Get the scaled translation only

		return scaled_cart_pose
	 
	def return_real(self):
		return self.real
			

	def _reset_teleop(self):
		# Just updates the beginning position of the arm
		print('****** RESETTING TELEOP ****** ')
		self.robot_frame=self.end_eff_position_subscriber.recv_keypoints()
		self.robot_init_H=self.cart2homo(self.robot_frame[2:])
		self.robot_moving_H = copy(self.robot_init_H)

		first_hand_frame = self._get_hand_frame()
		while first_hand_frame is None:
			first_hand_frame = self._get_hand_frame()
		self.hand_init_H = self._turn_frame_to_homo_mat(first_hand_frame)
		self.hand_init_t = copy(self.hand_init_H[:3, 3])

		self.is_first_frame = False

		return first_hand_frame
	
	def _get_arm_teleop_state_from_hand_keypoints(self):
		pause_state ,pause_status,pause_right =self.get_pause_state_from_hand_keypoints()
		pause_status =np.asanyarray(pause_status).reshape(1)[0] 

		return pause_state,pause_status,pause_right
	
	def get_pause_state_from_hand_keypoints(self):
		transformed_hand_coords= self.transformed_hand_keypoint_subscriber.recv_keypoints()
		ring_distance = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['ring'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
		middle_distance = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['middle'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
		thresh = 0.04
		# self.pause_flag = 1  
		pause_right= True
		if ring_distance < thresh or middle_distance < thresh:
			self.pause_cnt+=1
			if self.pause_cnt==1:
				self.prev_pause_flag=self.pause_flag
				self.pause_flag = not self.pause_flag       
		else:
			self.pause_cnt=0
		pause_state = np.asanyarray(self.pause_flag).reshape(1)[0]
		pause_status= False  
		if pause_state!= self.prev_pause_flag:
			pause_status= True 
		return pause_state , pause_status , pause_right
	
	def get_gripper_state_from_hand_keypoints(self):
		transformed_hand_coords= self.transformed_hand_keypoint_subscriber.recv_keypoints()
		# distance = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['middle'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
		# distance2 = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['index'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
		# distance3 = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['ring'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
		pinky_distance = np.linalg.norm(transformed_hand_coords[OCULUS_JOINTS['pinky'][-1]]- transformed_hand_coords[OCULUS_JOINTS['thumb'][-1]])
		thresh = 0.04
		gripper_fr =False
		if pinky_distance < thresh:#and distance2 < thresh and distance3 < thresh) or (distance < thresh and distance2 < thresh) or (distance < thresh and distance3 < thresh) or (distance2 < thresh and distance3 < thresh):
			self.gripper_cnt+=1
			if self.gripper_cnt==1:
				self.prev_gripper_flag = self.gripper_flag
				self.gripper_flag = not self.gripper_flag 
				gripper_fl=True
		else: 
			self.gripper_cnt=0
			#gripper_fl=False
		gripper_state = np.asanyarray(self.gripper_flag).reshape(1)[0]
		status= False  
		if gripper_state!= self.prev_gripper_flag:
			status= True
		return gripper_state , status , gripper_fr

	def _apply_retargeted_angles(self, log=False):

		# self.count += 1
		# if self.count == 50:
		# 	import ipdb; ipdb.set_trace()

		#Moving End Effector Teleoperation Code

		# See if there is a reset in the teleop
		new_arm_teleop_state,pause_status,pause_right = self._get_arm_teleop_state_from_hand_keypoints()
		if self.is_first_frame or (self.arm_teleop_state == ARM_TELEOP_STOP and new_arm_teleop_state == ARM_TELEOP_CONT):
			moving_hand_frame = self._reset_teleop() # Should get the moving hand frame only once
		else:
			moving_hand_frame = self._get_hand_frame()
		self.arm_teleop_state = new_arm_teleop_state

		# gripper
		gripper_state,status_change, gripper_flag = self.get_gripper_state_from_hand_keypoints()
		if self.gripper_cnt==1 and status_change is True:
			self.gripper_correct_state= gripper_state
        # print("Status Change", status_change)
        # if status_change is True:
		if self.gripper_correct_state == GRIPPER_OPEN:
			gripper_state = -1
		elif self.gripper_correct_state == GRIPPER_CLOSE:
			gripper_state = 1
		# print("Gripper state", gripper_state)

		if moving_hand_frame is None: 
			return # It means we are not on the arm mode yet instead of blocking it is directly returning
		
		self.hand_moving_H = self._turn_frame_to_homo_mat(moving_hand_frame)

		# Transformation code
		H_HI_HH = copy(self.hand_init_H) # Homo matrix that takes P_HI to P_HH - Point in Inital Hand Frame to Point in Home Hand Frame
		H_HT_HH = copy(self.hand_moving_H) # Homo matrix that takes P_HT to P_HH
		H_RI_RH = copy(self.robot_init_H) # Homo matrix that takes P_RI to P_RH

		H_HT_HI = np.linalg.pinv(H_HI_HH) @ H_HT_HH # Homo matrix that takes P_HT to P_HI
		
		#####################################################################################
		H_R_V= np.array([[0 , 0, 1, 0], 
						[0 , 1, 0, 0],
						[-1, 0, 0, 0],
						[0, 0 ,0 , 1]])
		H_T_V = np.array([[0, 0 ,1, 0],
						 [0 ,1, 0, 0],
						 [-1, 0, 0, 0],
						[0, 0, 0, 1]])
	
		H_HT_HI_r=(np.linalg.pinv(H_R_V) @ H_HT_HI @ H_R_V)[:3,:3]
		H_HT_HI_t=(np.linalg.pinv(H_T_V) @ H_HT_HI @ H_T_V)[:3,3]
		
		relative_affine = np.block(
		[[ H_HT_HI_r,  H_HT_HI_t.reshape(3, 1)], [0, 0, 0, 1]])
		
		target_translation = H_RI_RH[:3,3] + relative_affine[:3,3]
		target_rotation = H_RI_RH[:3, :3] @ relative_affine[:3,:3]
		H_RT_RH = np.block(
					[[target_rotation, target_translation.reshape(-1, 1)], [0, 0, 0, 1]])

		# curr_robot_pose = self.robot_pose_subscriber.recv_keypoints()
		# print(f"Received: {curr_robot_pose[2:]}")
		# print(f"To go: {self._homo2cart(H_RT_RH)}")
		# print("#####################################################")
		# curr_robot_pose = self.cart2homo(curr_robot_pose[2:])
		curr_robot_pose = self.robot_moving_H
		# print("robot_init:", self.robot_init_H)
		# print("Curr pose:", curr_robot_pose)
		# print("robot mov:", H_RT_RH)

		#####################################################################################
		translation_scale = 50.0 #10.0 #50.0
		# rotation_scale = 1.0
		T_togo = H_RT_RH[:3, 3] #* translation_scale
		R_togo = H_RT_RH[:3, :3]
		T_curr = curr_robot_pose[:3, 3]
		R_curr = curr_robot_pose[:3, :3]
		# print(f"Raw T: {H_RT_RH[:3, 3]}")
		# print(f"Scaled T: {T_togo}")
		# print(f"Curr T: {T_curr}")
		# print("-----------------------------------------------------")
		rel_pos = (T_togo - T_curr) * translation_scale
		rel_rot = np.linalg.pinv(R_curr) @ R_togo
		rel_axis_angle = R.from_matrix(rel_rot).as_rotvec()
		print(f"R curr: {R_curr}")
		print(f"R moving: {self.robot_moving_H[:3,:3]}")
		# print(f"R togo: {R_togo}")
		# print(f"Rel rot: {rel_rot}")
		print(f"Rel axis angle: {rel_axis_angle}")
		# rel_axis_angle[np.abs(rel_axis_angle) < 0.15] = 0
		rel_axis_angle = rel_axis_angle * 5.0 #10.0 #0.2
		print(f"Mod Rel axis angle: {rel_axis_angle}")
		print("#####################################################")
		# rel_axis_angle = np.zeros(3)
		#####################################################################################

		# # transformation_matrix = np.linalg.pinv(self.robot_moving_H) @ H_RT_RH
		# rel_pos = (H_RT_RH[:3, 3] - self.robot_moving_H[:3, 3]) * 50.0
		# # rel_pos = (H_RT_RH[:3, 3] - curr_robot_pose[:3, 3]) #* 50.0
		# # rel_pos[np.abs(rel_pos) < 100] = 0
		# # print(f"Rel pos: {rel_pos}")
		# rel_rot = np.linalg.pinv(self.robot_moving_H[:3,:3]) @ H_RT_RH[:3,:3]
		# # rel_rot = np.linalg.pinv(curr_robot_pose[:3,:3]) @ H_RT_RH[:3,:3]
		# # rel_rot[np.abs(rel_rot) < 100] = 0
		# # print(f"Rel rot: {rel_rot}")
		# # rel_quat = T.mat2quat(rel_rot)
		# # print(f"Rel quat: {rel_quat}")
		# # rel_axis_angle = T.quat2axisangle(rel_quat) * 0.1 #* 10 # np.array([3,5,10]) #5 #10
		# r = R.from_matrix(rel_rot + 1e-4)
		# rel_axis_angle = r.as_rotvec() #* 10 # np.array([3,5,10]) #5 #10
		# # rel_axis_angle[np.abs(rel_axis_angle) < 1] = 0
		# # rel_axis_angle = np.zeros(3)
		# # print(f"Rel axis angle: {rel_axis_angle}")
		
		self.robot_moving_H = copy(H_RT_RH)

		#####################################################################################


		# H_RT_RH= H_RI_RH @ H_HT_HI
		# self.robot_moving_H = copy(H_RT_RH)
		# self.hand_init_H = self.hand_moving_H

		if log:
			print('** ROBOT MOVING H **\n{}\n** ROBOT INIT H **\n{}\n'.format(
				self.robot_moving_H, self.robot_init_H))
			print('** HAND MOVING H: **\n{}\n** HAND INIT H: **\n{} - HAND INIT T: {}'.format(
				self.hand_moving_H, self.hand_init_H, self.hand_init_t
			))
			print('***** TRANSFORM MULT: ******\n{}'.format(
				H_HT_HI
			))

			print('\n------------------------------------\n\n\n')

		# H_R_V = np.array([[ 0,  0, -1, 0],
		# 				  [ 0, -1,  0, 0],
		# 				  [-1,  0,  0, 0],
		# 				  [ 0,  0,  0, 1]])
		# H_T_V = np.array([[ 0,  0,  1, 0],
		# 				  [ 0,  1,  0, 0],
		# 				  [-1,  0,  0, 0],
		# 				  [ 0,  0,  0, 1]])
		# rot = (np.linalg.pinv(H_R_V) @ H_HT_HI @ H_R_V)[:3,:3]
		# transformation_matrix = np.eye(4)
		# transformation_matrix[:3,:3] = rot
		# # transformation_matrix[:3, 3] = H_HT_HI[:3, 3]
		# transformation_matrix[:3, 3] = (np.linalg.pinv(H_T_V) @ H_HT_HI @ H_T_V)[:3,3]

		##########################################################################



		# # compute relative action
		# rel_pos = transformation_matrix[:3, 3] * 50 #2
		# # rel_pos = rel_pos[[2,1,0]] #[2,0,1]]
		# # rel_pos[0] = -rel_pos[0]
		# # rel_pos[1:] = -rel_pos[1:]
		# rel_quat = T.mat2quat(transformation_matrix[:3, :3])
		# rel_axis_angle = T.quat2axisangle(rel_quat) * np.array([3,5,10]) #5 #10
		# # rel_axis_angle = T.quat2axisangle(rel_quat) * np.array([3,5,10]) #5 #10
		# # rel_axis_angle[np.abs(rel_axis_angle) < 0.2] = 0
		# # rel_axis_angle[1:] = -rel_axis_angle[1:]
		action = np.concatenate([rel_pos, rel_axis_angle, [gripper_state]])

		averaged_action = moving_average(
			action,
			self.moving_Average_queue,
			5,
			# self.moving_average_limit
		)
		# for axis in range(len(averaged_action[:3])):
		# 	if abs(averaged_action[axis]) < 1:
		# 		averaged_action[axis] = 0



		self.end_eff_position_publisher.pub_keypoints(averaged_action,"endeff_coords")


		if self.use_filter:
			final_pose = self.comp_filter(final_pose)

	