import os
import numpy as np
import cv2
import sounddevice as sd

import gym
from ..RSI2.RL_env_RSI2 import RLEnvRSI2

class PretextEnvRSI1(RLEnvRSI2):
	def __init__(self):
		RLEnvRSI2.__init__(self)

		d = {
			'image': self.observation_space['image'],
			'goal_sound': self.observation_space['goal_sound'],
			'goal_sound_label': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.config.taskNum,), dtype=np.int32),
			# the ground_truth is the id of possible inSight result
			'ground_truth': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.int32),
			# zero-one encoding indicating which objects are being seen
			'inSight': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.config.taskNum,), dtype=np.int32),
			# whether the goal object is being seen
			'exi': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.int32),
		}

		self.observation_space = gym.spaces.Dict(d)
		self.maxSteps = self.config.pretextEnvMaxSteps

		# setup action space
		high = np.ones(self.config.pretextActionDim)
		self.action_space = gym.spaces.Box(-high, high, dtype=np.float32)

	def gen_obs(self):
		"""
		:return: a dict containing various type of observations
		"""
		image=self.robot.get_image()
		self.saveEpisodeImage(image)

		rayTest = self.robot.ray_test(self.objUidList)
		contactRays = [True if any((rayTest == Uid)[1:-1]) else False for Uid in self.objUidList]

		if self.envStepCounter == 0:
			self.goalObjIdx = self.np_random.randint(0, self.config.taskNum)
			self.goal_sound, self.goal_audio = self.audio.genSoundFeat(objIndx=self.goalObjIdx, featType='MFCC',
																	   rand_fn=self.np_random.randint)

			if self.goal_audio is not None and self.config.render:
				sd.play(self.goal_audio, self.audio.fs)
				print('Goal object is', self.goalObjIdx)

		goal_sound_label = [0] * self.config.taskNum
		goal_sound_label[self.goalObjIdx] = 1

		inSight = np.array(contactRays, dtype=np.int32)
		exi = inSight[self.goalObjIdx]

		inSightSum = sum(inSight)
		if inSightSum == 0:
			gt = self.config.taskNum
		else:
			gt = self.np_random.choice(np.where(inSight == 1)[0])

		obs = {
			'image': np.transpose(image, (2, 0, 1)), # for PyTorch convolution,
			'goal_sound':self.goal_sound,
			'goal_sound_label': goal_sound_label,
			'ground_truth': int(gt),
			'inSight': inSight,
			'exi': np.array(int(exi)),

		}

		return obs, None

	def step(self, action):
		action=np.array(action)

		act=[]
		if self.config.pretextRobotControl in ['pointFollower', 'setPose']:
			x = y = theta = None
			# action dimension should be 3
			max_xy=self.config.xyMax
			if self.config.pretextManualControl:
				x = np.clip(float(input('desired x:')), -max_xy, max_xy)
				y = np.clip(float(input('desired y:')), -max_xy, max_xy)
				theta = np.clip(float(input('desired theta:')), -np.pi, np.pi)
			else:
				goodPose=False
				while not goodPose:
					x,y = self.np_random.uniform(low=-max_xy, high=max_xy, size=(2,))
					theta=self.np_random.uniform(low=-np.pi, high=np.pi)
					if self.isFeasiblePose(new_obj=('robot', [x, y, theta]), currentPlan=self.objPoseList,
										   objList=self.objList, objInScene=self.objInScene):
						goodPose=True

			act=[x,y,theta]
		else:
			raise NotImplementedError

		real_action=self.robot.apply_action(act, controlMethod=self.config.pretextRobotControl)
		self.scene.global_step()
		self.envStepCounter = self.envStepCounter + 1

		obs,_=self.gen_obs()

		self.reward = 0.
		self.episodeReward = self.episodeReward + self.reward
		self.done = self.termination()

		infoDict={}

		return obs, self.reward, self.done, infoDict # reset will be called if done
