import os
from models.pretext.ai2thor_pretext_model import RSI2PretextNet, representationImgLinear, representationSoundLinear
from dataset import RSI2Dataset, RSI2FineTuneDataset
import warnings
import numpy as np
from collections import OrderedDict
from RSI2.pretext_RSI2 import plotRepresentationRSI2


class AI2ThorConfig(object):
	def __init__(self):
		"""
		Put all env settings here
		"""
		self.name = self.__class__.__name__
		# common env configuration
		self.envFolder = 'ai2thor'
		self.img_dim=(3, 96, 96)  # (channel, image_height, image_width)
		self.sound_dim=(1, 600, 40)  # sound matrix dimension (1, frames, numFeat)

		self.commonMediaPath = os.path.join('..','commonMedia')
		# fluent speech dataset metadata
		self.soundSource = {'dataset': 'FSC',
							'dataset_path': os.path.join(self.commonMediaPath, 'FSC'),
							'train_test': 'train',
							'max_sound_dur': 6.,
							'size': 1000,
							'obj_act': {
								'lights': ['activate', 'deactivate'], 'music': ['activate', 'deactivate'],
								'lamp': ['activate', 'deactivate'], 'shoes':['bring'],
							},
							'locations': ['none'],
							}
		self.soundSource['csv']=self.soundSource['train_test']+'_data.csv'
		
		self.allScene={'livingRoom':[201, 202, 206, 207, 208, 209,
		 									 210, 213, 214, 215, 216, 217, 218, 219, 220]}
		self.keyBoardMapping = OrderedDict(
			[
				('w', "MoveAhead"), ('s', 'MoveBack'), ('a', 'MoveLeft'), ('d', 'MoveRight'),
				('q', "RotateLeft"), ('e', "RotateRight"),
				('T', "ToggleObjectOn"),('t', "ToggleObjectOff"),
				('p', "PickupObject")
			])
		self.allActions = list(self.keyBoardMapping.values())
		# we use orderedDict to maintain the order so that task id will be the same every time
		self.allTasks = OrderedDict([
			('livingRoom', OrderedDict(
				[
					('FloorLamp', ['ToggleObjectOn', 'ToggleObjectOff']),
					('Television', ['ToggleObjectOn', 'ToggleObjectOff']),
					('Pillow', ['PickupObject'])
				]
			)),
		])

		self.taskNum=0
		for loc in self.allTasks:
			for obj in self.allTasks[loc]:
				self.taskNum = self.taskNum + len(self.allTasks[loc][obj])

		# the step size for large rooms:0.5, the step size for the small rooms: 0.25
		# the larger step size makes sure that the robot can finish a task within RLEnvMaxSteps
		self.gridSize={201: 0.25, 202: 0.25, 203: 0.25, 204: 0.25, 205: 0.25, 206: 0.25,
					   207: 0.25, 208: 0.25, 209: 0.25, 210: 0.25, 211: 0.25, 212: 0.25,
					   213: 0.25, 214: 0.25, 215: 0.25, 216: 0.25, 217: 0.25, 218: 0.25,
					   219: 0.25, 220: 0.25, 226: 0.25, 227: 0.25, 228: 0.25, 229: 0.25, 230: 0.5}
		self.snapToGrid=False
		self.rotateStepDegrees=45
		self.fieldOfView=90

		# a dict that defines the relations between ai2thor and fluent speech dataset
		# key: ai2thor, val: fluent speech dataset
		self.synonym = {
			# location
			'kitchen': ['kitchen', 'none'], 'livingRoom': ['none'],
			# object: when an object itself contains action, we write as object_action
			'FloorLamp': ['lights', 'lamp'], 'Television': ['music'],
			'Microwave': ['heat_increase'], 'Fridge': ['heat_decrease'],
			'Pillow': ['shoes'],
			# action
			'ToggleObjectOn': ['increase', 'activate'], 'ToggleObjectOff': ['decrease', 'deactivate'],
			'PickupObject': ['bring']
		}

		# env control
		self.domainRandomization=['randomInitialPose', 'randomObjState']
		# Turn on a matplotlib window to show what the robot sees.
		# Since the Unity window is always on, we don't decide render based on Unity window
		self.render=True
		self.use3rdCam = False  # if True, show a 3rd camera view in a window when render is True. It is for debugging
		self.episodeImgSaveDir=os.path.join('..','data','episodeRecord','tempImgs') # output turtlebot camera images to this location
		self.episodeImgSaveInterval=1  # -1 for not saving. Save an episode of turtlebot camera images every imgSaveInterval episode
		self.episodeImgSize=(96*5,96*5,3) #(height, width, channel)
		self.RSI_ver = 2 # the version of robot sound interpretation

		# pretext env configuration. For RSI1, it is pre training
		self.pretextEnvName = 'ai2thor-pretext-v2'
		self.pretextEnvMaxSteps = 15  # the length of an episode
		self.pretextEnvSeed = 977
		self.pretextNumEnvs = 4 if not self.render else 1# number of envs to collect data
		self.pretextVisibilityDistance=100.

		# pretext task configuration
		self.pretextManualControl = False
		self.pretextManualCollect = False
		self.pretextTrain = False  # collect data and perform training
		self.pretextCollection = False  # if False, data is assumed to be successfully collected and saved on disk
		self.pretextModelFineTune = False  # if true, load from self.pretextModelLoadDir
		self.pretextDataDir = ['path/to/pretextDataDir']
		self.pretextDataFileLoadNum=['all','all', 'all']
		self.pretextDataset = RSI2FineTuneDataset if self.pretextModelFineTune else RSI2Dataset
		self.pretextModel = RSI2PretextNet
		self.pretextModelSaveDir = os.path.join('..','data', 'pretext_model', 'ai2thor_20room_RSI2_5task_pillowShoes')
		self.pretextModelLoadDir = os.path.join('..','data', 'pretext_model', 'ai2thor_20room_RSI2_5task_pillowShoes_229', '39_RSI3Ver.pt')
		self.pretextModelSaveInterval = 10
		self.pretextCollectNum=[1000,1000,1000,1000,1000,1000]
		self.pretextDataNumWorkers = 8
		# the total number of episodes=pretextNumEnvs*pretextDataEpisode*pretextDataNumFiles
		self.pretextDataEpisode = 200  # each env will collect this number of episode
		self.pretextDataNumFiles = 20  # this number of pickle files will be generated
		self.pretextTrainBatchSize = 128
		self.pretextTestBatchSize = 128
		self.pretextLR = 1e-4 # 4 objs: 1e-3, 3objs: 1e-4
		self.pretextAdamL2 = 1e-6
		self.pretextLRStep = 'step'  # choose from ["cos", "step", "none"]
		self.pretextEpoch = 40
		self.pretextLRDecayEpoch = [20,30]  # milestones for learning rate decay
		self.pretextLRDecayGamma = 0.2  # multiplicative factor of learning rate decay
		self.representationDim = 3
		self.tripletMargin=1.0 # for mix or urbansound: 1.2; for others: 1.0
		self.pretextTrainLinear = False
		self.pretextLinearEpoch=40
		self.pretextLinearLR=5e-3
		self.pretextTestMethod = 'linear'
		self.pretextLinearImgModel=representationImgLinear
		self.pretextLinearSoundModel=representationSoundLinear
		self.pretextEmptyCenter=False
		self.plotRepresentation = 50 # plot the representation space every this number of epoch, -1 for not plotting
		self.plotRepresentationOnly = True
		self.plotFunc=plotRepresentationRSI2
		self.plotNumBatch=7
		self.annotateLastBatch=False
		self.plotRepresentationExtra =False # draw datapoint for images in episodeImgSaveDir or sound in commonMedia
		self.plotExtraPath=os.path.join('..','data','episodeRecord','extra')

		# pretrain network
		self.usePretrainedModel=False
		self.freezePretrainedModel=False
		self.trainPretrainedModel = False
		self.pretrainedUseExi=False
		self.pretrainedEpoch = 30
		self.pretrainedModelSaveDir = os.path.join('..','data', 'pretext_model', 'ai2thor_20room_RSI1')
		self.pretrainedModelLoadDir = os.path.join('..','data', 'pretext_model', 'ai2thor_20room_RSI1', '39.pt')
		self.pretrainedBatchSize=128
		self.pretrainedLR=2e-5
		self.pretrainedLRDecayEpoch = [25, 30]  # milestones for learning rate decay
		self.pretrainedLRDecayGamma = 0.2  # multiplicative factor of learning rate decay

		# RL env configuration
		self.RLEnvMaxSteps = 50 # the max number of actions (decisions) for an episode. Time horizon N.
		self.RLEnvName = 'ai2thor-RL-v2'
		self.RLEnvSeed = 3099
		self.RLNumEnvs = 8 if not self.render else 1
		self.RLVisibilityDistance=1.5
		self.RLVisibleGrid=9 #should be an odd number
		self.RLObsIgnore={'current_sound','goal_sound','goal_sound_label',} # the observation name that will be ignored for RL training
		self.RLActionDim=(len(self.allActions),)

		# RL task configuration
		self.RLManualControl = True
		self.RLTrain =False
		self.RLRealTimePlot = False
		self.RLLogDir = os.path.join('..','data', 'RL_model', 'ai2thor')
		self.calcMedoids = False
		self.RLTask = 'approach'
		self.numBatchMedoids = 10  # use numBatchMedoids*pretextTrainBatchSize to approximate medoids
		self.RLPolicyBase = 'ai2thor_RSI2'
		self.RLGamma = 0.99
		self.RLRecurrentPolicy = True
		self.ppoClipParam = 0.2
		self.ppoEpoch = 4
		self.ppoNumMiniBatch = 2
		self.ppoValueLossCoef = 0.5
		self.ppoEntropyCoef = 0.01
		self.ppoUseGAE = True
		self.ppoGAELambda = 0.95
		self.RLLr =6e-5
		self.RLEps = 1e-5
		self.RLMaxGradNorm = 0.5
		self.ppoNumSteps = self.RLEnvMaxSteps
		self.RLTotalSteps = 1e6
		self.RLModelSaveInterval = 200
		self.RLLogInterval = 100
		self.RLModelFineTune=True
		self.RLRewardSoundSound = False  # use the dot product between goal sound and current sound as reward
		self.RLModelSaveDir = os.path.join('..','data', 'RL_model', 'ai2thor_20room_RSI2_5task_pillowShoes_230')
		self.RLModelLoadDir =os.path.join('..','data', 'RL_model', 'ai2thor_20room_RSI2_5task_pillowShoes_230', '02200.pt')
		self.RLUseProperTimeLimits = False
		self.RLRecurrentSize = 1024
		self.RLRecurrentInputSize = 128
		self.RLActionHiddenSize = 128
		self.pretextModelUpdateInterval = np.inf  # 1500

		# test
		self.success_threshold = 1

		# checking configuration and output errors or warnings
		print("######Configuration Checking######")
		if self.RLTrain and self.RLManualControl:
			raise Exception('self.RLTrain and self.RLManualControl cannot be both True')

		if self.RLTrain or self.pretextTrain or self.pretextCollection:
			if self.soundSource['train_test'] == 'test':
				warnings.warn("You are using the test set for training")

		if -1<self.episodeImgSaveInterval < 5:
			warnings.warn("You may save the episode image too frequently")
