from tqdm import trange
import copy
import pickle
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import os
from Envs.vec_env.envs import make_vec_envs
import time
from utils import get_scheduler, medoids_with_ground_truth
from dataset import loadEnvData
from shutil import copyfile
import torch.optim as optim
from cfg import main_config, gym_register


from RSI2.pretext_RSI2 import collectPretextData
from dataset import RSI1Dataset, RSI2Dataset


def pretrainRSI1(model, epoch, lr, start_ep=0, cfgName=None):
	print('Begin pre-training')
	# load data
	if cfgName not in ['KukaConfig', 'KinovaGen3Config']:
		datasetType=RSI1Dataset
	else:
		datasetType=RSI2Dataset # for kuka and kinova, using RSI2 dataset is OK. There is no need to collect new set
	data_generator, ds = loadEnvData(data_dir=config.pretextDataDir,
									 config=config,
									 batch_size=config.pretextTrainBatchSize,
									 shuffle=True,
									 num_workers=config.pretextDataNumWorkers, # change it to 0 if hangs
									 drop_last=True,
									 loadNum=config.pretextDataFileLoadNum,
									 dtype=datasetType
									 )

	if not os.path.isdir(config.pretextModelSaveDir):
		os.makedirs(config.pretextModelSaveDir)

	model.train()

	optimizer = optim.Adam(filter(lambda parameters: parameters.requires_grad, model.parameters()),
						   lr=lr,
						   weight_decay=config.pretextAdamL2)

	scheduler = get_scheduler(config, optimizer)

	if cfgName not in ['KukaConfig', 'KinovaGen3Config']:
		inSight_criterion = torch.nn.BCEWithLogitsLoss()
		exi_criterion = torch.nn.BCEWithLogitsLoss()

	sound_aux_criterion=torch.nn.CrossEntropyLoss()

	def save_checkPoints():
		if (ep + 1) % config.pretextModelSaveInterval == 0 or ep + 1 == epoch:
			fname = os.path.join(config.pretextModelSaveDir, str(start_ep + ep) + '.pt')
			if not os.path.exists(config.pretextModelSaveDir):
				os.makedirs(config.pretextModelSaveDir)
			torch.save(model.state_dict(), fname, _use_new_zipfile_serialization=False)
			print('Model saved to ' + fname)

	# main training loop
	for ep in trange(epoch, position=0):
		if cfgName not in ['KukaConfig', 'KinovaGen3Config']:

			inSight_loss_ep = []
			exi_loss_ep = []
			soundAux_loss_ep = []
			for n_iter, (image, goal_sound, goal_sound_label, _, inSight_label, exi_label) in enumerate(data_generator):
				model.zero_grad()
				optimizer.zero_grad()

				pred_sound, pred_inSight, pred_exi = model(image.cuda(), goal_sound.float().cuda())
				inSight_label=inSight_label.float().cuda()
				goal_sound_label=goal_sound_label.float().cuda()
				exi_label = exi_label.float().cuda()

				inSightLoss = inSight_criterion(pred_inSight, inSight_label)
				exiLoss=exi_criterion(pred_exi, exi_label)
				soundLoss=sound_aux_criterion(pred_sound, torch.max(goal_sound_label, 1)[1])
				loss=config.RLAuxInSightLossWeight*inSightLoss+\
							   config.RLAuxExiLossWeight*exiLoss+\
							   config.RLAuxSoundLossWeight*soundLoss

				loss.backward()
				optimizer.step()

				inSight_loss_ep.append(inSightLoss.item())
				exi_loss_ep.append(exiLoss.item())
				soundAux_loss_ep.append(soundLoss.item())

			if config.pretextLRStep == "step":
				scheduler.step()

			save_checkPoints()
			print('inSight_loss', np.sum(inSight_loss_ep) / len(inSight_loss_ep))
			print('exi_loss', np.sum(exi_loss_ep) / len(exi_loss_ep))
			print('sound_loss', np.sum(soundAux_loss_ep) / len(soundAux_loss_ep))
		else:
			soundAux_loss_ep = []
			for n_iter, (image, sound_positive, sound_negative, gt) in enumerate(data_generator):
				model.zero_grad()
				optimizer.zero_grad()

				# remove the empty sound in a batch
				idx=torch.where(gt!=4)[0]
				gtt=gt[idx]

				sound_label = gtt.long().to(device).squeeze()

				sound_pred = model(sound_positive[idx].float().to(device).squeeze(1))

				loss = sound_aux_criterion(sound_pred, sound_label)

				loss.backward()
				optimizer.step()

				soundAux_loss_ep.append(loss.item())

			if config.pretextLRStep == "step":
				scheduler.step()
			save_checkPoints()
			print('sound_loss', np.sum(soundAux_loss_ep) / len(soundAux_loss_ep))

	print('Complete')
	model.eval()


if __name__ == '__main__':
	config = main_config()
	gym_register(config)

	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
	print("Using device:", device)
	cudnn.benchmark = True

	if config.pretextCollection:
		collectPretextData(config)
		print('Data Collection Complete')

	if config.pretextTrain:
		pretrainModel = config.pretextModel(config).cuda()
		if config.pretextModelFineTune:
			weight_path = config.pretextModelLoadDir
			pretrainModel.load_state_dict(torch.load(weight_path), strict=False)
			print('Load weights for fine tune model from', weight_path)

		if not os.path.exists(config.pretextModelSaveDir):
			os.makedirs(config.pretextModelSaveDir)
		copyfile(os.path.join('..', 'Envs', config.envFolder, 'RSI1', 'config.py'),
				 os.path.join(config.pretextModelSaveDir, 'config.py'))

		pretrainRSI1(model=pretrainModel, epoch=config.pretextEpoch, lr=config.pretextLR, start_ep=0, cfgName=config.name)
