from ..RSI2.RL_env_RSI2 import RLEnvRSI2
from ..robot_manipulators import *
import cv2
import sounddevice as sd


class RLEnvRSI1(RLEnvRSI2):
	def __init__(self):
		RLEnvRSI2.__init__(self)

		# observation space
		d = {
			'image': self.observation_space['image'],
			'goal_sound': gym.spaces.Box(low=-np.inf, high=np.inf, shape=self.config.sound_dim, dtype=np.float32),
			'robot_pose': gym.spaces.Box(low=-np.inf, high=np.inf,shape=(2,), dtype=np.float32),
			'ground_truth': self.observation_space['ground_truth'], # 0: left most block, 3: right most block
		}

		self.observation_space = gym.spaces.Dict(d)

	def gen_obs(self):
		image= self.robot.get_image()
		self.saveEpisodeImage(image)

		s = self.robot.calc_state()

		if self.envStepCounter==0:
			# randomly select an object
			if self.config.RLTrain or self.config.render:
				self.goalObjIdx = self.np_random.randint(0, self.config.taskNum)

			else:
				idx = np.where(self.size_per_class_cumsum <= self.episodeCounter)[0]
				self.goalObjIdx = 0 if len(idx) == 0 else int(idx.max() + 1)

			self.goal_sound, self.goal_audio = self.audio.genSoundFeat(objIndx=self.goalObjIdx, featType='MFCC',
																   rand_fn=self.np_random.randint)
			self.ground_truth = np.int32(self.goalObjIdx)
			if self.config.render or self.config.RLTrain==False:
				if self.goal_audio is not None and self.config.render:
					sd.play(self.goal_audio, self.audio.fs)
				print('Goal object is', self.goalObjIdx)
		else:
			self.goal_sound=np.ones_like(self.goal_sound)*np.inf

		obs = {
			'image': np.transpose(image, (2, 0, 1)),
			'goal_sound': self.goal_sound,
			'robot_pose': np.array([s['eeState'][0], s['eeState'][1]]),
			'ground_truth': self.ground_truth,
		}

		return obs, s

	def step(self, action):
		action=np.array(action)
		if self.config.RLManualControl:
			dx, dy, dz = 0.0, 0.0, 0.0
			k=p.getKeyboardEvents()
			if p.B3G_UP_ARROW in k:
				dx=-0.02
			if p.B3G_DOWN_ARROW in k:
				dx=0.02
			if p.B3G_LEFT_ARROW in k:
				dy=-0.02
			if p.B3G_RIGHT_ARROW in k:
				dy=0.02
			act=[dx, dy, dz]
		else:
			dv = 0.02
			dx = float(np.clip(action[0],-1,+1)) * dv  # move in x direction
			dy = float(np.clip(action[1],-1,+1)) * dv  # move in y direction
			dz = 0.
			act = [dx, dy, dz]

		real_action = self.robot.applyAction(act, controlMethod=self.config.RLRobotControl)
		self.scene.global_step()
		self.envStepCounter = self.envStepCounter + 1

		obs, s = self.gen_obs()

		infoDict = {}

		r = self.rewards(s) # calculate reward
		self.reward = sum(r)
		self.episodeReward = self.episodeReward + self.reward
		self.done = self.termination(s)

		# at test time, perform rayTest and put performance into the infoDict
		if not self.config.RLTrain:
			rayTest = self.robot.ray_test(self.objUidList)
			contactRays = [True if rayTest == Uid else False for Uid in self.objUidList]

			if contactRays[self.goalObjIdx]:
				self.goal_area_count=self.goal_area_count+1

			if self.done:
				infoDict['goal_area_count']=self.goal_area_count
				print('goal area count', self.goal_area_count)
				self.goal_area_count = 0

		if self.config.record:
			pass

		return obs, self.reward, self.done, infoDict  # reset will be called if done

	def termination(self, s):
		if self.envStepCounter >= self.maxSteps:
			return True
		return False

	def rewards(self,s):
		block_loc=[]
		for i in range(self.config.taskNum):
			blockPos, _ = self._p.getBasePositionAndOrientation(self.objUidList[i])
			block_loc.append(blockPos)

		num = self.goalObjIdx
		eePos = s['eeState']

		w1 = 0.5
		v1 = 0.1
		alpha1 = 0.00005
		offset1 = 3.5

		d = np.linalg.norm(np.array([block_loc[num][0], block_loc[num][1]]) - np.array([eePos[0], eePos[1]]))
		r1 = -w1 * d - v1 * (np.log(np.square(d) + alpha1) + offset1)

		rayTest = self.robot.ray_test(self.objUidList)
		contactRays = [True if rayTest == Uid else False for Uid in self.objUidList]
		if contactRays[self.goalObjIdx]:
			r2 = 1.0
		else:
			r2 = 0.

		return r1*self.config.potentialRewardWeight,r2*self.config.goalRewardWeight
