from .RL_env_RSI2 import RLEnvRSI2
from ..robot_manipulators import *
import cv2
import sounddevice as sd


class PretextEnvRSI2(RLEnvRSI2):
	def __init__(self):
		RLEnvRSI2.__init__(self)

		# observation space
		d = {
			'image': self.observation_space['image'],
			'sound_positive': gym.spaces.Box(low=-np.inf, high=np.inf, shape=self.config.sound_dim, dtype=np.float32),
			'sound_negative': gym.spaces.Box(low=-np.inf, high=np.inf, shape=self.config.sound_dim, dtype=np.float32),
			'ground_truth': self.observation_space['ground_truth']
		}

		self.observation_space = gym.spaces.Dict(d)
		self.maxSteps = self.config.pretextEnvMaxSteps

		# setup action space
		high = np.ones(self.config.pretextActionDim)
		self.action_space = gym.spaces.Box(-high, high, dtype=np.float32)

	def gen_obs(self):
		image= self.robot.get_image()
		if self.config.episodeImgSaveInterval > 0 and self.episodeCounter % self.config.episodeImgSaveInterval == 0:
			# save the images
			imgSave = cv2.resize(image, (self.config.episodeImgSize[1], self.config.episodeImgSize[0]))
			if self.config.episodeImgSize[2] == 3:
				imgSave = cv2.cvtColor(imgSave, cv2.COLOR_RGB2BGR)
			fileName = str(self.givenSeed) + 'out' + str(self.envStepCounter) + '.jpg'
			cv2.imwrite(os.path.join(self.config.episodeImgSaveDir, fileName), imgSave)

		s = self.robot.calc_state()

		sound_positive, sound_negative, ground_truth, positive_audio=self.get_positive_negative()
		if self.config.render and positive_audio is not None:
			sd.play(positive_audio, self.audio.fs)

		obs = {
			'image': np.transpose(image, (2, 0, 1)),  # for PyTorch convolution,
			'sound_positive': sound_positive,
			'sound_negative': sound_negative,
			'ground_truth': ground_truth,
		}

		return obs, s

	def step(self, action):
		action=np.array(action)
		dx, dy, dz = 0.0, 0.0, 0.0

		dx = self.np_random.uniform(-0.3, 0.3)
		dy = self.np_random.uniform(-0.4, 0.4)

		act = [dx, dy, dz]
		real_action = self.robot.applyAction(act, controlMethod=self.config.pretextRobotControl)

		self.scene.global_step()
		self.envStepCounter = self.envStepCounter + 1

		obs, s = self.gen_obs()

		r = [self.rewards()]  # calculate reward
		self.reward = sum(r)
		self.episodeReward = self.episodeReward + self.reward
		self.done = self.termination(s)

		infoDict = {}
		if self.config.record:
			pass

		if self.config.render:
			eeState = self._p.getLinkState(self.robot.robot_ids, self.config.endEffectorIndex)[0]
			start = [eeState[0], eeState[1]-0.1, eeState[2]]
			end = [eeState[0], eeState[1]+0.1, eeState[2]]

		return obs, self.reward, self.done, infoDict  # reset will be called if done

	def termination(self, s):
		if self.envStepCounter >= self.maxSteps:
			return True
		return False

	def rewards(self):
		reward=0
		return reward
