import os
import numpy as np
import cv2
import sounddevice as sd
import gym
from ..RSI2.RL_env_RSI2 import RLEnvRSI2


class PretextEnvRSI2(RLEnvRSI2):
	def __init__(self):
		RLEnvRSI2.__init__(self)

		# update observation space
		d = {
			'image': gym.spaces.Box(low=0, high=255, shape=self.config.img_dim, dtype='uint8'),
			'sound_positive': gym.spaces.Box(low=-np.inf, high=np.inf,
											 shape=self.config.sound_dim, dtype=np.float32),
			'sound_negative': gym.spaces.Box(low=-np.inf, high=np.inf,
											 shape=self.config.sound_dim, dtype=np.float32),
			'ground_truth': gym.spaces.Box(low=0, high=4, shape=(1,), dtype=np.int32)
		}

		self.observation_space = gym.spaces.Dict(d)
		self.maxSteps = self.config.pretextEnvMaxSteps

		# setup action space
		high = np.ones(self.config.pretextActionDim)
		self.action_space = gym.spaces.Box(-high, high, dtype=np.float32)

	def gen_obs(self):
		"""
		:return: a dict containing various type of observations
		"""
		image=self.robot.get_image()

		if self.config.episodeImgSaveInterval > 0 and self.episodeCounter % self.config.episodeImgSaveInterval == 0:
			# save the images
			imgSave = cv2.resize(image, (self.config.episodeImgSize[1], self.config.episodeImgSize[0]))
			if self.config.episodeImgSize[2] == 3:
				imgSave = cv2.cvtColor(imgSave, cv2.COLOR_RGB2BGR)
			fileName = str(self.givenSeed) + 'out' + str(self.envStepCounter) + '.jpg'
			cv2.imwrite(os.path.join(self.config.episodeImgSaveDir, fileName), imgSave)

		s = self.robot.calc_state(self.objPoseList, self.objList, self.objInScene)

		rayTest = self.robot.ray_test(self.objUidList)
		numContactRays = [True if any((rayTest == Uid)[1:-1]) else False for Uid in self.objUidList]

		inViewList = [np.inf] * self.config.taskNum
		# check which object is in agent's view
		for i in range(len(self.objList)):
			if numContactRays[i] and 0 <= s['dist'][i]:
				inViewList[i] = s['dist'][i]  # object not being seen has very large distance

		sound_positive, sound_negative, ground_truth, positive_audio = self.get_positive_negative(inViewList)
		if self.config.render and positive_audio is not None:
			sd.play(positive_audio, self.audio.fs)

		# Observations are dictionaries containing:
		# - an image (partially observable view of the environment)
		# - a sound mfcc feature (the name of the object if the agent sees it or empty)

		obs = {
			'image': np.transpose(image, (2, 0, 1)), # for PyTorch convolution,
			'sound_positive': sound_positive,
			'sound_negative': sound_negative,
			'ground_truth': ground_truth,
		}

		return obs, s

	def step(self, action):
		action=np.array(action)
		if self.config.record and self.config.loadAction: # replace network output with other actions if needed
			action = self.episodeRecorder.loadedAction[self.envStepCounter, :]

		act=[]
		if self.config.pretextRobotControl in ['pointFollower', 'setPose']:
			x = y = theta = None
			# action dimension should be 3
			max_xy=self.config.xyMax
			if self.config.pretextManualControl:
				x = np.clip(float(input('desired x:')), -max_xy, max_xy)
				y = np.clip(float(input('desired y:')), -max_xy, max_xy)
				theta = np.clip(float(input('desired theta:')), -np.pi, np.pi)
			else:
				goodPose=False
				while not goodPose:
					x,y = self.np_random.uniform(low=-max_xy, high=max_xy, size=(2,))
					theta=self.np_random.uniform(low=-np.pi, high=np.pi)
					if self.isFeasiblePose(new_obj=('robot', [x, y, theta]), currentPlan=self.objPoseList,
										   objList=self.objList, objInScene=self.objInScene):
						goodPose=True

			act=[x,y,theta]
		else:
			raise NotImplementedError

		real_action=self.robot.apply_action(act, controlMethod=self.config.pretextRobotControl)
		self.scene.global_step()
		self.envStepCounter = self.envStepCounter + 1

		obs, s=self.gen_obs()

		r =self.rewards() # calculate reward
		self.reward = sum(r)
		self.episodeReward = self.episodeReward + self.reward
		self.done = self.termination()

		infoDict={}
		if self.config.record:
			self.episodeRecorder.wheelVelList.append(real_action)
			self.episodeRecorder.actionList.append(list(action))

		return obs, self.reward, self.done, infoDict # reset will be called if done
