import os
import numpy as np
import cv2
import sounddevice as sd
import pybullet as p
import gym
from ..RSI2.RL_env_RSI2 import RLEnvRSI2

class RLEnvRSI1(RLEnvRSI2):
	def __init__(self):
		RLEnvRSI2.__init__(self)

		# update observation space
		d = {
			'image': self.observation_space['image'],
			'goal_sound': gym.spaces.Box(low=-np.inf, high=np.inf, shape=self.config.sound_dim, dtype=np.float32),
			'robot_pose': gym.spaces.Box(low=-np.inf, high=np.inf,shape=(2,), dtype=np.float32),
			'ground_truth': self.observation_space['ground_truth'], # 0: cube, 1: sphere, 2:cone, 3: cylinder
			# one-hot encoding indicating which objects are being seen
			'inSight': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.config.taskNum,), dtype=np.float32),
			# whether the goal object is being seen
			'exi': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32),
		}
		self.observation_space = gym.spaces.Dict(d)

		# for calculating the rewards
		self.potential = {'disPotential': [], 'angPotential': []}

	def envReset(self):
		if self.robot.robot_ids is None: # if it is the first run, load all models
			self.loadArena()
			# load objects and put them far away from the arena
			for item in self.objList :
				self.objUidList.append(self.loadObj(os.path.join(self.config.mediaPath, 'objects', item + ".obj"), item))

			# load robot so that robot_ids is not None
			self.robot.load_model()
			self.robot._p=self._p
			self.robot.eyes=self.robot.parts["eyes"] # in the urdf, the link for the camera is called "eyes"
			self.objInScene=np.arange(len(self.objList))
			self.randomization()

		if self.config.ifReset and self.episodeCounter>0:
			self.randomization()

		ret=self.gen_obs()
		obs=ret[0]
		s=ret[1]
		self.potential = {'disPotential': [], 'angPotential': []}  # potential
		self.calc_potential(s)  # get initial potential
		return obs

	def gen_obs(self):
		"""
		:return: a dict containing various type of observations
		"""
		image=self.robot.get_image()

		self.saveEpisodeImage(image)
		s = self.robot.calc_state(self.objPoseList, self.objList, self.objInScene)

		rayTest = self.robot.ray_test(self.objUidList)
		contactRays = [True if any((rayTest == Uid)[1:-1]) else False for Uid in self.objUidList]

		inViewList = [np.inf] * self.config.taskNum
		for i in range(len(self.objList)):
			if contactRays[i] and 0 <= s['dist'][i] <= 0.15:
				inViewList[i] = s['dist'][i]  # object not being seen has inf distance

		if self.envStepCounter==0:
			if self.config.RLTrain:
				self.goalObjIdx = self.np_random.randint(0, self.config.taskNum)
			else:
				idx = np.where(self.size_per_class_cumsum <= self.episodeCounter)[0]
				self.goalObjIdx = 0 if len(idx) == 0 else int(idx.max() + 1)
				self.goalObjIdx = np.clip(self.goalObjIdx, a_min=0, a_max=self.config.taskNum-1)

			self.goal_sound, self.goal_audio = self.audio.genSoundFeat(objIndx=self.goalObjIdx, featType='MFCC',
																	   rand_fn=self.np_random.randint)
			self.ground_truth = np.int32(self.goalObjIdx)
			if self.config.render or self.config.RLTrain == False:
				if self.goal_audio is not None and self.config.render:
					sd.play(self.goal_audio, self.audio.fs)
				print('Goal object is', self.goalObjIdx)

		else:
			self.goal_sound=np.ones_like(self.goal_sound)*np.inf

		inSight = np.array(contactRays, dtype=np.float32)
		exi = inSight[self.goalObjIdx]

		obs = {
			'image': np.transpose(image, (2, 0, 1)),
			'goal_sound': self.goal_sound,
			'robot_pose': np.array([s['motor'][0], s['motor'][1]]),
			'ground_truth': self.ground_truth,
			'inSight': inSight,
			'exi': exi,
		}

		return obs, s, inViewList

	def step(self, action):
		action=np.array(action)

		act=[]
		if self.config.RLRobotControl in ['rotPos']:
			if self.config.RLManualControl:
				k = p.getKeyboardEvents()
				dx=0.
				dtheta=0.0
				if p.B3G_UP_ARROW in k:
					dx = 0.05
				if p.B3G_DOWN_ARROW in k:
					dx = -0.05
				if p.B3G_LEFT_ARROW in k:
					dtheta = 0.15
				if p.B3G_RIGHT_ARROW in k:
					dtheta = -0.15
				act=[dx,dtheta]

			else:
				# action[0] transitional  velocity, action[1] rotational position
				deltaTrans = 0.05
				dTrans = float(np.clip(action[0], -1, +1)) * deltaTrans  # delta transitional velocity

				deltaRot = 0.15
				dRot = float(np.clip(action[1], -1, +1)) * deltaRot  # delta rotational position
				act = [dTrans, dRot]

		else:
			raise NotImplementedError

		real_action = self.robot.apply_action(act,
											  controlMethod=self.config.RLRobotControl,
											  currentPlan=self.objPoseList,
											  objList=self.objList,
											  objInScene=self.objInScene)
		self.scene.global_step()
		self.envStepCounter = self.envStepCounter + 1

		obs, s, inViewList = self.gen_obs()

		r =self.rewards(s) # calculate reward
		self.reward = sum(r)
		self.episodeReward = self.episodeReward + self.reward
		self.done = self.termination()

		infoDict={}

		if not self.config.RLTrain:
			if not np.isinf(inViewList[self.goalObjIdx]):
				self.goal_area_count=self.goal_area_count+1
			if self.done:
				infoDict['goal_area_count']=self.goal_area_count
				print('goal area count', self.goal_area_count)
				self.goal_area_count = 0

		if self.config.record:
			self.episodeRecorder.wheelVelList.append(real_action)
			self.episodeRecorder.actionList.append(list(action))

		return obs, self.reward, self.done, infoDict # reset will be called if done

	def rewards(self, state_dict):
		# know which object is the target
		num = self.goalObjIdx
		# get collision one-hot vector
		collide = state_dict['collide']
		# get distance and angle list
		dis = state_dict['dist']
		d = dis[num]

		ang = state_dict['ang']
		a = abs(ang[num])

		potential_old = self.potential
		self.calc_potential(state_dict)  # update potential

		collisionPenalty = 0
		distanceProgress = 0
		angleProgress = 0
		goalReward = 0

		# collision
		if collide[num] == 1:
			collisionPenalty = -0.1
			self.collision = True
		elif collide[num] == 0 and sum(collide) > 0:
			collisionPenalty = -0.3
			self.collision = True

		# distance
		distance_factor = 50
		distanceProgress = distanceProgress + distance_factor * (
					self.potential['disPotential'][num] - potential_old['disPotential'][num])
		#
		# # angle
		angle_factor = 20
		angleProgress = angleProgress + angle_factor * (
					self.potential['angPotential'][num] - potential_old['angPotential'][num])

		# achieve goal reward
		if 0 <= a <= 0.35:
			goalReward = goalReward + 0.5

		if -0 <= d <= 0.15:
			goalReward = goalReward + 0.5

		# perfect goal configuration without collision
		if (0 <= a <= 0.35) and (-0 <= d <= 0.15):
			goalReward = goalReward + 1

		rewards = [
			collisionPenalty*self.config.collisionPenaltyWeight,
			distanceProgress*self.config.distanceProgressWeight,
			angleProgress*self.config.angleProgressWeight,
			goalReward,
		]

		return rewards

	def calc_potential(self, state_dict):
		# get distance and angle list
		dis = state_dict['dist']
		ang = state_dict['ang']
		# the potential has two parts: a part for distance and a part for angle
		# calculate the potential for all the objects
		self.potential = {'disPotential': [], 'angPotential': []}
		for i in range(len(self.objList)):
			if i in self.objInScene:  # calculate potential if the object is in the arena for this episode
				self.potential['disPotential'].append(-abs(dis[i]))
				self.potential['angPotential'].append(-abs(ang[i]))
			else:  # insert 0 if the object is not in the arena for this episode
				self.potential['disPotential'].append(0)
				self.potential['angPotential'].append(0)
