"""
This file builds an environment satisfies OpenAI Gym standard
"""
from Envs.pybullet.turtlebot.robot_bases import BaseRobot
import numpy as np
import os
import cv2
import pybullet as p


class TurtleBot(BaseRobot):
	"""
	This is a derived class for TurtleBot choosing task
	"""
	def __init__(self,config=None):

		model_file_path = os.path.join(config.mediaPath, 'turtlebotModel', 'turtlebot3_waffle_pi.urdf')

		super(TurtleBot, self).__init__(model_file=model_file_path,
										robot_name=config.robotName,
										scale=config.robotScale)
		self._p = None
		self.L = config.robotWheelDistance
		self.R = config.robotWheelRadius
		self.config=config

		# robot
		self.actionBuf = [[0., 0.]]   # it is used to simulate the delay happened in the actuator and communication
		# the altitude of all the objects including the robot, hard coded value. Every episode, the objects and the
		# robot will be reset to the position with this altitude
		self.entityZ = 0.04
		# the RL planner's decision will change these vectors
		self.desiredPose = [0, 0, 0] # [dx, dy, dtheta]
		self.desiredTransVel = 0.0
		self.desiredRotVel = 0.0
		self.desiredRotPos=0.0

		self.anglePassed = 0.
		self.initialPose = [0, 0, 0]

		# rayTest
		self.rayIDs=[]
		self.rayAngles=np.deg2rad(np.linspace(-config.robotFov/2., config.robotFov/2., config.numRays))
		self.rayTest=None # the results of the ray test

	def get_pose(self):
		"""Get current robot position and orientation in quaternion [x,y,z,w]
		"""
		return self.robot_body.get_pose()

	def get_position(self):
		"""Get current robot position
		"""
		return self.robot_body.get_position()

	def get_orientation(self):
		"""Get current robot orientation in quaternion [x,y,z,w]
		"""
		return self.robot_body.get_orientation()

	def set_position(self, pos):
		"""
		Set the position of the robot body according to 'pos'. The orientation will not be changed
		:param pos: the desired position of the robot body
		"""
		self.robot_body.reset_position(pos)

	def set_orientation(self, orn):
		"""
		Set the orientation of the robot body according to 'orn'. The position will not be changed
		:param orn: the desired orientation of the robot body in quaternion [x,y,z,w]
		"""
		self.robot_body.set_orientation(orn)
	def set_pose(self,pos,orn):
		"""
		Set the position and the orientation of the robot body.
		:param pos: the desired position of the robot body
		:param orn: the desired orientation of the robot body in quaternion [x,y,z,w]
		"""
		self.robot_body.set_pose(pos,orn)

	def _get_scaled_position(self):
		'''Private method, please don't use this method outside
		Used for downscaling MJCF models
		'''
		return self.robot_body.get_position() / self.scale

	def __del__(self):
		pass

	def apply_action(self, action, controlMethod, **kwargs):
		realAction=None
		body_xyz, rotation = self.get_odom()

		if controlMethod=='pointFollower':

			goal_x, goal_y, goal_z = action #(x,y,theta), assuming theta is in radians
			goal_distance = np.sqrt(pow(goal_x - body_xyz[0], 2) + pow(goal_y - body_xyz[1], 2))
			distance = goal_distance
			last_rotation=0.

			while distance > 0.05:
				body_xyz, rotation = self.get_odom()
				x_start = body_xyz[0]
				y_start = body_xyz[1]
				path_angle = np.arctan2(goal_y - y_start, goal_x - x_start)

				if path_angle < -np.pi / 4 or path_angle > np.pi / 4:
					if goal_y < 0 and y_start < goal_y:
						path_angle = -2 * np.pi + path_angle
					elif goal_y >= 0 and y_start > goal_y:
						path_angle = 2 * np.pi + path_angle
				if last_rotation > np.pi - 0.1 and rotation <= 0:
					rotation = 2 * np.pi + rotation
				elif last_rotation < -np.pi + 0.1 and rotation > 0:
					rotation = -2 * np.pi + rotation
				angularZ = self.config.pointFollowerAngularGain * path_angle - rotation

				distance = np.sqrt(pow((goal_x - x_start), 2) + pow((goal_y - y_start), 2))
				self.desiredTransVel = min(self.config.pointFollowerLinearGain * distance, 0.1)

				if angularZ > 0:
					self.desiredRotVel = min(angularZ, 1.5)
				else:
					self.desiredRotVel = max(angularZ, -1.5)

				last_rotation = rotation
				self.motorDriver()
				self._p.stepSimulation()

			position, rotation=self.get_odom()

			while abs(rotation - goal_z) > 0.1:
				position, rotation = self.get_odom()
				if goal_z >= 0:
					if rotation <= goal_z and rotation >= goal_z - np.pi:
						self.desiredTransVel = 0.00
						self.desiredRotVel = 0.5
					else:
						self.desiredTransVel = 0.00
						self.desiredRotVel = -0.5
				else:
					if rotation <= goal_z + np.pi and rotation > goal_z:
						self.desiredTransVel = 0.00
						self.desiredRotVel = -0.5
					else:
						self.desiredTransVel = 0.00
						self.desiredRotVel = 0.5

				self.motorDriver()
				self._p.stepSimulation()

			self.desiredTransVel = 0.0
			self.desiredRotVel = 0.0
			self.motorDriver()
			self._p.stepSimulation()

		elif controlMethod=='setPose':
			goal_x, goal_y, goal_z = action  # (x,y,theta), assuming theta is in radians
			self.set_pose([goal_x, goal_y, self.entityZ], self._p.getQuaternionFromEuler([0, 0, goal_z]))

		elif controlMethod=='setPoseInc':
			inc_x, inc_y, inc_z = action  # (delta x, delta y, delta theta), assuming theta is in radians
			planed_desiredX = float(np.clip(self.desiredPose[0]+inc_x, a_min=-self.config.xyMax, a_max=self.config.xyMax))
			planed_desiredY = float(np.clip(self.desiredPose[1] + inc_y, a_min=-self.config.xyMax, a_max=self.config.xyMax))
			planed_desiredRotPos = self.desiredPose[2] + inc_z
			if planed_desiredRotPos>np.pi:
				planed_desiredRotPos=planed_desiredRotPos-2*np.pi
			elif planed_desiredRotPos<-np.pi:
				planed_desiredRotPos=planed_desiredRotPos+2*np.pi
			collide, disList, _ = self.isCollide(new_obj=('robot', [planed_desiredX, planed_desiredY, planed_desiredRotPos]),
													currentPlan=kwargs['currentPlan'],
													objList=kwargs['objList'],
													objInScene=kwargs['objInScene'])
			if sum(collide) == 0:
				self.set_pose([planed_desiredX, planed_desiredY, self.entityZ],
							  self._p.getQuaternionFromEuler([0, 0, planed_desiredRotPos]))
				self.desiredPose[0]=planed_desiredX
				self.desiredPose[1]=planed_desiredY
				self.desiredPose[2]=planed_desiredRotPos
			else:
				self.set_pose([self.desiredPose[0], self.desiredPose[1], self.entityZ],
							  self._p.getQuaternionFromEuler([0, 0, planed_desiredRotPos]))
				self.desiredPose[2] = planed_desiredRotPos

		elif controlMethod=='rotPos':
			self.desiredTransVel = self.desiredTransVel + action[0]
			self.desiredTransVel = np.clip(self.desiredTransVel, self.config.robotMinTransVel, self.config.robotMaxTransVel)
			self.desiredRotPos = self.desiredRotPos + action[1]

			self.desiredRotPos = np.clip(self.desiredRotPos, -2 * np.pi, 2 * np.pi)
			if self.desiredRotPos >= 0:
				l = [self.desiredRotPos - self.anglePassed, self.desiredRotPos - (self.anglePassed + 2 * np.pi)]
			else:
				l = [self.desiredRotPos - self.anglePassed, self.desiredRotPos - (self.anglePassed - 2 * np.pi)]
			index = int(np.argmin([abs(l[0]), abs(l[1])]))
			error = l[index]
			self.desiredRotVel = self.config.rotPosPGain * error  # e.g. P control gain=1
			self.desiredRotVel = np.clip(self.desiredRotVel, -self.config.robotMaxRotVel, self.config.robotMaxRotVel)
			realAction=self.motorDriver()

		return realAction

	def motorDriver(self):
		# use unicircle model in the simulator. No need for the real robot
		self.actionBuf.append([self.desiredTransVel, self.desiredRotVel])
		delayedVel = self.actionBuf.pop(0)

		# the desired wheel joint velocity in rad/s should be limited to 7.87rad/s, since the max transition velocity
		# is 0.26m/s and the radius of the wheel is 0.033. In the real experiment, it is usually 7.7 rad/s
		desiredRightWheelVel = (2. * delayedVel[0] + delayedVel[1] * self.L) / (2. * self.R)
		desiredLeftWheelVel = (2. * delayedVel[0] - delayedVel[1] * self.L) / (2. * self.R)
		desiredRightWheelVel = np.clip(desiredRightWheelVel, -7.7, 7.7)
		desiredLeftWheelVel = np.clip(desiredLeftWheelVel, -7.7, 7.7)

		realAction = [desiredLeftWheelVel,
					  desiredRightWheelVel]  # the velocity command for the left motor and the right motor

		for n, j in enumerate(self.ordered_joints):
			j.set_motor_velocity(realAction[n])

		return realAction

	def robot_specific_reset(self, robotx, roboty, robotYaw, np_random):
		for j in self.ordered_joints:
			j.reset_joint_state(np_random.uniform(low=-0.1, high=0.1), 0)
		self.set_pose([robotx,roboty, self.entityZ],self._p.getQuaternionFromEuler([0, 0, robotYaw]))

		self.desiredPose = [robotx, roboty, robotYaw]  # [dx, dy, dtheta]
		self.desiredTransVel = 0.0
		self.desiredRotVel = 0.0
		self.desiredRotPos = 0. 
		self.anglePassed=0.
		self.initialPose = [robotx, roboty, robotYaw]

	def calc_state(self, objPoseList, objList, objInScene):
		# get robot body position and orientation
		body_xyz, yaw=self.get_odom()

		self.anglePassed = yaw - self.initialPose[2]
		if self.anglePassed > np.pi:
			self.anglePassed = self.anglePassed - 2 * np.pi
		elif self.anglePassed < -np.pi:
			self.anglePassed = self.anglePassed + 2 * np.pi

		# check if the robot is colliding with any object or wall and the distance between them
		collide, disList, angList=self.get_measurement(new_obj=('robot',[body_xyz[0],body_xyz[1],yaw]),
										currentPlan=objPoseList,
										objList=objList,
										objInScene=objInScene)

		s={'xyz':body_xyz, 'yaw':yaw, 'collide':collide,
		   'dist':disList, 'ang': angList, 'motor': [self.desiredTransVel, self.desiredRotPos]}
		return s

	def get_image(self):
		# cameraFrame=worldFrame*Rz(90)*Ry(0)*Rx(-90), rpy=(90,0,-90)
		eye_pos = self.eyes.get_position()
		position, yaw=self.get_odom()

		cam_offset=np.array([self.config.robotCamOffset * np.cos(yaw),
							 self.config.robotCamOffset * np.sin(yaw),
							 0])
		view_matrix = \
			self._p.computeViewMatrix(cameraEyePosition=eye_pos+cam_offset,
									  cameraTargetPosition=eye_pos + 2*cam_offset,
									  cameraUpVector=np.array([0, 0, 1]))

		# Turtlebot camera resolution is 640*480. We keep the aspect ratio here. You need to do downsampling
		# to the images received by the camera first.

		proj_matrix = self._p.computeProjectionMatrixFOV(
			fov=self.config.robotFov, aspect=4. / 3., #62.2
			nearVal=0.01, farVal=100)

		(w, h, px, _, _) = self._p.getCameraImage(
			height=self.config.robotCamRenderSize[0],
			width=self.config.robotCamRenderSize[1], viewMatrix=view_matrix,
			projectionMatrix=proj_matrix, shadow=0, renderer=self._p.ER_TINY_RENDERER,
			flags=self._p.ER_NO_SEGMENTATION_MASK
		)
		rgb_array = np.array(px)
		# The image output here will be (height,width, channel) as well
		img = rgb_array[:, :, :3]

		# process the image
		if self.config.robotCamRenderSize[2] == 1:  # if we need grayscale image
			img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
			img = np.reshape(img, [self.config.robotCamRenderSize[0], self.config.robotCamRenderSize[1], 1])

		# crop and resize
		img = img[:, 12:87, :]
		img = cv2.resize(img, (self.config.img_dim[2], self.config.img_dim[1])) # cv2.resize(width, height)

		return img

	def ray_test(self, objUidList):
		eye_pos = self.eyes.get_position()
		position, yaw = self.get_odom()

		cam_offset = np.array([self.config.robotCamOffset * np.cos(yaw),
							   self.config.robotCamOffset * np.sin(yaw),
							   0])

		startPoint = eye_pos + cam_offset
		rayFrom = [startPoint] * self.config.numRays
		rayTo = np.array([self.config.rayLen * np.cos(yaw + self.rayAngles),
						  self.config.rayLen * np.sin(yaw + self.rayAngles),
						  np.zeros([self.config.numRays])])
		rayTo = np.reshape(startPoint, (3, 1)) + rayTo
		rayTo = np.split(rayTo, self.config.numRays, axis=1)

		results = p.rayTestBatch(rayFrom, rayTo)
		self.rayTest = []

		for i in range(self.config.numRays):
			hitObjectUid = results[i][0]
			self.rayTest.append(hitObjectUid)
			if self.config.render:
				if len(self.rayIDs) < self.config.numRays:  # draw these rays out
					self.rayIDs.append(p.addUserDebugLine(rayFrom[i], rayTo[i], self.config.rayMissColor))

				if hitObjectUid not in objUidList:
					hitPosition = [0, 0, 0]
					p.addUserDebugLine(rayFrom[i], rayTo[i], self.config.rayMissColor,
									   replaceItemUniqueId=self.rayIDs[i])
				else:
					hitPosition = results[i][3]
					p.addUserDebugLine(rayFrom[i], hitPosition, self.config.rayHitColor,
									   replaceItemUniqueId=self.rayIDs[i])
		return np.array(self.rayTest)

	def get_measurement(self, new_obj, currentPlan, objList, objInScene):
		"""
		get the distance, angle and collision information for calc_state()
		new_obj is a tuple ('robot', [robotx, roboty, robotyaw]).
		currentPlan is a list which has the location of all the objects obtained from randomization()
		the order will be same as self.objList . list: [[x,y,yaw],[x,y,yaw],...].

		:return:
		A one-hot vector indicating if the robot is colliding with objects or wall and
		a list of distance and angle between the robot and the objects or walls.
		"""

		new_key, new_pose = new_obj
		tot_dist = self.config.robotExpandDistance + self.config.robotRadius
		# the order of these three lists: self.objList ->wall
		collideList=[]  #one-hot vector indicating if the robot is colliding with wall or objects
		# the distance between robot and objects or walls
		disList=[]
		angList = []

		# check against other objects
		for i,value in enumerate(currentPlan):
			if i in objInScene: # get information of objects in the scene
				# get distance between the robot and the object
				distance = np.linalg.norm([new_pose[0] - value[0], new_pose[1] - value[1]])
				threshold=self.config.robotRadius + self.config.robotExpandDistance + \
						 self.config.objectsRadius[objList[i]] + self.config.objectsExpandDistance[objList[i]]
				collideList.append(int(distance <= threshold))

				# the value distance-threshold is around 0.19~0.34m. It is the distance from robot camera to object
				# surface.
				disList.append(distance-threshold) # this value will be negative if collide

				# use arctan2 for ensuring ang is from -pi to pi
				y_diff = (value[1] - new_pose[1]) if abs(value[1] - new_pose[1]) > 0.0001 else 0
				x_diff = (value[0] - new_pose[0]) if abs(value[0] - new_pose[0]) > 0.0001 else 0
				ang = np.arctan2(y_diff, x_diff)
				angle2target = ang - new_pose[2]
				if angle2target > np.pi:  # so that the range is from -pi to pi
					angle2target = angle2target - 2 * np.pi
				elif angle2target < -np.pi:
					angle2target = angle2target + 2 * np.pi
				angList.append(angle2target)

			else: # for objs not in the scene, just input 0
				disList.append(0)
				collideList.append(0)
				angList.append(0)

		# check against wall
		if new_pose[0] - tot_dist <= -0.95 or new_pose[0] + tot_dist >= 0.95 or \
				new_pose[1] - tot_dist <= -0.95 or new_pose[1] + tot_dist >= 0.95:
			collideList.append(1)
		else:
			collideList.append(0)

		disX = abs(0.95 - new_pose[0] - tot_dist)
		disY = abs(0.95 - new_pose[1] - tot_dist)
		disList.extend([disX,disY,1.9-disX-2*tot_dist,1.9-disY-2*tot_dist])
		return collideList, disList, angList

	def get_odom(self):
		body_pose = self.robot_body.pose()
		body_xyz = body_pose.xyz()
		roll, pitch, yaw = body_pose.rpy()
		return body_xyz, yaw
