from .robot_bases import *
import numpy as np
import os
import cv2


class KinovaGen3(BaseRobot):
	"""
	It is the base class for Kinova Gen3 robot
	"""

	def __init__(self, config=None):
		model_file = os.path.join(config.mediaPath, 'kinova_gen3', 'gen3_robotiq_2f_85.urdf')
		BaseRobot.__init__(self, model_file=model_file, robot_name=config.robotName, scale=config.robotScale)
		self._p = None  # bullet client
		self.config=config

		self.numJoints=None
		self.desiredEndEffectorPos = [0.0, 0.0, 0.0] # the RL planner's decision will change this vector

		self.rayIDs = None

	def robot_specific_reset(self, eePositionX, eePositionY, eePositionZ):
		# reset robot
		# do not call inverse kinematics directly without resetting robot with jointPositionReset
		# Otherwise, you will get incomplete or unstable results
		jointPositionsReset = self.config.ik_rp

		for jointIndex in range(self.numJoints-6):
			self._p.resetJointState(self.robot_ids, jointIndex, jointPositionsReset[jointIndex])
			self._p.setJointMotorControl2(self.robot_ids, jointIndex, self._p.POSITION_CONTROL,
										  targetPosition=jointPositionsReset[jointIndex],
										  force=self.config.positionControlMaxForce)

		eePosition = [eePositionX, eePositionY, eePositionZ]
		orn = self._p.getQuaternionFromEuler([0, np.pi, np.pi/2])

		jointPositionsInitial = self.invKin(eePosition, orn)

		for jointIndex in range(self.numJoints-6):
			self._p.resetJointState(self.robot_ids, jointIndex, jointPositionsInitial[jointIndex])
			self._p.setJointMotorControl2(self.robot_ids, jointIndex, self._p.POSITION_CONTROL,
										  targetPosition=jointPositionsInitial[jointIndex],
										  force=self.config.positionControlMaxForce)

		self.desiredEndEffectorPos = [eePositionX, eePositionY, eePositionZ]

	def calc_state(self):
		eeState = self._p.getLinkState(self.robot_ids, self.config.endEffectorIndex)[0]

		s = {'eeState':eeState}
		return s

	def applyAction(self, eeCommands, controlMethod):

		dx, dy, dz = eeCommands
		self.desiredEndEffectorPos[0] = self.desiredEndEffectorPos[0] + dx
		self.desiredEndEffectorPos[0] = np.clip(self.desiredEndEffectorPos[0], a_min=self.config.xMin,
												a_max=self.config.xMax)

		self.desiredEndEffectorPos[1] = self.desiredEndEffectorPos[1] + dy
		self.desiredEndEffectorPos[1] = np.clip(self.desiredEndEffectorPos[1], a_min=self.config.yMin,
												a_max=self.config.yMax)

		self.desiredEndEffectorPos[2] = self.desiredEndEffectorPos[2] + dz
		orn = self._p.getQuaternionFromEuler([0, np.pi, np.pi/2])  # -math.pi,yaw])

		jointPositions = self.invKin(self.desiredEndEffectorPos, orn)

		if controlMethod=='position':
			for i in range(self.config.endEffectorIndex + 1):
				self._p.setJointMotorControl2(bodyUniqueId=self.robot_ids, jointIndex=i,
											  controlMode=self._p.POSITION_CONTROL,
											  targetPosition=jointPositions[i], targetVelocity=0,
											  force=self.config.positionControlMaxForce,
											  positionGain=self.config.positionControlPositionGain,
											  velocityGain=self.config.positionControlVelGain)
			# fingers
			self._p.setJointMotorControl2(self.robot_ids, 13, self._p.POSITION_CONTROL, targetPosition=0,
										  force=self.config.positionControlMaxForce)
			self._p.setJointMotorControl2(self.robot_ids, 15, self._p.POSITION_CONTROL, targetPosition=0,
										  force=self.config.fingerAForce)
			self._p.setJointMotorControl2(self.robot_ids, 17, self._p.POSITION_CONTROL, targetPosition=0,
										  force=self.config.fingerBForce)

			self._p.setJointMotorControl2(self.robot_ids, 18, self._p.POSITION_CONTROL, targetPosition=0,
										  force=self.config.fingerTipForce)
			self._p.setJointMotorControl2(self.robot_ids, 20, self._p.POSITION_CONTROL, targetPosition=0,
										  force=self.config.fingerTipForce)
			self._p.setJointMotorControl2(self.robot_ids, 22, self._p.POSITION_CONTROL, targetPosition=0,
										  force=self.config.fingerTipForce)

		else:
			raise NotImplementedError

	def get_image(self, externalCamEyePosition, externalCamTargetPosition):
		view_matrix = \
			self._p.computeViewMatrix (cameraEyePosition=externalCamEyePosition,
									cameraTargetPosition=externalCamTargetPosition,
									cameraUpVector=[0,0,1])

		# RealSense D455 camera . We keep the aspect ratio here. You need to do downsampling
		# to the images received by the camera first.

		proj_matrix = self._p.computeProjectionMatrixFOV(
			fov=self.config.externalCamFov, aspect=self.config.externalCamAspect,
			nearVal=0.01, farVal=100)

		(_, _, px, _, _) = self._p.getCameraImage(
			width=self.config.externalCamRenderSize[1], height=self.config.externalCamRenderSize[0], viewMatrix=view_matrix,
			projectionMatrix=proj_matrix, shadow=0,
			# if you load the weights, do not forget to use TINY_RENDERER, otherwise the images will be different
			renderer=self._p.ER_TINY_RENDERER,
			flags=self._p.ER_NO_SEGMENTATION_MASK
		)
		rgb_array = np.array(px)
		img = rgb_array[:, :, :3]
		img = img[:, 12:87, :]
		img = cv2.resize(img, (self.config.img_dim[2], self.config.img_dim[1]))

		# process the image
		if self.config.externalCamRenderSize[2] == 1:  # if we need grayscale image
			img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
			img = np.reshape(img, [self.config.externalCamRenderSize[0], self.config.externalCamRenderSize[1], 1])

		return img

	def ray_test(self, objUidList):
		eePosition = list(self._p.getLinkState(self.robot_ids, self.config.endEffectorIndex)[0])
		eePosition[-1]=eePosition[-1]-0.15 # an offset so that the ray does not contact with the end effector
		rayTo=[eePosition[0], eePosition[1], self.config.tableZ]
		results = p.rayTest(eePosition, rayTo)

		hitObjectUid = results[0][0]
		if self.config.render:
			if self.rayIDs is None:  # draw these rays out
				self.rayIDs=p.addUserDebugLine(eePosition, rayTo, self.config.rayMissColor, lineWidth=5)

			if hitObjectUid not in objUidList:
				hitPosition = [0, 0, 0]
				p.addUserDebugLine(eePosition, rayTo, self.config.rayMissColor, lineWidth=5, replaceItemUniqueId=self.rayIDs)
			else:
				hitPosition = results[0][3]
				p.addUserDebugLine(eePosition, hitPosition, self.config.rayHitColor, lineWidth=5, replaceItemUniqueId=self.rayIDs)
		return results[0][0]

	def invKin(self, pos,orn):
		# calculate inverse kinematics
		if self.config.ik_useNullSpace:
			if self.config.ik_useOrientation:
				jointPositions = self._p.calculateInverseKinematics(self.robot_ids, self.config.endEffectorIndex, pos, orn,
																lowerLimits=self.config.ik_ll, upperLimits=self.config.ik_ul,
																jointRanges=self.config.ik_jr, restPoses=self.config.ik_rp, maxNumIterations=50)
			else:
				jointPositions = self._p.calculateInverseKinematics(self.robot_ids, self.config.endEffectorIndex, pos,
														  lowerLimits=self.config.ik_ll, upperLimits=self.config.ik_ul,
														  jointRanges=self.config.ik_jr, restPoses=self.config.ik_rp)
		else:
			if self.config.ik_useOrientation:
				jointPositions = self._p.calculateInverseKinematics(self.robot_ids, self.config.endEffectorIndex, pos, orn,
														  jointDamping=self.config.ik_jd)
			else:
				jointPositions = self._p.calculateInverseKinematics(self.robot_ids, self.config.endEffectorIndex, pos)


		return list(jointPositions)
