import numpy as np
import os
from models.pretext.kuka_pretext_model import RSI3PretextNet, representationImgLinear, representationSoundLinear
import warnings
from dataset import RSI3Dataset, RSI3FineTuneDataset
from RSI3.pretext_RSI3 import plotRepresentationRSI3


class KukaConfig(object):
	def __init__(self):
		"""
		Put all env settings here
		"""
		self.name = self.__class__.__name__
		# common env configuration
		self.objList=['key', 'key', 'key', 'key']
		self.objList=['key', 'key', 'key', 'key']
		self.taskNum=len(self.objList)
		# hide the obj so that the camera can't see
		# mode can be 'random', 'fix', 'none'. If 'random', it will hide 'hideNum' blocks and 'hideIdx' is irrelevant
		# If 'fix', it will hide the blocks indicated by the list 'hideIdx' , 'hideNum' is irrelevant
		# if 'none', hide no blocks
		self.hideObj={'mode': 'none', 'hideNum':1, 'hideIdx':[2]}
		self.objInterval=0.1 # the distance between two objects
		self.objZ=-0.085
		self.tableZ=-0.75
		# object and end-effector location range
		self.xMax=0.75
		self.xMin=0.45
		self.yMax=0.35
		self.yMin=-0.25
		self.img_dim = (3, 96, 96)  # (channel, image_height, image_width)

		self.commonMediaPath = os.path.join('..','commonMedia')
		self.soundSourcePreset='normal'
		if self.soundSourcePreset=='mix':
			self.sound_dim = (1, 100, 40)  # sound matrix dimension (1, frames, numFeat)
			self.soundSource = {'dataset': ['GoogleCommand', 'UrbanSound'],
								'items': {'GoogleCommand':['house', 'tree', 'bird', 'dog'],
										  'UrbanSound':['jackhammer', None, None, 'dog_bark'] },
								'size': {'GoogleCommand': [25, 50, 50, 25], 'UrbanSound':[25, 0, 0, 25]},
								'train_test': 'test',
								}
		elif self.soundSourcePreset=='normal':
			self.sound_dim = (1, 100, 40)  # sound matrix dimension (1, frames, numFeat)
			self.soundSource = {'dataset': ['GoogleCommand'],
								'items': {'GoogleCommand': ['zero', 'one', 'two', 'three']},
								'size': {'GoogleCommand': [50, 50, 50, 50]},
								'train_test': 'test',
								}
		elif self.soundSourcePreset=='FSC':
			self.sound_dim = (1, 600, 40)  # sound matrix dimension (1, frames, numFeat)
			self.soundSource=self.soundSource = {'dataset': ['FSC'],
							'dataset_path': os.path.join(self.commonMediaPath, 'FSC'),
							'csv': 'train_data.csv',
							'max_sound_dur':6.,
							'size': {'FSC': [1000,1000,1000,1000]},
							'objs':['lights', 'lights', 'music', 'music', 'lamp', 'lamp'],
							'act': ['activate', 'deactivate', 'activate', 'deactivate', 'activate', 'deactivate'],
							'locations': ['none', ],
							'train_test': 'train',
							}

		self.mediaPath = os.path.join('..', "Envs", "pybullet", "kuka", "media")  # objects' model
		self.envFolder = os.path.join('pybullet', 'kuka')

		self.render = True
		self.frameSkip = 16
		self.rayHitColor = [1, 0, 0]
		self.rayMissColor = [0, 1, 0]

		# robot configuration
		self.robotName = 'base_link'
		self.robotScale = 1
		self.endEffectorHeight=0.22

		self.selfCollision= True
		self.endEffectorIndex= 6  # we mainly control this joint for position
		self.positionControlMaxForce=500
		self.positionControlPositionGain = 0.03 #0.03
		self.positionControlVelGain=1.0 #1.0
		self.changeDynamics=False
		self.fingerAForce=2
		self.fingerBForce=2
		self.fingerTipForce= 2

		# inverse kinematics settings
		self.ik_useNullSpace= True
		self.ik_useOrientation=True
		self.ik_ll = [-.967, -2, -2.96, 0.19, -2.96, -2.09, -3.05] # lower limits for null space
		self.ik_ul = [.967, 2, 2.96, 2.29, 2.96, 2.09, 3.05] # upper limits for null space
		self.ik_jr = [5.8, 4, 5.8, 4, 5.8, 4, 6] # joint ranges for null space
		self.ik_rp = [0, 0, 0, 0.5 * np.pi, 0, -np.pi * 0.5 * 0.66, 0] # restposes for null space
		self.ik_jd = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] # joint damping coefficents

		# robot camera
		self.robotCamOffset = 0  # it is used to adjust the near clipping plane of the camera
		self.robotCamRenderSize = (75, 100, 3)  # simulation render (height, width, channel)
		self.robotFov = 48.8

		# pybullet debug GUI viewing angle and distance
		self.debugCam_dist = 1.0
		self.debugCam_yaw = 90
		self.debugCam_pitch = -30

		# env control
		self.ifReset = True  # if you want to reset the scene after an episode ends
		self.domainRandomization = False  # True if you want to randomize textures for the wall and objects
		self.numTexture = 700  # number of texture to load. Bigger number will take more memory
		self.record = False  # save important data of an episode.
		self.loadAction = False  # load action from file so that the neural network decision is not used
		# recorder load action from location
		self.loadActionFile = os.path.join('..','libs', 'PPO', 'savedModel', 'episodeRecoder', 'ep' + str(2), 'action.csv')
		# recorder save information to this location
		self.recordSaveDir = os.path.join('..','data', 'episodeRecord', 'savedModel', 'episodeRecoder', 'test')

		self.episodeImgSaveDir = os.path.join('..','data', 'episodeRecord',
											  'tempImgs')  # output kuka camera images to this location
		self.episodeImgSaveInterval = -1  # -1 for not saving. Save an episode of turtlebot camera images every imgSaveInterval episode
		self.episodeImgSize = (224, 224, 3)  # (height, width, channel)

		# pretext env configuration
		self.pretextEnvName = 'kuka-pretext-v3'
		self.pretextActionDim = (2,)
		self.pretextEnvMaxSteps = 10  # the length of an episode
		self.pretextEnvSeed = 977
		self.pretextNumEnvs = 4 if not self.render else 1  # number of Envs to collect data

		# pretext robot control
		self.pretextRobotControl = 'position'

		# pretext task configuration
		self.pretextTrain = True  # collect data and perform training
		self.pretextCollection = False  # if False, data is assumed to be successfully collected and saved on disk
		self.pretextManualCollect=False
		self.pretextManualControl=False
		self.pretextDataDir = ['path/to/pretextDataDir']
		self.pretextCollectNum = [100, 100, 100, 100, 1400]
		self.pretextDataFileLoadNum = ['all', 'all', 'all']
		self.pretextDataset=RSI3Dataset
		self.pretextModelFineTune = True
		self.pretextModel = RSI3PretextNet
		self.pretextModelSaveDir = os.path.join('..','data', 'pretext_model','RSI3', 'Kuka_RSI3_0123_dynamics_perception')
		self.pretextModelLoadDir = os.path.join('..','data', 'pretext_model','RSI3', 'Kuka_RSI3_0123_dynamics_perception', '29.pt') # mainly used when self.RLTrain=False
		self.pretextModelSaveInterval = 10
		self.pretextDataNumWorkers = 4
		# the total number of episodes=pretextNumEnvs*pretextDataEpisode*pretextDataNumFiles
		self.pretextDataEpisode = 500  # each env will collect this number of episode
		self.pretextDataNumFiles = 20  # this number of pickle files will be generated
		self.pretextTrainBatchSize = 128
		self.pretextTestBatchSize = 128
		self.pretextLR = 1e-3
		self.pretextAdamL2 = 1e-6
		self.pretextLRStep = 'step'  # choose from ["cos", "step", "none"]
		self.pretextEpoch = 30
		self.pretextLRDecayEpoch = [25, 30]  # milestones for learning rate decay
		self.pretextLRDecayGamma = 0.2  # multiplicative factor of learning rate decay
		self.representationDim = 3
		self.representationTau=0.3
		self.pretextEmptyCenter=True
		self.representationUseLabels=False if self.pretextModelFineTune else True

		self.pretextTrainLinear = False
		self.pretextLinearEpoch = 40
		self.pretextLinearLR = 5e-3
		self.pretextTestMethod = 'linear'
		self.pretextLinearImgModel = representationImgLinear
		self.pretextLinearSoundModel = representationSoundLinear

		self.plotRepresentation = 50  # plot the representation space every this number of epoch, -1 for not plotting
		self.plotRepresentationOnly=False # plot representation only and exit when pretextTrain, RLManualControl, RLTrain = False
		self.plotNumBatch = 10
		self.plotFunc=plotRepresentationRSI3
		self.annotateLastBatch = False
		self.plotRepresentationExtra = False  # draw datapoint for images in episodeImgSaveDir or sound in commonMedia
		self.plotExtraPath = os.path.join('..','data', 'episodeRecord', 'extra')

		# pretrain network
		self.usePretrainedModel = False
		self.trainPretrainedModel = False
		self.freezePretrainedModel=False
		self.pretrainedModelLoadDir = os.path.join('..','data', 'pretext_model', 'Kuka_0123_pretrain', '29.pt')

		# RL env configuration
		self.RLEnvMaxSteps = 100  # the max number of actions (decisions) for an episode. Time horizon N.
		self.RSI_ver = 3  # the version of robot sound interpretation

		self.RLEnvName = 'kuka-RL-v3'
		self.RLActionDim = (2,)
		self.RLEnvSeed = 966
		self.RLNumEnvs = 8 if not self.render else 1
		self.RLRewardSoundSound=False

		# RL robot control
		self.RLRobotControl = 'position'

		# RL task configuration
		self.RLManualControl = True
		self.RLTrain =False
		self.RLRealTimePlot=True

		self.RLPolicyBase = 'kuka_RSI3'
		self.calcMedoids = False
		self.RLTask='approach'
		self.numBatchMedoids=100 # use numBatchMedoids*pretextTrainBatchSize to approximate medoids
		self.RLGamma = 0.99
		self.RLRecurrentPolicy = True
		self.ppoClipParam = 0.2
		self.ppoEpoch = 4

		self.ppoNumMiniBatch = 2 if not self.render else 1
		self.ppoValueLossCoef = 0.5
		self.ppoEntropyCoef = 0.01
		self.ppoUseGAE = True
		self.ppoGAELambda = 0.95
		self.RLLr = 8e-6 # was 8e-6
		self.RLEps = 1e-5
		self.RLMaxGradNorm = 0.5
		self.ppoNumSteps = 100

		self.RLTotalSteps = 1e6
		self.RLModelSaveInterval = 200
		self.RLObsIgnore = {'current_sound', 'goal_sound',
							'goal_sound_label', }  # the observation name that will be ignored for RL training
		self.RLLogInterval = 100
		self.RLModelFineTune=True
		self.RLModelSaveDir = os.path.join('..','data', 'RL_model', 'Kuka_RSI3_0123_dynamics_perception')
		self.RLModelLoadDir = os.path.join('path/to/RLModelLoadDir')
		self.RLUseProperTimeLimits = False
		self.RLRecurrentSize = 512
		self.RLRecurrentInputSize = 128
		self.RLActionHiddenSize = 128

		# update representation
		self.pretextModelUpdateInterval = np.inf
		self.pretextModelUpdateEpoch = 5
		self.pretextModelUpdateLR = self.pretextLR * self.pretextLRDecayGamma * 0.2
		self.pretextModelUpdatePairProb = 2  # 0.97
		self.pretextModelUpdateDataDir = os.path.join('..', 'data', 'pretext_training', 'TurtleBot_update')

		# test
		self.success_threshold=50

		self.hierarchy=False
		self.skillInfos = [
			{'path': os.path.join('..','data', 'RL_model', 'Kuka_RSI3_0123_dynamics_perception', '00600.pt'),
			 'actionDim': 2, },
		]

		# checking configuration and output errors or warnings
		print("######Configuration Checking######")
		if self.RLTrain and self.RLManualControl:
			raise Exception('self.RLTrain and self.RLManualControl cannot be both True')

		if self.RLTrain:
			if self.soundSource['train_test']=='test':
				warnings.warn("You are using the test set for training")

		if not self.RLTrain:
			if self.soundSource['train_test']=='train':
				warnings.warn("You are using the train set for testing")

		if 0<self.episodeImgSaveInterval<5:
			warnings.warn("You may save the episode image too frequently")
		print("##################################")
