from ..RSI2.RL_env_RSI2 import RLEnvRSI2
from ..env_bases import BaseEnv
from Envs.pybullet.kuka.robot_manipulators import *
import cv2
import sounddevice as sd
import os


class RLEnvRSI3(RLEnvRSI2):
	def __init__(self):
		RLEnvRSI2.__init__(self)

		d = {
			'image': gym.spaces.Box(low=0, high=255, shape=self.config.img_dim, dtype='uint8'),
			'goal_sound': gym.spaces.Box(low=-np.inf, high=np.inf, shape=self.config.sound_dim, dtype=np.float32),
			'current_sound': gym.spaces.Box(low=-np.inf, high=np.inf, shape=self.config.sound_dim, dtype=np.float32),
			'robot_pose': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
			# [0:'left most block', 1:'left block', 2:'right block', 3:'right most block', 4:empty]
			'goal_sound_label': gym.spaces.Box(low=0, high=4, shape=(1,), dtype=np.int32),
			'goal_sound_feat': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.config.representationDim,),
											  dtype=np.float32),
			'image_feat': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.config.representationDim,),
										 dtype=np.float32),
		}
		self.observation_space = gym.spaces.Dict(d)
		self.maxSteps = self.config.RLEnvMaxSteps

	def gen_obs(self):
		image= self.robot.get_image()
		self.saveEpisodeImage(image)

		s = self.robot.calc_state()

		# sound_positive: the current sound heard by the agent
		# sound_negative: the sound that is not the current heard sound
		# sound_positive_ground_truth: the ground truth label for the sound heard by the agent
		sound_positive, sound_negative, sound_positive_ground_truth, _ = self.get_positive_negative(get_negative=False,
																									generate_audio=True)

		if self.envStepCounter==0:
			if self.config.hideObj['mode'] == 'random':
				prob = np.ones((self.config.taskNum,)) / (self.config.taskNum - self.config.hideObj['hideNum'])
				prob[self.hideObjIdx] = 0.
				self.goalObjIdx = self.np_random.choice(self.config.taskNum, replace=False, p=prob)

			elif self.config.hideObj['mode']=='fix':
				prob = np.ones((self.config.taskNum,)) / (self.config.taskNum - len(self.config.hideObj['hideIdx']))
				prob[self.config.hideObj['hideIdx']] = 0.
				self.goalObjIdx = self.np_random.choice(self.config.taskNum, replace=False, p=prob)

			# all 4 objects are present
			elif self.config.hideObj['mode']=='none':
				# randomly select an object
				if self.config.RLTrain or self.config.render:
					self.goalObjIdx = self.np_random.randint(0, self.config.taskNum)

				else:
					idx = np.where(self.size_per_class_cumsum <= self.episodeCounter)[0]
					self.goalObjIdx = 0 if len(idx) == 0 else int(idx.max() + 1)
			else:
				raise NotImplementedError

			self.goal_sound, self.goal_audio = self.audio.genSoundFeat(objIndx=self.goalObjIdx, featType='MFCC',
																   rand_fn=self.np_random.randint)
			self.ground_truth = np.int32(self.goalObjIdx)
			if self.config.render or self.config.RLTrain==False:
				if self.goal_audio is not None and self.config.render:
					sd.play(self.goal_audio, self.audio.fs)
				print('Goal object is', self.goalObjIdx)

		# Observations are dictionaries containing:
		# - an image (partially observable view of the environment)
		# - a sound mfcc feature (the name of the object if the agent sees it or empty)

		obs = {
			'image': np.transpose(image, (2, 0, 1)),
			'goal_sound': self.goal_sound,
			'current_sound': sound_positive,
			'robot_pose': np.array([s['eeState'][0], s['eeState'][1]]),
			'goal_sound_label': self.ground_truth,
			'goal_sound_feat': np.zeros((self.config.representationDim,)),
			'image_feat': np.zeros((self.config.representationDim,))
		}

		return obs, s, sound_positive, sound_negative, sound_positive_ground_truth

	def step(self, action):
		action=np.array(action)
		if self.config.RLManualControl:
			dx, dy, dz = 0.0, 0.0, 0.0
			k=p.getKeyboardEvents()
			if p.B3G_UP_ARROW in k:
				dx=-0.02
			if p.B3G_DOWN_ARROW in k:
				dx=0.02
			if p.B3G_LEFT_ARROW in k:
				dy=-0.02
			if p.B3G_RIGHT_ARROW in k:
				dy=0.02
			act=[dx, dy, dz]
		else:
			dv = 0.02
			dx = float(np.clip(action[0],-1,+1)) * dv  # move in x direction
			dy = float(np.clip(action[1],-1,+1)) * dv  # move in y direction
			dz = 0.
			act = [dx, dy, dz]

		real_action = self.robot.applyAction(act, controlMethod=self.config.RLRobotControl)
		self.scene.global_step()
		self.envStepCounter = self.envStepCounter + 1

		obs, s, sound_positive, sound_negative, sound_positive_ground_truth = self.gen_obs()

		infoDict = {}

		r = [self.rewards()]  # calculate reward
		self.reward = sum(r)
		self.episodeReward = self.episodeReward + self.reward
		self.done = self.termination(s)

		if sound_negative is not None:
			infoDict['image'] = obs['image']
			infoDict['sound_positive'] = sound_positive.astype(np.float32)
			infoDict['sound_negative'] = sound_negative.astype(np.float32)
			infoDict['ground_truth'] = np.array([sound_positive_ground_truth], dtype=np.int32)

		# at test time, perform rayTest and put performance into the infoDict
		if not self.config.RLTrain:
			rayTest = self.robot.ray_test(self.objUidList)
			contactRays = [True if rayTest == Uid else False for Uid in self.objUidList]
			if self.config.RLTask=='approach':
				if contactRays[self.goalObjIdx]:
					self.goal_area_count=self.goal_area_count+1
			elif self.config.RLTask=='avoid':
				if sum(contactRays)==1 and contactRays[self.goalObjIdx]==False:
					self.goal_area_count = self.goal_area_count + 1
			else:
				raise NotImplementedError
			if self.done:
				infoDict['goal_area_count']=self.goal_area_count
				print('goal area count', self.goal_area_count)
				self.goal_area_count = 0

		if self.config.record:
			pass

		if self.config.render:
			eeState = self._p.getLinkState(self.robot.robot_ids, self.config.endEffectorIndex)[0]
			start = [eeState[0] - 0.1, eeState[1], eeState[2]]
			end = [eeState[0] + 0.1, eeState[1], eeState[2]]

		return obs, self.reward, self.done, infoDict  # reset will be called if done
