import os
import numpy as np
import cv2
import sounddevice as sd
import gym
from ..RSI2.RL_env_RSI2 import RLEnvRSI2


class PretextEnvRSI3(RLEnvRSI2):
	def __init__(self):
		RLEnvRSI2.__init__(self)

		# update observation space
		d = {
			'image': gym.spaces.Box(low=0, high=255, shape=self.config.img_dim, dtype='uint8'),
			'ground_truth': gym.spaces.Box(low=0, high=self.config.taskNum+1, shape=(1,), dtype=np.int32),
		}

		self.observation_space = gym.spaces.Dict(d)
		self.maxSteps = self.config.pretextEnvMaxSteps

		# setup action space
		high = np.ones(self.config.pretextActionDim)
		self.action_space = gym.spaces.Box(-high, high, dtype=np.float32)

	def gen_obs(self):
		"""
		:return: a dict containing various type of observations
		"""
		image=self.robot.get_image()

		self.saveEpisodeImage(image)

		s = self.robot.calc_state(self.objPoseList, self.objList, self.objInScene)

		rayTest = self.robot.ray_test(self.objUidList)
		numContactRays = [True if any((rayTest == Uid)[1:-1]) else False for Uid in self.objUidList]

		inViewList = [np.inf] * self.config.taskNum
		# check which object is in agent's view
		for i in range(len(self.objList)):
			if numContactRays[i] and 0 <= s['dist'][i]:
				inViewList[i] = s['dist'][i]  # object not being seen has very large distance
		sound_positive, sound_negative, ground_truth, positive_audio = self.get_positive_negative(inViewList,
																								  get_negative=False,
																								  generate_audio=False)

		if self.config.render and positive_audio is not None:
			sd.play(positive_audio, self.audio.fs)

		# Observations are dictionaries containing:
		# - an image (partially observable view of the environment)
		# - a sound mfcc feature (the name of the object if the agent sees it or empty)

		obs = {
			'image': np.transpose(image, (2, 0, 1)), # for PyTorch convolution,
			'ground_truth': ground_truth,
		}

		return obs, s

	def step(self, action):
		action=np.array(action)
		if self.config.record and self.config.loadAction: # replace network output with other actions if needed
			action = self.episodeRecorder.loadedAction[self.envStepCounter, :]

		act=[]
		if self.config.pretextRobotControl in ['pointFollower', 'setPose']:
			x = y = theta = None
			# action dimension should be 3
			max_xy=self.config.xyMax
			if self.config.pretextManualControl:
				x = np.clip(float(input('desired x:')), -max_xy, max_xy)
				y = np.clip(float(input('desired y:')), -max_xy, max_xy)
				theta = np.clip(float(input('desired theta:')), -np.pi, np.pi)
			else:
				goodPose=False
				while not goodPose:
					x,y = self.np_random.uniform(low=-max_xy, high=max_xy, size=(2,))
					theta=self.np_random.uniform(low=-np.pi, high=np.pi)
					if self.isFeasiblePose(new_obj=('robot', [x, y, theta]), currentPlan=self.objPoseList,
										   objList=self.objList, objInScene=self.objInScene):
						goodPose=True

			act=[x,y,theta]
		else:
			raise NotImplementedError

		real_action=self.robot.apply_action(act, controlMethod=self.config.pretextRobotControl)
		self.scene.global_step()
		self.envStepCounter = self.envStepCounter + 1

		obs, s=self.gen_obs()

		r =self.rewards() # calculate reward
		self.reward = sum(r)
		self.episodeReward = self.episodeReward + self.reward
		self.done = self.termination()

		infoDict={}
		if self.config.record:
			self.episodeRecorder.wheelVelList.append(real_action)
			self.episodeRecorder.actionList.append(list(action))

		return obs, self.reward, self.done, infoDict # reset will be called if done
