import os
import numpy as np
import cv2
import sounddevice as sd
import matplotlib.pyplot as plt
import gym
from ..RSI2.RL_env_RSI2 import RLEnvRSI2, Task


class PretextEnvRSI3(RLEnvRSI2):
	def __init__(self):
		RLEnvRSI2.__init__(self)

		# update observation space
		d = {
			'image': gym.spaces.Box(low=0, high=255, shape=self.config.img_dim, dtype='uint8'),
			# the sound will be paired according to ground truth when we train the representation
			'ground_truth': gym.spaces.Box(low=0, high=self.config.taskNum+1, shape=(1,), dtype=np.int32),
		}
		# if the agent sees no obj: we pass image, inf bounding box, empty
		# if the agent sees only 1 obj: we pass image, bounding box, sound +
		# if the agent sees 2 objs: we get 3 samples:
		# 1st sample: image, inf bounding box (no bounding box), empty
		# 2nd sample: image of the first object, bounding box, sound +
		# 3rd sample: image of the second object, bounding box, sound +
		# 2nd and 3rd samples are stored in info to be collected in pretext_RSI3
		# the 1st sample is passed as observation to be collected in pretext_RSI3

		self.observation_space = gym.spaces.Dict(d)
		self.maxSteps = self.config.pretextEnvMaxSteps

		self.visibleDist=self.config.pretextVisibilityDistance
		self.renderInstanceSegmentation = False # for generating bounding box

	def setupTask(self):
		self.domainRandomization()
		if self.task.act=='PickupObject':
			self.pickUpByTask(self.task)

	def get_pos_act(self, obj_in_view):
		if len(self.config.allTasks[self.task.loc][obj_in_view]) == 1:
				act = self.config.allTasks[self.task.loc][obj_in_view][0]
		else:
			# check the current state of the obj_in_view and choose the same
			if self.objMeta[obj_in_view]["isToggled"]:
				act = 'ToggleObjectOn'
			else:
				act = 'ToggleObjectOff'
		return act

	def get_positive(self):
		positive_audio = None

		num_visible=0
		obj_in_view=None
		for k in self.visibility:
			if k != "Pillow":
				if self.visibility[k]:
					num_visible=num_visible+1
					obj_in_view=k

		inventory = self.controller.last_event.metadata['inventoryObjects']
		if len(inventory) != 0:
			pos_tsk = Task(loc=self.task.loc, obj=inventory[0]['objectType'], act='PickupObject')
			ground_truth = np.int32(self.task2ID[pos_tsk])
			if self.config.render:
				_, positive_audio, _ = self.audio.getAudioFromTask(self.np_random, pos_tsk, Task)

		else:
			if num_visible!=1:
				# the agent sees nothing, no sound is given
				pos_tsk=None
				ground_truth = np.int32(self.config.taskNum)

			else:  # the agent sees an object
				act=self.get_pos_act(obj_in_view)
				pos_tsk = Task(loc=self.task.loc, obj=obj_in_view, act=act)
				ground_truth = np.int32(self.task2ID[pos_tsk])
				if self.config.render:
					_, positive_audio,_ = self.audio.getAudioFromTask(self.np_random, pos_tsk, Task)

		return ground_truth, positive_audio

	def gen_obs(self):
		"""
		:return: a dict containing various type of observations
		"""
		# update object metadata
		self.updateObjMeta(list(self.objMeta.keys())) 
		self.checkVisible()

		if self.config.render:
			self.agentMeta = self.controller.last_event.metadata["agent"]
			self.local_occupancy = self.get_local_occupancy_map(x=self.agentMeta['position']['x'],
																z=self.agentMeta['position']['z'],
																y=self.agentMeta['rotation']['y'])

		image=self.controller.last_event.frame
		self.saveEpisodeImage(image)

		ground_truth, positive_audio=self.get_positive()

		if self.config.render and positive_audio is not None:
			sd.play(positive_audio, self.audio.fs)

		# Observations are dictionaries containing:
		# - an image (partially observable view of the environment)
		# - a sound mfcc feature (the name of the object if the agent sees it or empty)

		obs = {
			'image': np.transpose(image, (2, 0, 1)), # for PyTorch convolution,
			'ground_truth': np.array([ground_truth]),
		}

		return obs, 0

	def step(self, action):
		action=np.array(action)
		infoDict = {}
		k=None

		if self.config.pretextManualControl or self.config.pretextManualCollect:
			k=self.keyboardControl()
		else:
			self.randomTeleport()

		self.controller.step("Pass") # fix the design choice that images from the Unity window lag by 1 step

		# update counters
		self.envStepCounter = self.envStepCounter + 1
		# get new obs
		obs,_ = self.gen_obs()

		if self.config.use3rdCam:
			self.update3rdCam("Update")

		if k=='r': # save this triplet to buffer
			self.saved_pairs.append(obs)
			print("Number of triplets collected", len(self.saved_pairs))
		elif k=='z': # save collected triplets in the buffer to disk
			self.saveManualPairs()
			print("Triplets saved to", self.config.pretextDataDir[0])

		r =self.rewards() # calculate reward
		self.reward = sum(r)
		self.episodeReward = self.episodeReward + self.reward
		self.done = self.termination()

		return obs, self.reward, self.done, infoDict # reset will be called if done
