from ..scene_abstract import SingleRobotEmptyScene

from ..env_bases import BaseEnv
from Envs.pybullet.kuka.robot_manipulators import *
import cv2
import sounddevice as sd
from cfg import main_config
from Envs.audioLoader import audioLoader
import pandas as pd
import os


class RLEnvRSI2(BaseEnv):
	def __init__(self):

		self.config = main_config()
		self.audio = None
		self.robot = Kuka(config=self.config)
		self.robotID = None

		d = {
			'image': gym.spaces.Box(low=0, high=255, shape=self.config.img_dim, dtype='uint8'),
			'goal_sound': gym.spaces.Box(low=-np.inf, high=np.inf, shape=self.config.sound_dim, dtype=np.float32),
			'current_sound': gym.spaces.Box(low=-np.inf, high=np.inf, shape=self.config.sound_dim, dtype=np.float32),
			'robot_pose': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
			'ground_truth': gym.spaces.Box(low=0, high=4, shape=(1,), dtype=np.int32)
		}
		self.observation_space = gym.spaces.Dict(d)
		self.maxSteps = self.config.RLEnvMaxSteps

		# setup action space
		high = np.ones(self.config.RLActionDim)
		self.action_space = gym.spaces.Box(-high, high, dtype=np.float32)

		BaseEnv.__init__(self, self.robot,
						 config=self.config,
						 render=self.config.render,
						 action_space=self.action_space,
						 observation_space=self.observation_space)

		# models
		self.objList = self.config.objList
		self.config.taskNum = len(self.objList)
		self.objUidList = []  # PyBullet Uid for each object

		# each spot associates with a list [x,y,yaw]. The order will be same with self.objList
		# [[x,y,yaw],[x,y,yaw],...]
		self.objPoseList = []  # objects position and orientation information
		self.tableUid = None

		self.workSpaceDebugLine = []
		self.eePositionDebugLine = None

		self.hideObjIdx = []

		self.goalObjIdx = None
		self.goal_sound = None
		self.ground_truth = None # the ground truth label for the goal sound
		self.goal_audio = None
		self.rayTestResult=None


		self.size_per_class = np.zeros((len(self.config.objList),), dtype=np.int64)
		for key in self.config.soundSource['size']:
			self.size_per_class = self.size_per_class + self.config.soundSource['size'][key]
		self.size_per_class_cumsum = np.cumsum(self.size_per_class)

		self.goal_area_count=0

		if self.config.record:
			self.episodeRecorder = Recorder()  # used for recording data
			self.episodeRecorder.saveTo = self.config.recordSaveDir
			if self.config.loadAction:
				self.episodeRecorder.loadActions()
				self.episodeRecorder.loadFrom = self.config.loadActionFile

	def saveEpisodeImage(self, image):
		if self.config.episodeImgSaveInterval > 0 and self.episodeCounter % self.config.episodeImgSaveInterval == 0:
			# save the images
			imgSave = cv2.resize(image, (self.config.episodeImgSize[1], self.config.episodeImgSize[0]))
			if self.config.episodeImgSize[2] == 3:
				imgSave = cv2.cvtColor(imgSave, cv2.COLOR_RGB2BGR)
			fileName = str(self.givenSeed) + 'out' + str(self.envStepCounter) + '.jpg'
			cv2.imwrite(os.path.join(self.config.episodeImgSaveDir, fileName), imgSave)

	def create_single_player_scene(self, bullet_client):
		"""
		Setup physics engine and simulation
		:param bullet_client:
		:return:
		"""

		return SingleRobotEmptyScene(bullet_client, gravity=(0, 0, -9.8),
									 timestep=self.timeStep, frame_skip=self.config.frameSkip,
									 render=self.config.render)

	def randomization(self):
		# randomize the locations of the blocks
		orn = self._p.getQuaternionFromEuler([0, 0, 0])  # do not randomize the orientation
		randomx = self.np_random.uniform(self.config.xMin+0.05, self.config.xMax-0.05) #0.5~0.7  fix: 0.55
		randomy = self.np_random.uniform(self.config.yMin+0.05, self.config.yMax-0.45) #-0.2~-0.1 fix: -0.15

		if self.config.hideObj['mode'] == 'random':
			self.hideObjIdx = self.np_random.choice(self.config.taskNum, size=self.config.hideObj['hideNum'], replace=False)
			for i in range(len(self.objUidList)):
				x =1. if i in self.hideObjIdx else randomx
				self._p.resetBasePositionAndOrientation(self.objUidList[i],
														[x, randomy + i * self.config.objInterval, self.config.objZ],
														orn)

		elif self.config.hideObj['mode']=='fix':
			for i in range(len(self.objUidList)):
				x = 1. if i in self.config.hideObj['hideIdx'] else randomx
				self._p.resetBasePositionAndOrientation(self.objUidList[i],
														[x, randomy + i * self.config.objInterval, self.config.objZ],
														orn)

		elif self.config.hideObj['mode']=='none':
			for i in range(len(self.objUidList)):
				if self.config.objList[i]=='capsule':
					orn=self._p.getQuaternionFromEuler([0, 1.57, 0])
					objZ=self.config.objZ-0.02
					x=randomx
					y=randomy
				elif self.config.objList[i]=='teddy':
					orn = self._p.getQuaternionFromEuler([0, 3.14, 1.57])
					objZ = self.config.objZ-0.01
					x=randomx+0.05
					y=randomy
				elif self.config.objList[i]=='bunny':
					orn = self._p.getQuaternionFromEuler([1.57, 1.57, 1.57])
					objZ = self.config.objZ-0.1
					x=randomx+0.02
					y=randomy+0.07
				else:
					orn=self._p.getQuaternionFromEuler([0, 0, 0])
					objZ=self.config.objZ
					x=randomx
					y=randomy
				self._p.resetBasePositionAndOrientation(self.objUidList[i],
														[x, y + i * self.config.objInterval, objZ],
														orn)

		else:
			raise NotImplementedError

		eePositionX=self.np_random.uniform(self.config.xMin+0.05, self.config.xMax-0.05)
		eePositionY=self.np_random.uniform(self.config.yMin+0.05, self.config.yMax-0.05)
		self.robot.robot_specific_reset(eePositionX, eePositionY, self.config.endEffectorHeight)
		self._p.stepSimulation() # refresh the simulator. Needed for the ray test

	def get_positive_negative(self, get_negative=True, generate_audio=True):
		rayTest = self.robot.ray_test(self.objUidList)
		contactRays = [True if rayTest == Uid else False for Uid in self.objUidList]
		positive_audio=None
		sound_negative=None
		sound_positive=None
		if not any(contactRays):  # the end effector hits nothing, no sound is given
			if generate_audio:
				sound_positive = np.zeros(shape=self.config.sound_dim)
			ground_truth = np.int32(self.config.taskNum)
			if get_negative:
				objIndx = self.np_random.randint(0, self.config.taskNum)
				sound_negative, _ = self.audio.genSoundFeat(objIndx=objIndx, featType='MFCC',
															rand_fn=self.np_random.randint)

		else:  # the agent sees an object in self.config.objList
			objIndx_positive = np.argmax(contactRays)  # choose the one with min distance
			ground_truth = np.int32(objIndx_positive)

			if generate_audio or self.config.render:
				sound_positive, positive_audio = self.audio.genSoundFeat(objIndx=objIndx_positive, featType='MFCC',
																		 rand_fn=self.np_random.randint)
			if get_negative:
				objIndx_negative = self.np_random.randint(0, self.config.taskNum)
				if objIndx_positive == objIndx_negative:
					sound_negative = np.zeros(shape=self.config.sound_dim)
				else:
					sound_negative, negative_audio = self.audio.genSoundFeat(objIndx=objIndx_negative, featType='MFCC',
																			 rand_fn=self.np_random.randint)

		return sound_positive, sound_negative, ground_truth, positive_audio

	def envReset(self):
		if self.robot.robot_ids is None:
			# load sound
			# if os is Linux, we already load the data in the parent process. See shmem_vec_env.py
			if self.audio is None:
				self.audio=audioLoader(config=self.config)
			# load robot
			objects = self._p.loadSDF(os.path.join(pybullet_data.getDataPath(), "kuka_iiwa/kuka_with_gripper2.sdf"))
			self.robot.robot_ids = objects[0]
			self.robotID=self.robot.robot_ids

			# anchor the robot
			self._p.resetBasePositionAndOrientation(self.robotID, [-0.100000, 0.000000, 0.070000],
													[0.000000, 0.000000, 0.000000, 1.000000])

			self.robot.numJoints = self._p.getNumJoints(self.robotID)
			'''
			Kuka iiwa 7 DOF + 5 DOF tool
			jointIndex	jointName		jointType
			0			J1 (iiwa)			R   
			1			J2 (iiwa)			R
			2			J3 (iiwa)			R
			3			J4 (iiwa)			R
			4			J5 (iiwa)			R
			5			J6 (iiwa)			R 
			6 			J7 (iiwa)			R EE position
			7			gripper base		R EE orientation
			8			left finger			R
			10			left finger tip		R
			11			right finger		R
			13			right finger tip	R

			When the gripper is perpendicular to the table:
			gripper base and left/right finger tips joint has distance in z axis of 0.21m
			finger tip joints and table top has 0.025 m distance in z axis
			table top has z coordinate of -0.195, gripper base has z coordinate of 0.04m,
			end effector has z coordinate of 0.065m
			The table is -0.2~1.2m in x, and -0.5~0.5 in y
			'''

			# load table and obj
			self.tableUid = self._p.loadURDF(os.path.join(pybullet_data.getDataPath(), "table/table.urdf"),
											 [0.5, 0.0, self.config.tableZ],
											 [0.0, 0.0, 0.0, 1.0])

			for i in range(self.config.taskNum):
				objPath = os.path.join(self.config.mediaPath, 'objects', self.config.objList[i]+'.urdf')
				self.objUidList.append(self._p.loadURDF(objPath))

			if self.config.render:
				start=[self.config.xMin, self.config.yMin, self.config.objZ]
				end=[self.config.xMax, self.config.yMin, self.config.objZ]
				self.workSpaceDebugLine.append(p.addUserDebugLine(start, end, self.config.rayMissColor, lineWidth=5))
				start = [self.config.xMin, self.config.yMax, self.config.objZ]
				end = [self.config.xMax, self.config.yMax, self.config.objZ]
				self.workSpaceDebugLine.append(p.addUserDebugLine(start, end, self.config.rayMissColor, lineWidth=5))
				start = [self.config.xMax, self.config.yMin, self.config.objZ]
				end = [self.config.xMax, self.config.yMax, self.config.objZ]
				self.workSpaceDebugLine.append(p.addUserDebugLine(start, end, self.config.rayMissColor, lineWidth=5))
				start = [self.config.xMin, self.config.yMin, self.config.objZ]
				end = [self.config.xMin, self.config.yMax, self.config.objZ]
				self.workSpaceDebugLine.append(p.addUserDebugLine(start, end, self.config.rayMissColor, lineWidth=5))

			self.robot._p = self._p
			self.randomization()

		if self.config.ifReset and self.episodeCounter > 0:
			self.randomization()

		ret = self.gen_obs()
		return ret[0]

	def gen_obs(self):
		image= self.robot.get_image()

		if self.config.episodeImgSaveInterval > 0 and self.episodeCounter % self.config.episodeImgSaveInterval == 0:
			# save the images
			imgSave = cv2.resize(image, (self.config.episodeImgSize[1], self.config.episodeImgSize[0]))
			if self.config.episodeImgSize[2] == 3:
				imgSave = cv2.cvtColor(imgSave, cv2.COLOR_RGB2BGR)
			fileName = str(self.givenSeed) + 'out' + str(self.envStepCounter) + '.jpg'
			cv2.imwrite(os.path.join(self.config.episodeImgSaveDir, fileName), imgSave)
		s = self.robot.calc_state()

		# at training time, generate triplets from time to time
		get_negative = self.config.RLTrain and self.np_random.rand() > self.config.pretextModelUpdatePairProb

		# sound_positive: the current sound heard by the agent
		# sound_negative: the sound that is not the current heard sound
		# sound_positive_ground_truth: the ground truth label for the sound heard by the agent
		sound_positive, sound_negative, sound_positive_ground_truth, _ = self.get_positive_negative(get_negative)

		if self.envStepCounter==0:
			if self.config.hideObj['mode'] == 'random':
				prob = np.ones((self.config.taskNum,)) / (self.config.taskNum - self.config.hideObj['hideNum'])
				prob[self.hideObjIdx] = 0.
				self.goalObjIdx = self.np_random.choice(self.config.taskNum, replace=False, p=prob)

			elif self.config.hideObj['mode']=='fix':
				prob = np.ones((self.config.taskNum,)) / (self.config.taskNum - len(self.config.hideObj['hideIdx']))
				prob[self.config.hideObj['hideIdx']] = 0.
				self.goalObjIdx = self.np_random.choice(self.config.taskNum, replace=False, p=prob)

			# all 4 objects are present
			elif self.config.hideObj['mode']=='none':
				# randomly select an object
				if self.config.RLTrain or self.config.render:
					self.goalObjIdx = self.np_random.randint(0, self.config.taskNum)

				else:
					idx = np.where(self.size_per_class_cumsum <= self.episodeCounter)[0]
					self.goalObjIdx = 0 if len(idx) == 0 else int(idx.max() + 1)
			else:
				raise NotImplementedError

			self.goal_sound, self.goal_audio = self.audio.genSoundFeat(objIndx=self.goalObjIdx, featType='MFCC',
																   rand_fn=self.np_random.randint)
			self.ground_truth = np.int32(self.goalObjIdx)
			if self.config.render or self.config.RLTrain==False:
				if self.goal_audio is not None and self.config.render:
					sd.play(self.goal_audio, self.audio.fs)
				print('Goal object is', self.goalObjIdx)

		# Observations are dictionaries containing:
		# - an image (partially observable view of the environment)
		# - a sound mfcc feature (the name of the object if the agent sees it or empty)

		obs = {
			'image': np.transpose(image, (2, 0, 1)),
			'goal_sound': self.goal_sound,
			'current_sound': sound_positive,
			'robot_pose': np.array([s['eeState'][0], s['eeState'][1]]),
			'ground_truth': self.ground_truth,
		}

		return obs, s, sound_positive, sound_negative, sound_positive_ground_truth


	def step(self, action):
		action=np.array(action)
		if self.config.RLManualControl:
			dx, dy, dz = 0.0, 0.0, 0.0
			k=p.getKeyboardEvents()
			if p.B3G_UP_ARROW in k:
				dx=-0.02
			if p.B3G_DOWN_ARROW in k:
				dx=0.02
			if p.B3G_LEFT_ARROW in k:
				dy=-0.02
			if p.B3G_RIGHT_ARROW in k:
				dy=0.02
			act=[dx, dy, dz]
		else:
			dv = 0.02
			dx = float(np.clip(action[0],-1,+1)) * dv  # move in x direction
			dy = float(np.clip(action[1],-1,+1)) * dv  # move in y direction
			dz = 0.
			act = [dx, dy, dz]


		real_action = self.robot.applyAction(act, controlMethod=self.config.RLRobotControl)
		self.scene.global_step()
		self.envStepCounter = self.envStepCounter + 1

		obs, s, sound_positive, sound_negative, sound_positive_ground_truth = self.gen_obs()

		infoDict = {}

		r = [self.rewards()]  # calculate reward
		self.reward = sum(r)
		self.episodeReward = self.episodeReward + self.reward
		self.done = self.termination(s)

		if sound_negative is not None:
			infoDict['image'] = obs['image']
			infoDict['sound_positive'] = sound_positive.astype(np.float32)
			infoDict['sound_negative'] = sound_negative.astype(np.float32)
			infoDict['ground_truth'] = np.array([sound_positive_ground_truth], dtype=np.int32)

		# at test time, perform rayTest and put performance into the infoDict
		if not self.config.RLTrain:
			rayTest = self.robot.ray_test(self.objUidList)
			contactRays = [True if rayTest == Uid else False for Uid in self.objUidList]
			if self.config.RLTask=='approach':
				if contactRays[self.goalObjIdx]:
					self.goal_area_count=self.goal_area_count+1
			elif self.config.RLTask=='avoid':
				if sum(contactRays)==1 and contactRays[self.goalObjIdx]==False:
					self.goal_area_count = self.goal_area_count + 1
			else:
				raise NotImplementedError
			if self.done:
				infoDict['goal_area_count']=self.goal_area_count
				print('goal area count', self.goal_area_count)
				self.goal_area_count = 0

		if self.config.record:
			pass

		if self.config.render:
			eeState = self._p.getLinkState(self.robot.robot_ids, self.config.endEffectorIndex)[0]
			start = [eeState[0] - 0.1, eeState[1], eeState[2]]
			end = [eeState[0] + 0.1, eeState[1], eeState[2]]

		return obs, self.reward, self.done, infoDict  # reset will be called if done

	def termination(self, s):
		if self.envStepCounter >= self.maxSteps:
			return True
		return False

	def rewards(self, *args):
		"""
		The reward will be defined as
		Part 1: the l2 distance from the gripper base to the object, d1, plus Lorentzian ρ-function
		"""
		return 0.


class Recorder(object):
	def __init__(self):
		self.positionList = []  # recorded in calc_state()

		self.reward = []  # records Episode Reward
		self.in_area = []  # # of time steps in the vincenty of the target block
		self.choice = []  # target block: 0, 1, 2, 3
		self.final_dist = []

		self.episodeInitNum=0
		self.saveTo=None

	def saveEpisode(self,episodeCounter):
		savePath = os.path.join(self.saveTo, 'ep' + str(self.episodeInitNum + episodeCounter))
		if not os.path.exists(savePath):
			os.makedirs(savePath)
		position = pd.DataFrame({'x': np.array(self.positionList)[:, 0], 'y': np.array(self.positionList)[:, 1]})

		# episode info
		reward = pd.DataFrame({'target block': np.array(self.choice), 'reward': np.array(self.reward),
							   'num timesteps in area': np.array(self.in_area), 'final_dist': self.final_dist})

		position.to_csv(os.path.join(savePath, 'position.csv'), index=False)
		reward.to_csv(os.path.join(savePath,'results.csv'), index=False)

		print("csv written")
		self.clear()

	def clear(self):
		self.positionList = []  # recorded in calc_state() [x,y]
		self.reward = []  # records Episode Reward
		self.in_area = []
		self.choice = []
		self.final_dist = []
