import numpy as np
import os
from models.pretext.kinova_pretext_model import KinovaPretextNetwork
import warnings


class KinovaGen3Config(object):
	def __init__(self):
		"""
		Put all env settings here
		"""
		self.name=self.__class__.__name__
		# common env configuration
		self.objList=['block', 'block', 'block', 'block']
		self.taskNum=len(self.objList)
		# hide the obj so that the camera can't see
		# mode can be 'random', 'fix', 'none'. If 'random', it will hide 'hideNum' blocks and 'hideIdx' is irrelevant
		# If 'fix', it will hide the blocks indicated by the list 'hideIdx' , 'hideNum' is irrelevant
		# if 'none', hide no blocks
		self.hideObj={'mode': 'none', 'hideNum':1, 'hideIdx':[0,1,2,3]}
		self.objInterval=0.1 # the distance between two objects
		self.objZ=0
		self.tableZ=-0.65
		# object and end-effector location range
		self.xMax=0.55
		self.xMin=0.3
		self.yMax=0.2
		self.yMin=-0.35
		self.img_dim = (3, 96, 96)  # (channel, image_height, image_width)
		self.sound_dim = (1, 100, 40)  # sound matrix dimension (1, frames, numFeat)

		self.soundSourcePreset='normal'
		if self.soundSourcePreset=='mix':
			self.soundSource = {'dataset': ['GoogleCommand', 'UrbanSound'],
								'items': {'GoogleCommand':['house', 'tree', 'bird', 'dog'],
										  'UrbanSound':['jackhammer', None, None, 'dog_bark'] },
								'size': {'GoogleCommand': [25, 50, 50, 25], 'UrbanSound':[25, 0, 0, 25]},
								'train_test': 'test',
								}
		elif self.soundSourcePreset=='normal':
			self.soundSource = {'dataset': ['GoogleCommand'],
								'items': {'GoogleCommand': ['zero', 'one', 'two', 'three']},
								'size': {'GoogleCommand': [50, 50, 50, 50]},
								'train_test': 'test',
								}
		self.commonMediaPath = os.path.join('..','commonMedia')
		self.mediaPath = os.path.join('..', "Envs", "pybullet", "kinova_gen3", "media")  # objects' model
		self.envFolder = os.path.join('pybullet', 'kinova_gen3')
		self.render = True
		self.frameSkip = 16
		self.rayHitColor = [1, 0, 0]
		self.rayMissColor = [0, 1, 0]

		# simulated robot configuration
		self.robotName = 'base_link'
		self.robotScale = 1
		self.endEffectorHeight=0.22
		self.selfCollision= True
		self.endEffectorIndex= 7  # we mainly control this joint for position
		self.positionControlMaxForce=500
		self.positionControlPositionGain = 0.04
		self.positionControlVelGain=1.
		self.fingerAForce=2
		self.fingerBForce=2
		self.fingerTipForce= 2

		# inverse kinematics settings
		self.ik_useNullSpace= True
		self.ik_useOrientation=True
		self.ik_ll = [-2.967, -2, -2.96, -2.29, -2.96, -2.09, -3.05] # lower limits for null space
		self.ik_ul = [2.967, 2, 2.96, 2.29, 2.96, 2.09, 3.05] # upper limits for null space
		self.ik_jr = [5.8, 4, 5.8, 4, 5.8, 4, 6] # joint ranges for null space
		self.ik_rp = [0.3, 0.68, 0.0, 1.57, 0.0, 0.77, 1.57] # restposes for null space
		self.ik_jd = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,0.1,0.1,0.1,0.1,0.1] # joint damping coefficents

		self.externalCamRenderSize = (75, 100, 3)  # simulation render (height, width, channel)
		self.externalCamFov = 58.0 # the camera vertical FoV
		self.externalCamAspect=4./3. # we will use 640 x 480 setting for the RealSense
		self.externalCamEyePosition=[0.9, -0.11, 0.35] # put your real camera at about [0.9, -0.11, 0.42]
		self.externalCamTargetPosition=[0.5,-0.11,0]

		# pybullet debug GUI viewing angle and distance
		self.debugCam_dist = 1.0
		self.debugCam_yaw = 90
		self.debugCam_pitch = -30

		# run on a real robot, external camera is RealSense D455
		self.realRobot = True
		self.ROSStepInterval = 0.2

		# env control
		self.ifReset = True  # if you want to reset the scene after an episode ends
		self.domainRandomization = False  # True if you want to randomize textures for sim2real transfer
		self.numTexture = 700  # number of texture to load. Bigger number will take more memory
		self.record = False  # save important data of an episode.
		self.loadAction = False  # load action from file so that the neural network decision is not used
		# recorder load action from location
		self.loadActionFile = os.path.join('..','data', 'episodeRecord', 'simEp', 'ep' + str(4), 'desired_position.csv')
		# recorder save information to this location
		self.recordSaveDir = os.path.join('..','data', 'episodeRecord', 'realEp')
		self.recorderInitNum=0
		self.episodeImgSaveDir = os.path.join('..','data', 'episodeRecord',
											  'tempImgs')  # output kuka camera images to this location
		self.episodeImgSaveInterval = -1  # -1 for not saving. Save an episode of turtlebot camera images every imgSaveInterval episode
		self.episodeImgSize = (224, 224, 3)  # (height, width, channel)

		# pretext env configuration
		self.pretextEnvName = 'Kinova-ROS-pretext-v2'
		self.pretextActionDim = (2,)
		self.pretextEnvMaxSteps = 10  # the length of an episode
		self.pretextEnvSeed = 11
		self.pretextNumEnvs = 4 if not self.render else 1  # number of Envs to collect data

		# pretext robot control
		self.pretextRobotControl = 'position'

		# pretext task configuration
		self.pretextTrain = True  # collect data and perform training
		self.pretextCollection = False  # if False, data is assumed to be successfully collected and saved on disk
		self.pretextManualCollect = True
		self.pretextManualControl=True
		self.pretextModelFineTune=True
		self.pretextDataDir = [os.path.join('..','data', 'pretext_training', 'Kinova_0123_realCollect')] 
		self.pretextDataFileLoadNum=['all']
		self.pretextCollectNum=[60, 60, 60, 60, 60]
		self.pretextModel = KinovaPretextNetwork
		self.pretextModelSaveDir = os.path.join('..','data', 'pretext_model', 'Kinova_0123_realCollect')
		self.pretextModelLoadDir = os.path.join('..','data', 'pretext_model', 'Kinova_0123_realCollect', '39.pt') # mainly used when self.RLTrain=False
		self.pretextModelSaveInterval = 10
		self.pretextDataNumWorkers = 4
		# the total number of episodes=pretextNumEnvs*pretextDataEpisode*pretextDataNumFiles
		self.pretextDataEpisode = 200  # each env will collect this number of episode
		self.pretextDataNumFiles = 20  # this number of pickle files will be generated
		self.pretextTrainBatchSize = 256
		self.pretextTestBatchSize = 128
		self.pretextLR = 1e-3
		self.pretextAdamL2 = 1e-6
		self.pretextLRStep = 'step'  # choose from ["cos", "step", "none"]
		self.pretextEpoch = 40
		self.pretextLRDecayEpoch = [30, 40]  # milestones for learning rate decay
		self.pretextLRDecayGamma = 0.2  # multiplicative factor of learning rate decay
		self.representationDim = 3
		self.tripletMargin = 1.0

		self.plotRepresentation = 50  # plot the representation space every this number of epoch, -1 for not plotting
		self.plotNumBatch = 20
		self.annotateLastBatch = False
		self.plotRepresentationExtra = False  # draw datapoint for images in episodeImgSaveDir or sound in commonMedia
		self.plotExtraPath = os.path.join('..','data', 'episodeRecord', 'extra')

		# RL env configuration
		self.RLEnvMaxSteps = 100  # the max number of actions (decisions) for an episode. Time horizon N.
		self.RSI_ver = 2
		self.RLEnvName = 'Kinova-ROS-RL-v2'
		self.RLActionDim = (2,)
		self.RLEnvSeed = 101
		self.RLNumEnvs = 8 if (not self.render and not self.realRobot) else 1

		# RL robot control
		self.RLRobotControl = 'position'

		# RL task configuration
		self.RLManualControl = False
		self.RLTrain = False
		self.RLRealTimePlot=False
		self.RLPolicyBase = 'armRobot'
		self.calcMedoids = False
		self.RLTask='approach'
		self.numBatchMedoids=100 # use numBatchMedoids*pretextTrainBatchSize to approximate medoids
		self.RLGamma = 0.99
		self.RLRecurrentPolicy = True
		self.ppoClipParam = 0.2
		self.ppoEpoch = 4
		self.ppoNumMiniBatch = 2 if (not self.render and not self.realRobot) else 1
		self.ppoValueLossCoef = 0.5
		self.ppoEntropyCoef = 0.01
		self.ppoUseGAE = True
		self.ppoGAELambda = 0.95
		self.RLLr = 8e-6
		self.RLEps = 1e-5
		self.RLMaxGradNorm = 0.5
		self.ppoNumSteps = self.RLEnvMaxSteps*2
		self.RLTotalSteps = 1e6
		self.RLModelSaveInterval = 10
		self.RLLogInterval = 10
		self.RLModelFineTune = True
		self.RLRewardSoundSound = False  # use the dot product between goal sound and current sound as reward
		self.RLModelSaveDir = os.path.join('..','data', 'RL_model', 'Kinova_0123_realCollect')
		self.RLModelLoadDir = os.path.join('..','data', 'RL_model', 'Kinova_0123_realCollect', '1', '00050.pt')
		self.RLUseProperTimeLimits = False
		self.RLRecurrentSize = 512
		self.RLRecurrentInputSize = 128
		self.RLActionHiddenSize = 128
		self.RLAuxSoundLossWeight = 1.
		self.RLAuxInSightLossWeight = 1.
		self.RLAuxExiLossWeight = 1.

		# update representation
		self.pretextModelUpdateInterval = np.inf
		self.pretextModelUpdateEpoch=5
		self.pretextModelUpdateLR=self.pretextLR*self.pretextLRDecayGamma*0.2
		self.pretextModelUpdatePairProb= 2
		self.pretextModelUpdateDataDir=os.path.join('..','data', 'pretext_training', 'Kinova_update')

		# test
		self.success_threshold=25

		# checking configuration and output errors or warnings
		print("######Configuration Checking######")
		if self.RLTrain and self.RLManualControl:
			raise Exception('self.RLTrain and self.RLManualControl cannot be both True')

		if self.RLTrain and self.record:
			warnings.warn("You are doing episode recording during training")

		if self.RLTrain or self.pretextTrain:
			if self.soundSource['train_test']=='test':
				warnings.warn("You are using the test set for training")

		if not self.RLTrain:
			if self.soundSource['train_test']=='train':
				warnings.warn("You are using the train set for testing")

		if -1<self.episodeImgSaveInterval<5:
			warnings.warn("You may save the episode image too frequently")
		print("##################################")
