import os
from models.pretext.ai2thor_pretext_model import RSI3PretextNet, representationImgLinear, representationSoundLinear
from dataset import RSI3Dataset, RSI3FineTuneDataset
from RSI3.pretext_RSI3 import plotRepresentationRSI3
import warnings
import numpy as np
from collections import OrderedDict


class AI2ThorConfig(object):
	def __init__(self):
		"""
		Put all env settings here
		"""
		self.name = self.__class__.__name__
		# common env configuration
		self.envFolder = 'ai2thor'
		self.img_dim=(3, 96, 96)  # (channel, image_height, image_width)
		self.sound_dim=(1, 600, 40)  # sound matrix dimension (1, frames, numFeat)
		self.commonMediaPath = os.path.join('..','commonMedia')

		# fluent speech dataset metadata
		self.soundSource = {'dataset': 'FSC',
							'dataset_path': os.path.join(self.commonMediaPath, 'FSC'),
							'train_test': 'test',
							'max_sound_dur':6.,
							'size':50,
							'obj_act': {
							'lights':['activate', 'deactivate'], 'music': ['activate', 'deactivate'],
							'lamp': ['activate', 'deactivate'], 'shoes':['bring']
							},
							'locations': ['none'],
							}
		self.soundSource['csv']=self.soundSource['train_test']+'_data.csv'
		self.allScene={'livingRoom':[226]}
		self.keyBoardMapping=OrderedDict([('w',"MoveAhead"), ('s', 'MoveBack'), ('a', 'MoveLeft'), ('d', 'MoveRight'),
										  ('q',"RotateLeft"), ('e',"RotateRight"),
										  ('T',"ToggleObjectOn"), ('t',"ToggleObjectOff"),
										  ('p', "PickupObject")
										  ])
		self.allActions=list(self.keyBoardMapping.values())

		# we use orderedDict to maintain the order so that task id will be the same every time
		self.allTasks = OrderedDict([
			('livingRoom', OrderedDict(
				[
					('FloorLamp', ['ToggleObjectOn','ToggleObjectOff']),
					('Television', ['ToggleObjectOn','ToggleObjectOff']),
					('Pillow', ['PickupObject'])
				]
			)),
		])
		self.taskNum=0
		for loc in self.allTasks:
			for obj in self.allTasks[loc]:
				self.taskNum = self.taskNum + len(self.allTasks[loc][obj])

		# the step size for large rooms:0.5, the step size for the small rooms: 0.25
		# the larger step size makes sure that the robot can finish a task within RLEnvMaxSteps
		self.gridSize={201: 0.25, 202: 0.25, 203: 0.25, 204: 0.25, 205: 0.25, 206: 0.25,
					   207: 0.25, 208: 0.25, 209: 0.25, 210: 0.25, 211: 0.25, 212: 0.25,
					   213: 0.25, 214: 0.25, 215: 0.25, 216: 0.25, 217: 0.25, 218: 0.25,
					   219: 0.25, 220: 0.25, 226: 0.25, 227: 0.25, 228: 0.25, 229: 0.25, 230: 0.5}
		self.snapToGrid=False
		self.rotateStepDegrees=45
		self.fieldOfView=90

		# a dict that defines the relations between ai2thor and fluent speech dataset
		# key: ai2thor, val: fluent speech dataset
		self.synonym={
		# location
		'kitchen':['kitchen', 'none'], 'livingRoom': ['none'],
		# object: when an object itself contains action, we write as object_action
		'FloorLamp': ['lights', 'lamp'], 'Television': ['music'],
		 'Microwave': ['heat_increase'], 'Fridge': ['heat_decrease'],
		'Pillow':['shoes'],
		# action
		'ToggleObjectOn': ['increase', 'activate'], 'ToggleObjectOff': ['decrease', 'deactivate'],
		'PickupObject': ['bring']
		}

		# env control
		self.domainRandomization=['randomInitialPose', 'randomObjState']
		# Turn on a matplotlib window to show what the robot sees.
		# Since the Unity window is always on, we don't decide render based on Unity window
		self.render=True
		self.use3rdCam = False  # if True, show a 3rd camera view in a window when render is True. It is for debugging
		self.episodeImgSaveDir=os.path.join('..','data','episodeRecord','tempImgs') # output turtlebot camera images to this location
		self.episodeImgSaveInterval=-1  # -1 for not saving. Save an episode of turtlebot camera images every imgSaveInterval episode
		self.episodeImgSize=(96*5,96*5,3) #(height, width, channel)
		self.RSI_ver = 3 # the version of robot sound interpretation

		# pretext env configuration.
		self.pretextEnvName = 'ai2thor-pretext-v3'
		self.pretextEnvMaxSteps = 15  # the length of an episode
		self.pretextEnvSeed = 500
		self.pretextNumEnvs = 4 if not self.render else 1# number of envs to collect data
		self.pretextVisibilityDistance=100.
		self.pretextBBPixelAmount=5 #augment the bounding box from the simulator by this amount of pixels

		# pretext task configuration
		self.pretextManualControl = False
		self.pretextManualCollect = False
		self.pretextTrain = True  # collect data and perform training
		self.pretextCollection = False  # if False, data is assumed to be successfully collected and saved on disk
		self.pretextModelFineTune = True  # if true, load from self.pretextModelLoadDir
		self.pretextDataDir = ['path/to/pretextDataDir']
		self.pretextDataset=RSI3FineTuneDataset if self.pretextModelFineTune else RSI3Dataset
		self.pretextDataFileLoadNum=['all','all', 'all']
		self.pretextModel = RSI3PretextNet
		self.pretextModelSaveDir = os.path.join('..','data', 'pretext_model', 'ai2thor_20room_RSI3_5task_pillowShoes_newSoundBranch_226_1200')
		self.pretextModelLoadDir = os.path.join('..','data', 'pretext_model', 'RSI3', 'ai2thor_20room_RSI3_5task_pillowShoes_newSoundBranch_226_1200', '39.pt')
		self.pretextModelSaveInterval = 5
		self.pretextCollectNum=[100,100,100,100,100,100]
		self.pretextDataNumWorkers = 8
		# the total number of episodes=pretextNumEnvs*pretextDataEpisode*pretextDataNumFiles
		self.pretextDataEpisode = 200  # each env will collect this number of episode
		self.pretextDataNumFiles = 20  # this number of pickle files will be generated
		self.pretextTrainBatchSize = 128
		self.pretextTestBatchSize = 128
		self.pretextLR = 1e-4 # 5e-4 for training, 5e-5 for fine-tuning
		self.pretextAdamL2 = 1e-5 #was 1e-6
		self.pretextLRStep = 'step'  # choose from ["cos", "step", "none"]
		self.pretextEpoch = 40
		self.pretextLRDecayEpoch = [15, 30]  # milestones for learning rate decay
		self.pretextLRDecayGamma = 0.2  # multiplicative factor of learning rate decay
		self.representationDim = 3
		self.representationTau=0.3 # the temperature parameter in the supcon loss. alpha in triplet alpha = 2tau
		self.pretextEmptyCenter=True # make the empty class appear at the center of the VAR
		# True: we use supcon loss. False: all the other (image, S+) paris within a batch are negatives
		# the current (image, S+) is the only positive, similar to self-supervised learning method like simclr
		self.representationUseLabels=False if self.pretextModelFineTune else True
		self.pretextUseAudioAug=False
		self.pretextTrainLinear = False
		self.pretextLinearEpoch = 40
		self.pretextLinearLR = 5e-3
		self.pretextTestMethod = 'linear'
		self.pretextLinearImgModel = representationImgLinear
		self.pretextLinearSoundModel = representationSoundLinear
		self.plotRepresentation = 50 # plot the representation space every this number of epoch, -1 for not plotting
		self.plotFunc=plotRepresentationRSI3
		self.plotRepresentationOnly = True
		self.plotNumBatch=7
		self.annotateLastBatch=False
		self.plotRepresentationExtra =False # draw datapoint for images in episodeImgSaveDir or sound in commonMedia
		self.plotExtraPath=os.path.join('..','data','episodeRecord','extra')

		# pretrain network
		self.usePretrainedModel=False
		self.freezePretrainedModel=False
		self.trainPretrainedModel = False
		self.pretrainedUseExi=False
		self.pretrainedEpoch = 30
		self.pretrainedModelSaveDir = os.path.join('..','data', 'pretext_model', 'ai2thor_RSI3_5task_pillowShoes_VGG')
		self.pretrainedModelLoadDir = os.path.join('..','data', 'pretext_model', 'ai2thor_RSI3_5task_pillowShoes_VGG', '29.pt')
		self.pretrainedBatchSize=128
		self.pretrainedLR=2e-5
		self.pretrainedLRDecayEpoch = [25, 30]  # milestones for learning rate decay
		self.pretrainedLRDecayGamma = 0.2  # multiplicative factor of learning rate decay

		# RL env configuration
		self.RLEnvMaxSteps = 50 # the max number of actions (decisions) for an episode. Time horizon N.
		self.RLEnvName = 'ai2thor-RL-v3'
		self.RLEnvSeed = 1024
		self.RLObsIgnore={'current_sound','goal_sound','goal_sound_label',} # the observation name that will be ignored for RL training
		self.RLNumEnvs = 8 if not self.render else 1
		self.RLVisibilityDistance=1.5
		self.RLVisibleGrid=9 #should be an odd number
		self.RLActionDim=(len(self.allActions),)

		# RL task configuration
		self.RLManualControl = True
		self.RLTrain = False
		self.RLRealTimePlot = True
		self.RLLogDir = os.path.join('..','data', 'RL_model', 'ai2thor')
		self.calcMedoids = False
		self.RLTask = 'approach'
		self.numBatchMedoids = 10  # use numBatchMedoids*pretextTrainBatchSize to approximate medoids
		self.RLPolicyBase = 'ai2thor_RSI3'
		self.RLGamma = 0.99
		self.RLRecurrentPolicy = True
		self.ppoClipParam = 0.2
		self.ppoEpoch = 4
		self.ppoNumMiniBatch = 2
		self.ppoValueLossCoef = 0.5
		self.ppoEntropyCoef = 0.01
		self.ppoUseGAE = True
		self.ppoGAELambda = 0.95
		self.RLLr =5e-5 # the learning rate should be less than 5e-4 and greater than 7e-5
		self.RLEps = 1e-5
		self.RLMaxGradNorm = 0.5
		self.ppoNumSteps = self.RLEnvMaxSteps
		self.RLTotalSteps = 1e6
		self.RLModelSaveInterval = 200
		self.RLLogInterval = 100
		self.RLModelFineTune=True
		self.RLRewardSoundSound = False  # use the dot product between goal sound and current sound as reward
		self.RLModelSaveDir = os.path.join('..','data', 'RL_model', 'ai2thor_20room_RSI3_5task_pillowShoes_lowLR_226_1200')
		self.RLModelLoadDir =os.path.join('..','data', 'RL_model', 'ai2thor_20room_RSI3_5task_pillowShoes_lowLR', '22400.pt')
		self.RLUseProperTimeLimits = False
		self.RLRecurrentSize = 1024
		self.RLRecurrentInputSize = 128
		self.RLActionHiddenSize = 128
		self.RLAuxExiLossWeight=1.

		# update representation
		self.pretextModelUpdateInterval=np.inf
		self.pretextModelUpdateEpoch = 5
		self.pretextModelUpdateLR = self.pretextLR * self.pretextLRDecayGamma * 0.2
		self.pretextModelUpdatePairProb = 2  # 0.97
		self.pretextModelUpdateDataDir = os.path.join('..','data', 'pretext_training', 'TurtleBot_update')

		# test
		self.success_threshold = 1
		self.hierarchy=False
		self.skillInfos = [
			{'path': os.path.join('..','data', 'RL_model', 'ai2thor_20room_RSI3_5task_pillowShoes_lowLR_226_1200', '01000.pt'),
			 'actionDim': 9, },
		]

		# taskID as key, index of skillPaths as value
		# for example, taskID 0 and 1 coorespond to skill 0 in the skillPath {0:0, 1:0}
		self.taskID2Skill={0:0, 1:0, 2:1, 3:1, 4:2}

		# checking configuration and output errors or warnings
		print("######Configuration Checking######")
		if self.RLTrain and self.RLManualControl:
			raise Exception('self.RLTrain and self.RLManualControl cannot be both True')

		if self.RLTrain or self.pretextTrain or self.pretextCollection:
			if self.soundSource['train_test'] == 'test':
				warnings.warn("You are using the test set for training")

		if -1<self.episodeImgSaveInterval < 5:
			warnings.warn("You may save the episode image too frequently")
