import os
from models.pretext.ai2thor_pretext_model import RSI1Pretrain
import warnings
import numpy as np
from collections import OrderedDict


class AI2ThorConfig(object):

	def __init__(self):
		"""
		Put all env settings here
		"""
		self.name = self.__class__.__name__

		# common env configuration
		self.envFolder = 'ai2thor'
		self.img_dim=(3, 96, 96)  # (channel, image_height, image_width)
		self.sound_dim=(1, 600, 40)  # sound matrix dimension (1, frames, numFeat)
		self.commonMediaPath = os.path.join('..','commonMedia')

		# fluent speech dataset metadata
		self.soundSource = {'dataset': 'FSC',
							'dataset_path': os.path.join(self.commonMediaPath, 'FSC'),
							'train_test': 'test',
							'max_sound_dur': 6.,
							'size': 50,
							'obj_act': {
								'lights': ['activate', 'deactivate'], 'music': ['activate', 'deactivate'],
								'lamp': ['activate', 'deactivate'], 'shoes': ['bring']
							},
							'locations': ['none'],
							}
		self.soundSource['csv'] = self.soundSource['train_test'] + '_data.csv'

		self.allScene = {'livingRoom': [230]}
		self.keyBoardMapping = OrderedDict(
			[('w', "MoveAhead"), ('s', 'MoveBack'), ('a', 'MoveLeft'), ('d', 'MoveRight'),
			 ('q', "RotateLeft"), ('e', "RotateRight"),
			 ('T', "ToggleObjectOn"), ('t', "ToggleObjectOff"),
			 ('p', "PickupObject")
			 ])
		self.allActions = list(self.keyBoardMapping.values())

		# we use orderedDict to maintain the order so that task id will be the same every time
		self.allTasks = OrderedDict([
			('livingRoom', OrderedDict(
				[
					('FloorLamp', ['ToggleObjectOn', 'ToggleObjectOff']),
					('Television', ['ToggleObjectOn', 'ToggleObjectOff']),
					('Pillow', ['PickupObject'])
				]
			)),
		])
		self.taskNum=0
		for loc in self.allTasks:
			for obj in self.allTasks[loc]:
				self.taskNum = self.taskNum + len(self.allTasks[loc][obj])

		# the step size for large rooms:0.5, the step size for the small rooms: 0.25
		# the larger step size makes sure that the robot can finish a task within RLEnvMaxSteps
		self.gridSize={201: 0.25, 202: 0.25, 203: 0.25, 204: 0.25, 205: 0.25, 206: 0.25,
					   207: 0.25, 208: 0.25, 209: 0.25, 210: 0.25, 211: 0.25, 212: 0.25,
					   213: 0.25, 214: 0.25, 215: 0.25, 216: 0.25, 217: 0.25, 218: 0.25,
					   219: 0.25, 220: 0.25, 226: 0.25, 227: 0.25, 228: 0.25, 229: 0.25, 230: 0.5}
		self.snapToGrid=False
		self.rotateStepDegrees=45
		self.fieldOfView=90

		# a dict that defines the relations between ai2thor and fluent speech dataset
		# key: ai2thor, val: fluent speech dataset
		self.synonym = {
			# location
			'kitchen': ['kitchen', 'none'], 'livingRoom': ['none'],
			# object: when an object itself contains action, we write as object_action
			'FloorLamp': ['lights', 'lamp'], 'Television': ['music'],
			'Microwave': ['heat_increase'], 'Fridge': ['heat_decrease'],
			'Pillow': ['shoes'],
			# action
			'ToggleObjectOn': ['increase', 'activate'], 'ToggleObjectOff': ['decrease', 'deactivate'],
			'PickupObject': ['bring']
		}

		# env control
		self.domainRandomization=['randomInitialPose', 'randomObjState']
		# Turn on a matplotlib window to show what the robot sees.
		# Since the Unity window is always on, we don't decide render based on Unity window
		self.render=False
		self.use3rdCam = False  # if True, show a 3rd camera view in a window when render is True. It is for debugging
		self.episodeImgSaveDir=os.path.join('..','data','episodeRecord','tempImgs') # output turtlebot camera images to this location
		self.episodeImgSaveInterval=-1  # -1 for not saving. Save an episode of turtlebot camera images every imgSaveInterval episode
		self.episodeImgSize=(96*5,96*5,3) #(height, width, channel)
		self.RSI_ver = 1 # the version of robot sound interpretation

		# pretext env configuration. For RSI1, it is pre training
		self.pretextEnvName = 'ai2thor-pretext-v1'
		self.pretextEnvMaxSteps = 15  # the length of an episode
		self.pretextEnvSeed = 999
		self.pretextNumEnvs = 4 if not self.render else 1# number of envs to collect data
		self.pretextVisibilityDistance=100.

		# pretext task configuration
		self.pretextManualControl = False
		self.pretextManualCollect = False
		self.pretextTrain = True  # collect data and perform training
		self.pretextCollection = False  # if False, data is assumed to be successfully collected and saved on disk
		self.pretextDataDir = ['path/to/pretextDataDir']
		self.pretextDataFileLoadNum=['all','all', 'all']
		self.pretextModel = RSI1Pretrain
		self.pretextModelFineTune=False # if true, load from self.pretextModelLoadDir
		self.pretextModelSaveDir = os.path.join('..','data', 'pretext_model', 'ai2thor_RSI1_5task_pillowShoes2')
		self.pretextModelLoadDir = os.path.join('..','data', 'pretext_model', 'ai2thor_RSI1_5task_pillowShoes2', '39.pt')
		self.pretextModelSaveInterval = 10
		self.pretextCollectNum=[3333,3333,3333,3333,3333, 3333]
		self.pretextDataNumWorkers = 8
		# the total number of episodes=pretextNumEnvs*pretextDataEpisode*pretextDataNumFiles
		self.pretextDataEpisode = 200  # each env will collect this number of episode
		self.pretextDataNumFiles = 20  # this number of pickle files will be generated
		self.pretextTrainBatchSize = 128
		self.pretextTestBatchSize = 128
		self.pretextLR = 1e-4 # 4 objs: 1e-3, 3objs: 1e-4
		self.pretextAdamL2 = 1e-6
		self.pretextLRStep = 'step'  # choose from ["cos", "step", "none"]
		self.pretextEpoch = 40
		self.pretextLRDecayEpoch = [30, 40]  # milestones for learning rate decay
		self.pretextLRDecayGamma = 0.2  # multiplicative factor of learning rate decay
		self.representationDim=0

		# pretrain network
		self.usePretrainedModel=False
		self.freezePretrainedModel=False
		self.pretrainedModelLoadDir = os.path.join('..','data', 'pretext_model', 'ai2thor_RSI1_5task_pillowShoes', '39.pt')

		# RL env configuration
		self.RLEnvMaxSteps = 50 # the max number of actions (decisions) for an episode. Time horizon N.
		self.RLEnvName = 'ai2thor-RL-v1'
		self.RLEnvSeed = 1333
		self.RLNumEnvs = 8 if not self.render else 1
		self.RLVisibilityDistance=1.5
		self.RLVisibleGrid=9 #should be an odd number
		self.distanceRewardFactor=1.
		self.RLActionDim=(len(self.allActions),)

		# RL task configuration
		self.RLManualControl = False
		self.RLTrain =False
		self.RLUseSoundLabel = False  # if True, the network uses one-hot sound label directly as input without using the mfccs.
		self.RLObsIgnore={'goal_sound'} if self.RLUseSoundLabel else {}
		self.RLLogDir = os.path.join('..','data', 'RL_model', 'ai2thor')
		self.RLPolicyBase = 'ai2thor_RSI1'
		self.RLGamma = 0.99
		self.RLRecurrentPolicy = True
		self.ppoClipParam = 0.2
		self.ppoEpoch = 4
		self.ppoNumMiniBatch = 2
		self.ppoValueLossCoef = 0.5
		self.ppoEntropyCoef = 0.01
		self.ppoUseGAE = True
		self.ppoGAELambda = 0.95
		self.RLLr =1e-4
		self.RLEps = 1e-5
		self.RLMaxGradNorm = 0.5
		self.ppoNumSteps = self.RLEnvMaxSteps
		self.RLTotalSteps = 1e6
		self.RLModelSaveInterval = 200
		self.RLLogInterval = 100
		self.RLModelFineTune=False
		self.RLModelSaveDir = os.path.join('..','data', 'RL_model', 'ai2thor_RSI1_5task_pillowShoes_230')
		self.RLModelLoadDir =os.path.join('..','data', 'RL_model', 'ai2thor_RSI1_5task_pillowShoes_230', '02400.pt')
		self.RLUseProperTimeLimits = False
		self.RLRecurrentSize = 1024
		self.RLRecurrentInputSize = 128
		self.RLActionHiddenSize = 128
		self.RLAuxSoundLossWeight=1.
		self.RLAuxInSightLossWeight = 1.
		self.RLAuxExiLossWeight = 1.

		# test
		self.success_threshold = 1

		# checking configuration and output errors or warnings
		print("######Configuration Checking######")
		if self.RLTrain and self.RLManualControl:
			raise Exception('self.RLTrain and self.RLManualControl cannot be both True')

		if self.RLTrain or self.pretextTrain or self.pretextCollection:
			if self.soundSource['train_test'] == 'test':
				warnings.warn("You are using the test set for training")

		if -1<self.episodeImgSaveInterval < 5:
			warnings.warn("You may save the episode image too frequently")
