from ..RSI2.RL_env_RSI2 import RLEnvRSI2
from ..robot_manipulators import *
import cv2
import sounddevice as sd


class PretextEnvRSI3(RLEnvRSI2):
	def __init__(self):
		RLEnvRSI2.__init__(self)

		# observation space
		d = {
			'image': gym.spaces.Box(low=0, high=255, shape=self.config.img_dim, dtype='uint8'),
			# the sound will be paired according to ground truth when we train the representation
			'ground_truth': gym.spaces.Box(low=0, high=self.config.taskNum + 1, shape=(1,), dtype=np.int32),
		}

		self.observation_space = gym.spaces.Dict(d)
		self.maxSteps = self.config.pretextEnvMaxSteps

	def gen_obs(self):
		image= self.robot.get_image()
		self.saveEpisodeImage(image)

		s = self.robot.calc_state()

		sound_positive, sound_negative, ground_truth, positive_audio=self.get_positive_negative(get_negative=False,
																								generate_audio=False)
		if self.config.render and positive_audio is not None:
			sd.play(positive_audio, self.audio.fs)

		obs = {
			'image': np.transpose(image, (2, 0, 1)),  # for PyTorch convolution,
			'ground_truth': ground_truth,
		}

		return obs, s

	def step(self, action):
		action=np.array(action)
		dx, dy, dz = 0.0, 0.0, 0.0

		dx = self.np_random.uniform(-0.3, 0.3)
		dy = self.np_random.uniform(-0.4, 0.4)

		act = [dx, dy, dz]
		real_action = self.robot.applyAction(act, controlMethod=self.config.pretextRobotControl)

		self.scene.global_step()
		self.envStepCounter = self.envStepCounter + 1

		obs, s = self.gen_obs()

		r = [self.rewards()]  # calculate reward
		self.reward = sum(r)
		self.episodeReward = self.episodeReward + self.reward
		self.done = self.termination(s)

		infoDict = {}
		if self.config.record:
			pass

		if self.config.render:
			eeState = self._p.getLinkState(self.robot.robot_ids, self.config.endEffectorIndex)[0]
			start = [eeState[0], eeState[1]-0.1, eeState[2]]
			end = [eeState[0], eeState[1]+0.1, eeState[2]]

		return obs, self.reward, self.done, infoDict  # reset will be called if done
