import os
import numpy as np
import cv2
import sounddevice as sd
import matplotlib.pyplot as plt
import gym
from ..RSI2.RL_env_RSI2 import RLEnvRSI2, Task
import warnings
from ai2thor.controller import Controller


class PretextEnvRSI1(RLEnvRSI2):

	def __init__(self):
		RLEnvRSI2.__init__(self)

		# update observation space
		d = {
			'image': self.observation_space['image'],
			'goal_sound_label': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.config.taskNum,), dtype=np.int32),
			# the ground_truth is the id of possible inSight result
			'ground_truth': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.int32),
			# zero-one encoding indicating which objects are being seen
			'inSight': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.config.taskNum,), dtype=np.int32),
			# whether the goal object is being seen
			'exi': gym.spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.int32),
		}

		self.observation_space = gym.spaces.Dict(d)
		self.maxSteps = self.config.pretextEnvMaxSteps

		self.visibleDist = self.config.pretextVisibilityDistance
		self.renderInstanceSegmentation = False  # for generating bounding box

	def setupTask(self):
		self.domainRandomization()
		if self.task.act=='PickupObject':
			self.pickUpByTask(self.task)

	def gen_obs(self):
		"""
		:return: a dict containing various type of observations
		"""

		# update object metadata
		self.updateObjMeta(list(self.objMeta.keys()))
		self.checkVisible()

		rgb_image = self.controller.last_event.frame
		self.saveEpisodeImage(rgb_image)
		image = rgb_image

		inSight = []
		for obj in self.config.allTasks[self.task.loc]:
			if obj in ['FloorLamp', 'Television']:
				l = [0] * len(self.config.allTasks[self.task.loc][obj])
				if self.visibility[obj]:
					if self.objMeta[obj]["isToggled"]:
						l[0] = 1
					else:
						l[1] = 1
				inSight.extend(l)
			elif obj=='Pillow':
				if self.visibility[obj]:
					inSight.append(1)
				else:
					inSight.append(0)
			else:
				raise NotImplementedError

		inSight=np.array(inSight)
		inSightSum=sum(inSight)
		if inSightSum==0:
			gt=self.config.taskNum
		else:
			gt=int(self.np_random.choice(np.where(inSight==1)[0]))

		exi = 0
		if self.visibility[self.task.obj]:
			exi = 1

		if self.envStepCounter == 0:
			self.goal_sound, self.goal_audio, self.transcription = self.audio.getAudioFromTask(self.np_random, self.task, Task)
			if self.goal_audio is not None and self.config.render:
				sd.play(self.goal_audio, self.audio.fs)
				print('Goal intent is', self.task.loc + ' ' + self.task.obj + ' ' + self.task.act)

		goal_sound_label=[0]*self.config.taskNum
		goal_sound_label[int(self.task2ID[self.task])]=1

		obs = {
			'image': np.transpose(image, (2, 0, 1)),
			'goal_sound_label': goal_sound_label,
			'ground_truth': gt,
			'inSight': inSight,
			'exi': np.array(exi),
		}

		return obs,None, None

	def step(self, action):
		action=np.array(action)

		if self.config.pretextManualControl:
			self.keyboardControl()
		else:
			self.randomTeleport()

		self.controller.step("Pass") # fix the design choice that images from the Unity window lag by 1 step

		# update counters
		self.envStepCounter = self.envStepCounter + 1
		# get new obs
		obs,_,_ = self.gen_obs()
		r =self.rewards() # calculate reward
		self.reward = sum(r)
		self.episodeReward = self.episodeReward + self.reward
		self.done = self.termination()
		infoDict={}

		return obs, self.reward, self.done, infoDict # reset will be called if done
