from ..scene_abstract import SingleRobotEmptyScene
from ..env_bases import BaseEnv
from ..robot_manipulators import *
import cv2
import sounddevice as sd
from cfg import main_config
from Envs.audioLoader import audioLoader
import pandas as pd


class RLEnvRSI2(BaseEnv):
	def __init__(self):

		self.config = main_config()
		self.audio = audioLoader(config=self.config)
		self.robot = KinovaGen3(config=self.config)
		self.robotID = None

		d = {
			'image': gym.spaces.Box(low=0, high=255, shape=self.config.img_dim, dtype='uint8'),
			'goal_sound': gym.spaces.Box(low=-np.inf, high=np.inf, shape=self.config.sound_dim, dtype=np.float32),
			'current_sound': gym.spaces.Box(low=-np.inf, high=np.inf, shape=self.config.sound_dim, dtype=np.float32),
			'robot_pose': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
			'ground_truth': gym.spaces.Box(low=0, high=4, shape=(1,), dtype=np.int32)
		}
		self.observation_space = gym.spaces.Dict(d)
		self.maxSteps = self.config.RLEnvMaxSteps

		# setup action space
		high = np.ones(self.config.RLActionDim)
		self.action_space = gym.spaces.Box(-high, high, dtype=np.float32)

		BaseEnv.__init__(self, self.robot,
						 config=self.config,
						 render=self.config.render,
						 action_space=self.action_space,
						 observation_space=self.observation_space)

		# models
		self.objList = self.config.objList
		self.config.taskNum = len(self.objList)
		self.objUidList = []  # PyBullet Uid for each object

		# each spot associates with a list [x,y,yaw]. The order will be same with self.objList
		# [[x,y,yaw],[x,y,yaw],...]
		self.objPoseList = []  # objects position and orientation information
		self.tableUid = None
		self.baseUid=None
		self.wallID=None

		self.texPath = os.path.join(self.config.commonMediaPath, 'texture')
		self.textureList = []

		self.workSpaceDebugLine = []
		self.eePositionDebugLine = None

		self.hideObjIdx = []

		self.goalObjIdx = None
		self.goal_sound = None
		self.ground_truth = None # the ground truth label for the goal sound
		self.goal_audio = None
		self.rayTestResult=None
		self.externalCamEyePosition=None
		self.externalCamTargetPosition=None

		self.size_per_class = np.zeros((len(self.config.objList),), dtype=np.int64)
		for key in self.config.soundSource['size']:
			self.size_per_class = self.size_per_class + self.config.soundSource['size'][key]
		self.size_per_class_cumsum = np.cumsum(self.size_per_class)

		self.goal_area_count=0

		if self.config.record:
			self.episodeRecorder = Recorder()  # used for recording data
			self.episodeRecorder.saveTo = self.config.recordSaveDir


	def create_single_player_scene(self, bullet_client):
		"""
		Setup physics engine and simulation
		:param bullet_client:
		:return:
		"""

		return SingleRobotEmptyScene(bullet_client, gravity=(0, 0, -9.8),
									 timestep=self.timeStep, frame_skip=self.config.frameSkip,
									 render=self.config.render)

	def loadTex(self):
		texList = os.listdir(self.texPath)
		idx = np.arange(len(texList))
		self.np_random.shuffle(idx)

		# load texture for walls
		for i in range(self.config.numTexture):
			# key=fileName, val=textureID. If we have already loaded the texture, no need to reload and drain the memory
			texID = self._p.loadTexture(os.path.join(self.texPath, texList[idx[i]]))
			self.textureList.append(texID)

	def rand_rgb(self, val):
		"""
		:param val:
		:return: a list that contains r, g ,b values which satisfies following condition
				r + g + b >= 2 and r <= 1 and g <= 1 and b <= 1
		"""
		# generate 3 numbers in [0,1)
		rgb = self.np_random.rand(3)
		rgb = (1 - val) * rgb + val
		return rgb

	def changeWallTexture(self, wallID):
		texID = self.np_random.choice(self.textureList)
		r, g, b = self.rand_rgb(2. / 3)  # generate random rgb values
		self._p.changeVisualShape(wallID, -1, textureUniqueId=texID,
								  rgbaColor=[r, g, b, 1])

	def randomization(self):
		# randomize the locations of the blocks
		orn = self._p.getQuaternionFromEuler([0, 0, 0])  # do not randomize the orientation
		randomx = self.np_random.uniform(self.config.xMin+0.05, self.config.xMax-0.05) # (0.35,0.5)
		randomy = self.np_random.uniform(self.config.yMin+0.05, self.config.yMax-0.45) # (-0.3,-0.25)

		if self.config.hideObj['mode'] == 'random':
			self.hideObjIdx = self.np_random.choice(self.config.taskNum, size=self.config.hideObj['hideNum'], replace=False)
			for i in range(len(self.objUidList)):
				x =1. if i in self.hideObjIdx else randomx
				self._p.resetBasePositionAndOrientation(self.objUidList[i],
														[x, randomy + i * self.config.objInterval, self.config.objZ],
														orn)

		elif self.config.hideObj['mode']=='fix':
			for i in range(len(self.objUidList)):
				x = 1. if i in self.config.hideObj['hideIdx'] else randomx
				self._p.resetBasePositionAndOrientation(self.objUidList[i],
														[x, randomy + i * self.config.objInterval, self.config.objZ],
														orn)

		elif self.config.hideObj['mode']=='none':
			for i in range(len(self.objUidList)):
				if self.config.domainRandomization:
					x=randomx+self.np_random.uniform(-0.01, 0.01)
					y=randomy+self.np_random.uniform(-0.01, 0.01)
				else:
					x=randomx
					y=randomy
				self._p.resetBasePositionAndOrientation(self.objUidList[i],
														[x, y + i * self.config.objInterval, self.config.objZ],
														orn)

		else:
			raise NotImplementedError

		eePositionX=self.np_random.uniform(self.config.xMin+0.05, self.config.xMax-0.05)
		eePositionY=self.np_random.uniform(self.config.yMin+0.05, self.config.yMax-0.05)
		self.robot.robot_specific_reset(eePositionX, eePositionY, self.config.endEffectorHeight)

		if self.config.domainRandomization:
			# change camera
			self.externalCamEyePosition=list(np.array(self.config.externalCamEyePosition)+self.np_random.uniform(-0.02,0.02, 3))
			self.externalCamTargetPosition=list(np.array(self.config.externalCamTargetPosition)+self.np_random.uniform(-0.02,0.02, 3))
			# change the texture of the background wall
			self.changeWallTexture(self.wallID)
			# change block texture
			texID = self.np_random.choice(self.textureList)
			for i in range(self.config.taskNum):
				self._p.changeVisualShape(self.objUidList[i], -1, textureUniqueId=texID,
										  rgbaColor=[1, 1, 1, 1])
		else:
			self.externalCamEyePosition=self.config.externalCamEyePosition
			self.externalCamTargetPosition=self.config.externalCamTargetPosition

		self._p.stepSimulation() # refresh the simulator. Needed for the ray test

	def get_positive_negative(self, get_negative=True):
		rayTest = self.robot.ray_test(self.objUidList)
		contactRays = [True if rayTest == Uid else False for Uid in self.objUidList]
		positive_audio=None
		sound_negative=None
		if not any(contactRays):  # the end effector hits nothing, no sound is given
			sound_positive = np.zeros(shape=self.config.sound_dim)
			ground_truth = np.int32(self.config.taskNum)
			if get_negative:
				objIndx = self.np_random.randint(0, self.config.taskNum)
				sound_negative, _ = self.audio.genSoundFeat(objIndx=objIndx, featType='MFCC',
															rand_fn=self.np_random.randint)

		else:  # the agent sees an object in self.config.objList
			objIndx_positive = np.argmax(contactRays)  # choose the one with min distance
			ground_truth = np.int32(objIndx_positive)
			sound_positive, positive_audio = self.audio.genSoundFeat(objIndx=objIndx_positive, featType='MFCC',
																	 rand_fn=self.np_random.randint)
			if get_negative:
				objIndx_negative = self.np_random.randint(0, self.config.taskNum)
				if objIndx_positive == objIndx_negative:
					sound_negative = np.zeros(shape=self.config.sound_dim)
				else:
					sound_negative, negative_audio = self.audio.genSoundFeat(objIndx=objIndx_negative, featType='MFCC',
																			 rand_fn=self.np_random.randint)

		return sound_positive, sound_negative, ground_truth, positive_audio

	def envReset(self):
		if self.robot.robot_ids is None:
			# load sound
			# if os is Linux, we already load the data in the parent process. See shmem_vec_env.py
			if len(self.audio.words) == 0:
				self.audio.loadData()
			# load robot
			self.robot.load_model()
			self.robot._p = self._p
			self.robotID=self.robot.robot_ids

			# anchor the robot
			self._p.resetBasePositionAndOrientation(self.robot.robot_ids, [0.0, 0.0, 0.07], [0.0, 0.0, 0.0, 1.0])

			# after we get all the join info we set it to the number of non-fixed joint
			self.robot.numJoints = len(self.robot.jdict)

			# load table and obj
			self.tableUid = self._p.loadURDF(os.path.join(self.config.mediaPath, 'table', 'table.urdf'),
											 [0.66, -0.14, self.config.tableZ],
											 [0.0, 0.0, 0.0, 1.0])
			objPath = os.path.join(self.config.mediaPath, 'objects', 'key.urdf')
			#woodTexID=self._p.loadTexture(os.path.join(self.config.mediaPath, 'objects','wood.png'))
			for i in range(self.config.taskNum):
				self.objUidList.append(self._p.loadURDF(objPath, useFixedBase=1))

			# add a base for the robot
			visualID = self._p.createVisualShape(shapeType=self._p.GEOM_CYLINDER,
												 radius=0.08,
												 length=0.02,
												 visualFramePosition=[0, 0, 0],
												 visualFrameOrientation=self._p.getQuaternionFromEuler([0, 0, 0]),
												 rgbaColor=[0, 0, 0, 1]
												 )

			self.baseUid = self._p.createMultiBody(baseMass=0,
											 baseInertialFramePosition=[0, 0, 0],
											 baseVisualShapeIndex=visualID,
											 basePosition=[0, 0, -0.01],
											 baseOrientation=self._p.getQuaternionFromEuler([0, 0, 0]))

			if self.config.domainRandomization:
				wall_visualID = self._p.createVisualShape(shapeType=self._p.GEOM_BOX,
														  halfExtents=[0.05, 1.5, 0.9],
														  visualFramePosition=[0, 0, 0],
														  visualFrameOrientation=self._p.getQuaternionFromEuler(
															  [0, 0, 0]),
														  rgbaColor=[1, 1, 1, 1]
														  )

				self.wallID = self._p.createMultiBody(baseMass=0,
													  baseInertialFramePosition=[0, 0, 0],
													  baseVisualShapeIndex=wall_visualID,
													  basePosition=[-0.3, -0.2, -0.5],
													  baseOrientation=self._p.getQuaternionFromEuler([0, 0, 0]))
				self.loadTex()

			if self.config.render:
				start=[self.config.xMin, self.config.yMin, self.config.objZ]
				end=[self.config.xMax, self.config.yMin, self.config.objZ]
				self.workSpaceDebugLine.append(p.addUserDebugLine(start, end, self.config.rayMissColor, lineWidth=5))
				start = [self.config.xMin, self.config.yMax, self.config.objZ]
				end = [self.config.xMax, self.config.yMax, self.config.objZ]
				self.workSpaceDebugLine.append(p.addUserDebugLine(start, end, self.config.rayMissColor, lineWidth=5))
				start = [self.config.xMax, self.config.yMin, self.config.objZ]
				end = [self.config.xMax, self.config.yMax, self.config.objZ]
				self.workSpaceDebugLine.append(p.addUserDebugLine(start, end, self.config.rayMissColor, lineWidth=5))
				start = [self.config.xMin, self.config.yMin, self.config.objZ]
				end = [self.config.xMin, self.config.yMax, self.config.objZ]
				self.workSpaceDebugLine.append(p.addUserDebugLine(start, end, self.config.rayMissColor, lineWidth=5))
				self.eePositionDebugLine=p.addUserDebugLine([0,0,0], [0,0,0], self.config.rayMissColor, lineWidth=5)

			self.robot._p = self._p
			self.randomization()

		if self.config.ifReset and self.episodeCounter > 0:
			self.randomization()

		ret = self.gen_obs()
		return ret[0]

	def saveEpisodeImage(self, image):
		if self.config.episodeImgSaveInterval > 0 and self.episodeCounter % self.config.episodeImgSaveInterval == 0:
			# save the images
			imgSave = cv2.resize(image, (self.config.episodeImgSize[1], self.config.episodeImgSize[0]))
			if self.config.episodeImgSize[2] == 3:
				imgSave = cv2.cvtColor(imgSave, cv2.COLOR_RGB2BGR)
			fileName = str(self.givenSeed) + 'out' + str(self.envStepCounter) + '.jpg'
			cv2.imwrite(os.path.join(self.config.episodeImgSaveDir, fileName), imgSave)

	def gen_obs(self):
		image= self.robot.get_image(self.externalCamEyePosition, self.externalCamTargetPosition)
		self.saveEpisodeImage(image)

		s = self.robot.calc_state()

		# at training time, generate triplets from time to time
		get_negative = self.config.RLTrain and self.np_random.rand() > self.config.pretextModelUpdatePairProb

		# sound_positive: the current sound heard by the agent
		# sound_negative: the sound that is not the current heard sound
		# sound_positive_ground_truth: the ground truth label for the sound heard by the agent
		sound_positive, sound_negative, sound_positive_ground_truth, _ = self.get_positive_negative(get_negative)

		if self.envStepCounter==0:
			if self.config.hideObj['mode'] == 'random':
				prob = np.ones((self.config.taskNum,)) / (self.config.taskNum - self.config.hideObj['hideNum'])
				prob[self.hideObjIdx] = 0.
				self.goalObjIdx = self.np_random.choice(self.config.taskNum, replace=False, p=prob)

			elif self.config.hideObj['mode']=='fix':
				if self.config.taskNum==len(self.config.hideObj['hideIdx']):
					self.goalObjIdx = self.np_random.choice(self.config.taskNum)
				else:
					prob = np.ones((self.config.taskNum,)) / (self.config.taskNum - len(self.config.hideObj['hideIdx']))
					prob[self.config.hideObj['hideIdx']] = 0.
					self.goalObjIdx = self.np_random.choice(self.config.taskNum, replace=False, p=prob)

			# all 4 objects are present
			elif self.config.hideObj['mode']=='none':
				# randomly select an object
				if self.config.RLTrain or self.config.render:
					self.goalObjIdx = self.np_random.randint(0, self.config.taskNum)

				else:
					idx = np.where(self.size_per_class_cumsum <= self.episodeCounter)[0]
					self.goalObjIdx = 0 if len(idx) == 0 else int(idx.max() + 1)
			else:
				raise NotImplementedError

			self.goal_sound, self.goal_audio = self.audio.genSoundFeat(objIndx=self.goalObjIdx, featType='MFCC',
																   rand_fn=self.np_random.randint)
			self.ground_truth = np.int32(self.goalObjIdx)
			if self.config.render or self.config.RLTrain==False:
				if self.goal_audio is not None and self.config.render:
					sd.play(self.goal_audio, self.audio.fs)
				print('Goal object is', self.goalObjIdx)

		# Observations are dictionaries containing:
		# - an image (partially observable view of the environment)
		# - a sound mfcc feature (the name of the object if the agent sees it or empty)

		obs = {
			'image': np.transpose(image, (2, 0, 1)),
			'goal_sound': self.goal_sound,
			'current_sound': sound_positive,
			'robot_pose': np.array([s['eeState'][0], s['eeState'][1]]),
			'ground_truth': self.ground_truth,
		}

		if self.config.record:
			self.episodeRecorder.eePositionList.append(s['eeState'])
			self.episodeRecorder.eePositionList_desired.append(self.robot.desiredEndEffectorPos.copy())
			if self.envStepCounter == 0:
				self.episodeRecorder.choice.append(self.goalObjIdx)

		return obs, s, sound_positive, sound_negative, sound_positive_ground_truth

	def keyboardControl(self):
		dx, dy, dz = 0.0, 0.0, 0.0
		k = p.getKeyboardEvents()
		if p.B3G_UP_ARROW in k:
			dx = -0.02
		if p.B3G_DOWN_ARROW in k:
			dx = 0.02
		if p.B3G_LEFT_ARROW in k:
			dy = -0.02
		if p.B3G_RIGHT_ARROW in k:
			dy = 0.02
		return dx, dy, dz

	def step(self, action):
		action=np.array(action)
		if self.config.RLManualControl:
			dx, dy, dz= self.keyboardControl()
		else:
			dv = 0.02
			dx = float(np.clip(action[0],-1,+1)) * dv  # move in x direction
			dy = float(np.clip(action[1],-1,+1)) * dv  # move in y direction
			dz = 0.
		act = [dx, dy, dz]

		real_action = self.robot.applyAction(act, controlMethod=self.config.RLRobotControl)
		self.scene.global_step()
		self.envStepCounter = self.envStepCounter + 1

		obs, s, sound_positive, sound_negative, sound_positive_ground_truth = self.gen_obs()

		infoDict = {}

		r = [self.rewards()]  # calculate reward
		self.reward = sum(r)
		self.episodeReward = self.episodeReward + self.reward
		self.done = self.termination(s)

		if sound_negative is not None:
			infoDict['image'] = obs['image']
			infoDict['sound_positive'] = sound_positive.astype(np.float32)
			infoDict['sound_negative'] = sound_negative.astype(np.float32)
			infoDict['ground_truth'] = np.array([sound_positive_ground_truth], dtype=np.int32)

		# at test time, perform rayTest and put performance into the infoDict
		if not self.config.RLTrain:
			rayTest = self.robot.ray_test(self.objUidList)
			contactRays = [True if rayTest == Uid else False for Uid in self.objUidList]
			if self.config.RLTask=='approach':
				if contactRays[self.goalObjIdx]:
					self.goal_area_count=self.goal_area_count+1
			elif self.config.RLTask=='avoid':
				if sum(contactRays)==1 and contactRays[self.goalObjIdx]==False:
					self.goal_area_count = self.goal_area_count + 1
			else:
				raise NotImplementedError
			if self.done:
				infoDict['goal_area_count']=self.goal_area_count
				print('goal area count', self.goal_area_count)
				self.goal_area_count = 0
				if self.config.record:
					self.episodeRecorder.saveEpisode(self.episodeCounter)

		if self.config.render:
			eeState = self._p.getLinkState(self.robot.robot_ids, self.config.endEffectorIndex)[0]
			start = [eeState[0] - 0.1, eeState[1], eeState[2]]
			end = [eeState[0] + 0.1, eeState[1], eeState[2]]
			p.addUserDebugLine(start, end, self.config.rayMissColor, lineWidth=5, replaceItemUniqueId=self.eePositionDebugLine)

		return obs, self.reward, self.done, infoDict  # reset will be called if done

	def termination(self, s):
		if self.envStepCounter >= self.maxSteps:
			return True
		return False

	def rewards(self):
		return 0.


class Recorder(object):
	def __init__(self):
		self.eePositionList = [] # the measured position of the end effector
		self.eePositionList_desired=[] # the desired position of the end effector
		self.choice = []  # target block: 0, 1, 2, 3
		self.episodeInitNum=0
		self.saveTo=None
		self.loadedAction=None
		self.loadFrom=None

	def saveEpisode(self,episodeCounter):
		savePath = os.path.join(self.saveTo, 'ep' + str(self.episodeInitNum + episodeCounter))
		if not os.path.exists(savePath):
			os.makedirs(savePath)
		position = pd.DataFrame({'x': np.array(self.eePositionList)[:, 0], 'y': np.array(self.eePositionList)[:, 1]})
		desired_position = pd.DataFrame({'x': np.array(self.eePositionList_desired)[:, 0], 'y': np.array(self.eePositionList_desired)[:, 1]})
		# episode info
		choice = pd.DataFrame({'choice': np.array(self.choice)})

		position.to_csv(os.path.join(savePath, 'position.csv'), index=False)
		desired_position.to_csv(os.path.join(savePath,'desired_position.csv'), index=False)
		choice.to_csv(os.path.join(savePath,'choice.csv'), index=False)

		print("csv written")
		self.clear()

	def clear(self):
		self.eePositionList = []  # the measured position of the end effector
		self.eePositionList_desired = []  # the desired position of the end effector
		self.choice = []  # target block: 0, 1, 2, 3

	def loadActions(self):
		print("Reading actions from", self.loadFrom)
		self.loadedAction = pd.read_csv(self.loadFrom)
		self.loadedAction = self.loadedAction.values
