import os
import numpy as np
import cv2
import sounddevice as sd
import matplotlib.pyplot as plt
import gym
from .RL_env_RSI2 import RLEnvRSI2, Task
import pickle
from datetime import datetime


class PretextEnvRSI2(RLEnvRSI2):

	def __init__(self):
		RLEnvRSI2.__init__(self)

		# update observation space
		d = {
			'image': gym.spaces.Box(low=0, high=255, shape=self.config.img_dim, dtype='uint8'),
			'sound_negative_id': gym.spaces.Box(low=0, high=self.config.taskNum+1, shape=(1,), dtype=np.int32), 
			'ground_truth': gym.spaces.Box(low=0, high=self.config.taskNum+1, shape=(1,), dtype=np.int32), # sound positive label
		}

		self.observation_space = gym.spaces.Dict(d)
		self.maxSteps = self.config.pretextEnvMaxSteps
		self.visibleDist = self.config.pretextVisibilityDistance

	def setupTask(self):
		self.domainRandomization()
		if self.task.act=='PickupObject':
			self.pickUpByTask(self.task)

	def get_pos_act(self, obj_in_view):
		if len(self.config.allTasks[self.task.loc][obj_in_view]) == 1:
				act = self.config.allTasks[self.task.loc][obj_in_view][0]
		else:
			# check the current state of the obj_in_view and choose the same
			if self.objMeta[obj_in_view]["isToggled"]:
				act = 'ToggleObjectOn'
			else:
				act = 'ToggleObjectOff'
		return act

	def get_positive_negative(self, get_negative, generate_audio):
		"""
		:param get_negative: in RSI2 pretext env, we should get negatives, which are not needed in RSI3
		:param generate_audio: audio is not needed for pretext envs (audios will be paired in dataset.py
		:return:
		"""
		sound_positive=None
		positive_audio = None
		neg_taskID = None

		num_visible=0
		obj_in_view=None
		for k in self.visibility:
			if k != "Pillow":
				if self.visibility[k]:
					num_visible=num_visible+1
					obj_in_view=k

		inventory = self.controller.last_event.metadata['inventoryObjects']
		if len(inventory) != 0:
			pos_tsk = Task(loc=self.task.loc, obj=inventory[0]['objectType'], act='PickupObject')
			ground_truth = np.int32(self.task2ID[pos_tsk])
			if generate_audio or self.config.render:
				sound_positive, positive_audio, _ = self.audio.getAudioFromTask(self.np_random, pos_tsk, Task)
			if get_negative:
				neg_taskID=self.get_negatives(empty=False, ground_truth=ground_truth)

		else:
			if num_visible!=1:
				# the agent sees nothing, no sound is given
				pos_tsk=None
				ground_truth = np.int32(self.config.taskNum)
				if generate_audio:
					sound_positive = np.zeros(shape=self.config.sound_dim)
				if get_negative:
					neg_taskID=self.get_negatives(empty=True, ground_truth=ground_truth)


			else:  # the agent sees an object
				act=self.get_pos_act(obj_in_view)
				pos_tsk = Task(loc=self.task.loc, obj=obj_in_view, act=act)
				ground_truth = np.int32(self.task2ID[pos_tsk])
				if generate_audio or self.config.render:
					sound_positive, positive_audio,_ = self.audio.getAudioFromTask(self.np_random, pos_tsk, Task)
				if get_negative:
					neg_taskID=self.get_negatives(empty=False, ground_truth=ground_truth)

		return sound_positive, neg_taskID, ground_truth, positive_audio

	def gen_obs(self):
		"""
		:return: a dict containing various type of observations
		"""
		# update object metadata
		self.updateObjMeta(list(self.objMeta.keys()))  
		self.checkVisible()

		if self.config.render:
			self.agentMeta = self.controller.last_event.metadata["agent"]
			self.local_occupancy = self.get_local_occupancy_map(x=self.agentMeta['position']['x'],
																z=self.agentMeta['position']['z'],
																y=self.agentMeta['rotation']['y'])

		image=self.controller.last_event.frame
		self.saveEpisodeImage(image)

		sound_positive, neg_taskID, ground_truth, positive_audio = self.get_positive_negative(get_negative=True, generate_audio=False)
		if self.config.render and positive_audio is not None:
			sd.play(positive_audio, self.audio.fs)

		# Observations are dictionaries containing:
		# - an image (partially observable view of the environment)
		# - a sound mfcc feature (the name of the object if the agent sees it or empty)

		obs = {
			'image': np.transpose(image, (2, 0, 1)), # for PyTorch convolution,
			'sound_negative_id': np.array([neg_taskID]),
			'ground_truth': np.array([ground_truth]),
		}

		return obs, 0

	def step(self, action):
		action=np.array(action)
		infoDict = {}
		k=None

		if self.config.pretextManualControl or self.config.pretextManualCollect:
			k=self.keyboardControl()
		else:
			self.randomTeleport()

		self.controller.step("Pass") # fix the design choice that images from the Unity window lag by 1 step

		# update counters
		self.envStepCounter = self.envStepCounter + 1
		# get new obs
		obs,_ = self.gen_obs()

		if self.config.use3rdCam:
			self.update3rdCam("Update")

		if k=='r': # save this triplet to buffer
			self.saved_pairs.append(obs)
			print("Number of triplets collected", len(self.saved_pairs))
		elif k=='z': # save collected triplets in the buffer to disk
			self.saveManualPairs()
			print("Triplets saved to", self.config.pretextDataDir[0])

		r =self.rewards() # calculate reward
		self.reward = sum(r)
		self.episodeReward = self.episodeReward + self.reward
		self.done = self.termination()

		return obs, self.reward, self.done, infoDict # reset will be called if done
