import os
from models.pretext.turtlebot_pretext_model import RSI2PretextNet, representationImgLinear, representationSoundLinear
import warnings
from dataset import RSI2Dataset, RSI2FineTuneDataset


class TurtleBotConfig(object):
	def __init__(self):
		"""
		Put all env settings here
		"""
		self.name = self.__class__.__name__
		# common env configuration
		self.objList=['cube', 'sphere', 'cone', 'cylinder']
		self.taskNum=len(self.objList)
		self.img_dim=(3, 96, 96)  # (channel, image_height, image_width)
		self.sound_dim=(1, 100, 40)  # sound matrix dimension (1, frames, numFeat)

		self.soundSourcePreset = 'normal'
		if self.soundSourcePreset == 'mix':
			self.soundSource = {'dataset': ['GoogleCommand', 'UrbanSound'],
								'items': {'GoogleCommand': ['house', 'tree', 'bird', 'dog'],
										  'UrbanSound': ['jackhammer', None, None, 'dog_bark']},
								'size': {'GoogleCommand': [25, 50, 50, 25], 'UrbanSound': [25, 0, 0, 25]},
								'train_test': 'test',
								}
		elif self.soundSourcePreset == 'normal':
			self.soundSource = {'dataset': ['GoogleCommand'],
								'items': {'GoogleCommand': ['zero', 'one', 'two', 'three']},
								'size': {'GoogleCommand': [1000, 1000, 1000, 1000]},
								'train_test': 'train',
								}
		self.commonMediaPath = os.path.join('..','commonMedia')
		self.mediaPath = os.path.join('..',"Envs", "pybullet", "turtlebot", "media")  # objects' model
		self.envFolder = os.path.join('pybullet', 'turtlebot')

		self.numRays=11
		self.rayLen=4
		self.rayHitColor = [1, 0, 0]
		self.rayMissColor = [0, 1, 0]
		self.render=False
		self.frameSkip = 36 

		# robot_bases.py will use the name indicated by robotName as the robot body
		# the robot body represents the whole robot, which is usually the base of a mobile robot
		self.robotName= 'base_link'
		self.robotScale=1

		# robot camera
		self.robotCamOffset=0.02 # it is used to adjust the near clipping plane of the camera
		self.robotCamRenderSize= (75, 100, 3)  # simulation render (height, width, channel)
		self.robotFov=48.8
		# pybullet debug GUI viewing angle and distance
		self.debugCam_dist= 1.8
		self.debugCam_yaw= -90
		self.debugCam_pitch= -65

		# objects and collision checking
		# the radius of the region for robot initial position.
		# e.g. 0.5->a circle with radius=0.5 centered at the world frame origin
		self.robotInitRegion_radius=0.5
		# we define the radius of an entity as the radius of the circle tangent to the entity's xy-plane bounding box
		self.robotRadius=0.143
		# the original radius of the models
		self.objectsRadius= {'cube': 0.15, 'sphere': 0.10, 'cone': 0.065, 'cylinder': 0.05}
		self.objectsExpandDistance = {'cube': 0.05, 'sphere': 0.05, 'cone': 0.05,
									  'cylinder': 0.15}  # extend the original radius of models for collision checking
		# Originally, the robot's radius is 0.22. We add this value so that the
		# collision checking radius is 0.22+robotExpandDistance. See robot_locomotors.isCollide()
		self.robotExpandDistance= 0.097

		self.placementExtension=0.25 # the locations of objects will be more sparse with bigger value

		# robot control
		self.pointFollowerLinearGain = 1
		self.pointFollowerAngularGain = 1
		self.rotPosPGain = 1.5  # P control gain for rotPos control mode
		self.robotWheelDistance = 0.287  # the distance between two wheels
		self.robotWheelRadius = 0.033  # the radius of the wheels
		# Turtlebot3 max transitional velocity=0.26 m/s and max rotational velocity=1.82 rad/s (104.27 deg/s)
		self.robotMaxTransVel = 0.25  # max translational velocity in m/s
		self.robotMinTransVel = -0.1
		self.robotMaxRotVel = 1.1  # max rotational velocity in radian

		# env control
		self.ifReset=True # if you want to reset the arena after an episode ends
		self.domainRandomization=False # True if you want to randomize textures for the wall and objects
		self.numTexture=700  # number of texture to load. Bigger number will take more memory
		self.record= False  # save important data of an episode.
		self.loadAction=False  # load action from file so that the neural network decision is not used
		# recorder load action from location
		self.loadActionFile = os.path.join('..','libs', 'PPO', 'savedModel', 'episodeRecoder', 'ep' + str(2), 'action.csv')
		# recorder save information to this location
		self.recordSaveDir= os.path.join('..','data', 'episodeRecord', 'savedModel', 'episodeRecoder', 'test')

		self.episodeImgSaveDir=os.path.join('..','data','episodeRecord','tempImgs') # output turtlebot camera images to this location
		self.episodeImgSaveInterval=-1  # -1 for not saving. Save an episode of turtlebot camera images every imgSaveInterval episode
		self.episodeImgSize=(224,224,3) #(height, width, channel)

		self.RSI_ver = 2  # the version of robot sound interpretation

		# pretext env configuration
		self.pretextEnvName = 'turtlebot-pretext-v2'
		self.pretextActionDim = (3,)
		self.pretextEnvMaxSteps = 30  # the length of an episode
		self.pretextEnvSeed = 10
		self.pretextNumEnvs = 4 if not self.render else 1# number of Envs to collect data

		# pretext robot control
		self.pretextRobotControl = 'setPose'
		self.xyMax=0.8

		# pretext task configuration
		self.pretextManualControl = False
		self.pretextManualCollect = False
		self.pretextTrain = False  # collect data and perform training
		self.pretextCollection = False  # if False, data is assumed to be successfully collected and saved on disk
		self.pretextDataDir = ['path/to/pretextDataDir']
		self.pretextDataFileLoadNum = ['all', 'all', 'all']
		self.pretextCollectNum = [1000, 1000, 1000, 1000, 1000]
		self.pretextModelFineTune=False
		self.pretextDataset = RSI2FineTuneDataset if self.pretextModelFineTune else RSI2Dataset
		self.pretextModel = RSI2PretextNet
		self.pretextModelFineTune = False
		self.pretextModelSaveDir = os.path.join('..','data', 'pretext_model', 'TurtleBot_0123')
		self.pretextModelLoadDir = os.path.join(self.pretextModelSaveDir, '39_RSI3Ver.pt')

		self.pretextModelSaveInterval = 10
		self.pretextDataNumWorkers = 4
		# the total number of episodes=pretextNumEnvs*pretextDataEpisode*pretextDataNumFiles
		self.pretextDataEpisode = 500  # each env will collect this number of episode
		self.pretextDataNumFiles = 20  # this number of pickle files will be generated
		self.pretextTrainBatchSize = 128
		self.pretextTestBatchSize = 128
		self.pretextLR = 1e-4 # 4 objs: 1e-3, 3objs: 1e-4
		self.pretextAdamL2 = 1e-6
		self.pretextLRStep = 'step'  # choose from ["cos", "step", "none"]
		self.pretextEpoch = 10
		self.pretextLRDecayEpoch = [30, 40]  # milestones for learning rate decay
		self.pretextLRDecayGamma = 0.2  # multiplicative factor of learning rate decay
		self.representationDim = 3
		self.tripletMargin=1.0 # for mix or urbansound: 1.2; for others: 1.0

		self.plotRepresentation = 50 # plot the representation space every this number of epoch, -1 for not plotting
		self.plotRepresentationOnly = True
		self.plotNumBatch=7
		self.annotateLastBatch=False
		self.plotRepresentationExtra =False # draw datapoint for images in episodeImgSaveDir or sound in commonMedia
		self.plotExtraPath=os.path.join('..','data','episodeRecord','extra')

		self.pretextTrainLinear = False
		self.pretextLinearEpoch = 40
		self.pretextLinearLR = 5e-3
		self.pretextTestMethod = 'linear'
		self.pretextLinearImgModel = representationImgLinear
		self.pretextLinearSoundModel = representationSoundLinear

		# pretrain network
		self.usePretrainedModel = True
		self.trainPretrainedModel = False
		self.pretrainedEpoch = 30
		self.pretrainedModelSaveDir = os.path.join('..','data', 'pretext_model', 'TurtleBot_mix2_VGG')
		self.pretrainedModelLoadDir = os.path.join('..','data', 'pretext_model', 'TurtleBot_nsynth_noAux_sparseR', '9.pt')
		self.pretrainedBatchSize=128
		self.pretrainedLR=2e-5
		self.pretrainedLRDecayEpoch = [25, 30]  # milestones for learning rate decay
		self.pretrainedLRDecayGamma = 0.2  # multiplicative factor of learning rate decay

		# RL env configuration
		self.RLEnvMaxSteps = 80 # the max number of actions (decisions) for an episode. Time horizon N.
		self.RLEnvName = 'turtlebot-RL-v2'
		self.RLActionDim = (2,)
		self.RLEnvSeed = 66
		self.RLNumEnvs = 8 if not self.render else 1

		# RL robot control
		# velocity: the action will be transitional velocity v and rotational velocity omega
		# rotPos: the action will be transitional velocity v and rotational position. P control is applied
		# debug: move holonomically, (x,y,theta)
		self.RLRobotControl = 'rotPos'

		# RL task configuration
		self.RLManualControl = False
		self.RLTrain =True
		self.RLModelFineTune=False
		self.RLRealTimePlot = False
		self.RLLogDir = os.path.join('..','data', 'RL_model', 'TurtleBot')
		self.calcMedoids = False
		self.RLTask = 'approach'
		self.numBatchMedoids = 10  # use numBatchMedoids*pretextTrainBatchSize to approximate medoids

		self.RLPolicyBase='turtlebot_RSI2'
		self.RLGamma = 0.99
		self.RLRecurrentPolicy = True
		self.ppoClipParam = 0.2
		self.ppoEpoch = 4
		self.ppoNumMiniBatch = 2 if not self.render else 1
		self.ppoValueLossCoef = 0.5
		self.ppoEntropyCoef = 0.01
		self.ppoUseGAE = True
		self.ppoGAELambda = 0.95
		self.RLLr = 8e-5#8e-6
		self.RLEps = 1e-5
		self.RLMaxGradNorm = 0.5
		self.ppoNumSteps = 80
		self.RLTotalSteps = 3e6
		self.RLModelSaveInterval = 200
		self.RLLogInterval = 100
		self.RLModelSaveDir = os.path.join('..','data', 'RL_model', 'TurtleBot_nsynth_noAux_sparseR2')
		self.RLModelLoadDir =os.path.join('..','data', 'RL_model', 'TurtleBot_nsynth_noAux_sparseR2', '04600.pt')
		self.RLUseProperTimeLimits = False
		self.RLRecurrentSize = 1024
		self.RLRecurrentInputSize = 128
		self.RLActionHiddenSize = 128

		# test
		self.success_threshold = 5

		# checking configuration and output errors or warnings
		print("######Configuration Checking######")
		if self.RLTrain and self.RLManualControl:
			raise Exception('self.RLTrain and self.RLManualControl cannot be both True')

		if self.RLTrain or self.pretextTrain:
			if self.soundSource['train_test'] == 'test':
				warnings.warn("You are using the test set for training")

		if not self.RLTrain:
			if self.soundSource['train_test'] == 'train':
				warnings.warn("You are using the train set for testing")

		if 0<self.episodeImgSaveInterval < 5:
			warnings.warn("You may save the episode image too frequently")
