from tqdm import trange
import copy
import pickle
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import os

import time
from utils import get_scheduler, medoids_with_ground_truth, project_to_representation
from dataset import loadEnvData
from shutil import copyfile
import torch.optim as optim
from cfg import main_config, gym_register
import gym
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import glob


def collectPretextData(config, fileName=None):
	print("Begin collecting...")
	###
	targetNum=config.pretextCollectNum 
	collectedNum=[0]*(config.taskNum+1)
	###
	# create parallel Envs
	from Envs.vec_env.envs import make_vec_envs
	envs = make_vec_envs(env_name=config.pretextEnvName,
						 seed=config.pretextEnvSeed,
						 num_processes=config.pretextNumEnvs,
						 gamma=None,
						 device=None,
						 randomCollect=True,
						 config=config)

	# collect data for pretext training
	observations = []
	# observations = observations + copy.deepcopy(Envs.reset())
	_ = envs.reset()
	observation=envs.unwrapped.obs_list
	for pairs in observation:
		if collectedNum[int(pairs['ground_truth'])] < targetNum[int(pairs['ground_truth'])]:
			observations = observations + [copy.deepcopy(pairs)]
			collectedNum[int(pairs['ground_truth'])] = collectedNum[int(pairs['ground_truth'])] + 1
	epoch=0
	while epoch <= config.pretextDataNumFiles:
		if epoch == config.pretextDataNumFiles and sum(collectedNum) < sum(targetNum):
			config.pretextDataNumFiles = config.pretextDataNumFiles+3
			print('Increase number of files')
		print("Number of pairs for each object", collectedNum)
		for episode in trange(config.pretextDataEpisode, position=0):

			for i in range(config.pretextEnvMaxSteps):
				if config.render:

					envs.render()
					if not config.pretextManualControl:
						time.sleep(2)
				action = [0]*config.pretextNumEnvs # dummy action. True random action is decided in env
				_, _, _, _ = envs.step(action)
				observation = envs.unwrapped.obs_list

				for pairs in observation:
					if collectedNum[int(pairs['ground_truth'])] < targetNum[int(pairs['ground_truth'])]:
						observations = observations + [copy.deepcopy(pairs)]
						collectedNum[int(pairs['ground_truth'])] = collectedNum[int(pairs['ground_truth'])] + 1

			if sum(collectedNum)==sum(targetNum):
				break

		# save observations as pickle files
		# observations is a list of dict [{'image':, 'sound_positive':, 'sound_negative':, 'ground_truth':}, ...]
		filePath = os.path.join(config.pretextDataDir[0], 'train')
		if not os.path.isdir(filePath):
			os.makedirs(filePath)
		if fileName is None:
			filePath = os.path.join(filePath, 'data_'+str(epoch)+'.pickle')
		else:
			filePath = os.path.join(filePath, fileName + '.pickle')
		with open(filePath, 'wb') as f:
			pickle.dump(observations, f, protocol=pickle.HIGHEST_PROTOCOL)
		observations.clear()

		if sum(collectedNum) == sum(targetNum):
			break

		epoch = epoch +1

	envs.close()
	return epoch

def manuallyCollectPretextData():
	env = gym.make(config.pretextEnvName)
	env.seed(0)

	pretextModel.load_state_dict(torch.load(config.pretextModelLoadDir))
	print('Load weights for pretextModel from', config.pretextModelLoadDir)

	observation = env.reset()
	while True:
		env.render()
		O, reward, done, info = env.step([0])
		with torch.no_grad():
			image_feat, goal_sound_feat, _ = \
				pretextModel(torch.from_numpy(O['image'][None, :] / 255.).float().to(device),
							 torch.from_numpy(O['sound_positive'][None, :]).float().to(device),
							 None)
		image_feat = image_feat.to('cpu').numpy()
		goal_sound_feat = goal_sound_feat.to('cpu').numpy()

		img_sound_dot=np.sum(image_feat[:, :config.representationDim] * goal_sound_feat, axis=1)
		print("embReward", img_sound_dot)


def trainRepresentation(model, epoch, lr, start_ep=0, plot=False):
	print('Begin representation training')
	# load data
	data_generator, ds = loadEnvData(data_dir=config.pretextDataDir,
									 config=config,
									 batch_size=config.pretextTrainBatchSize,
									 shuffle=True,
									 num_workers=config.pretextDataNumWorkers, # change it to 0 if hangs
									 drop_last=True,
									 loadNum=config.pretextDataFileLoadNum,
									 dtype=config.pretextDataset)

	if not os.path.isdir(config.pretextModelSaveDir):
		os.makedirs(config.pretextModelSaveDir)

	model.train()

	optimizer = optim.Adam(filter(lambda parameters: parameters.requires_grad, model.parameters()),
						   lr=lr,
						   weight_decay=config.pretextAdamL2)

	scheduler = get_scheduler(config, optimizer)
	criterion = torch.nn.TripletMarginLoss(margin=config.tripletMargin, p=2)
	norm_criterion = torch.nn.BCEWithLogitsLoss()

	loss_list=[]

	# main training loop
	for ep in trange(epoch, position=0):
		if config.plotRepresentation >= 0 and ep % config.plotRepresentation == 0 and ep > 0 and plot:
			model.eval()
			plotRepresentationRSI2(data_generator, device, model, config)
			model.train()

		loss_ep = []
		loss_ep_img_norm = []
		loss_ep_sound_norm = []

		for n_iter, (image, sound_positive, sound_negative, gt) in enumerate(data_generator):
			model.zero_grad()
			optimizer.zero_grad()
			image_feat, sound_feat_positive, sound_feat_negative,image_norm, sound_norm  = model(image.to(device),
																			 sound_positive.float().to(device),
																			 sound_negative.float().to(device), is_train=True)
			loss_triplet = criterion(image_feat, sound_feat_positive, sound_feat_negative)

			# calculate ground truth for the BCELoss
			gt_norm = torch.ones_like(gt)
			gt_norm[gt == config.taskNum] = 0.

			gt_norm = gt_norm.float().cuda()

			if config.pretextEmptyCenter:
				loss_img_norm = norm_criterion(image_norm, gt_norm)
				loss_sound_norm = norm_criterion(sound_norm, gt_norm)

				norm_weight=0.3
				loss = 4.0 * loss_triplet +  norm_weight* loss_img_norm + norm_weight * loss_sound_norm
				loss_ep_img_norm.append(loss_img_norm.item())
				loss_ep_sound_norm.append(loss_sound_norm.item())
			else:
				loss = 1.0 * loss_triplet

			loss.backward()
			optimizer.step()
			loss_ep.append(loss_triplet.item())

		if config.pretextLRStep == "step":
			scheduler.step()

		if (ep + 1) % config.pretextModelSaveInterval == 0 or ep+1==epoch:
			fname = os.path.join(config.pretextModelSaveDir, str(start_ep+ep) + '.pt')
			if not os.path.exists(config.pretextModelSaveDir):
				os.makedirs(config.pretextModelSaveDir)
			torch.save(model.state_dict(), fname, _use_new_zipfile_serialization=False)
			print('Model saved to ' + fname)

		avg_loss = np.sum(loss_ep) / len(loss_ep)
		loss_list.append(avg_loss)

		if config.pretextEmptyCenter:
			avg_img_norm=np.sum(loss_ep_img_norm) / len(loss_ep_img_norm)
			avg_sound_norm = np.sum(loss_ep_sound_norm) / len(loss_ep_img_norm)
			print('average loss',  avg_img_norm, avg_loss, avg_sound_norm)
		else:
			print('average loss', avg_loss)

	if config.pretextTrain:
		df = pd.DataFrame({'avg_loss': loss_list})
		save_path = os.path.join(config.pretextModelSaveDir, 'progress.csv')
		df.to_csv(save_path, mode='w', header=True, index=False)
		print('results saved to', save_path)
	print('Pretext Training Complete')
	model.eval()
	if config.plotRepresentation >= 0 and plot:
		plotRepresentationRSI2(data_generator, device, model, config)


def plotRepresentationRSI2(generator, torch_device, net, config, **kwargs):
	assert (config.representationDim<4)
	fig = plt.figure()
	if config.representationDim==3: # 3d scatter plot
		ax = fig.add_subplot(111, projection='3d')
		ax.set_zlabel('Z Label')
	else:
		ax = fig.add_subplot(111)

	ax.set_xlabel('X Label')
	ax.set_ylabel('Y Label')
	colors=['r','y','b','g','tab:purple','c']


	with torch.no_grad():
		for n, (img, sp, sn, gt) in enumerate(generator):
			if n > config.plotNumBatch:
				break  # show only config.plotNumBatch batch size data points on the plot
			else:
				if n == config.plotNumBatch and config.annotateLastBatch:
					# save this batch of image to config.episodeImgSaveDir with ID
					for j, pic in enumerate(img):
						imgSave = cv2.resize(np.transpose(pic.numpy(), (1, 2, 0)),
											 (config.episodeImgSize[1], config.episodeImgSize[0]))
						if config.episodeImgSize[2] == 3:
							imgSave = cv2.cvtColor(imgSave, cv2.COLOR_RGB2BGR)
						fileName = 'lastBatch' + str(j) + '.jpg'
						cv2.imwrite(os.path.join(config.episodeImgSaveDir, fileName), imgSave)

				features = net(img.to(torch_device), sp.float().to(torch_device), sn.float().to(torch_device))
				img_feat, sp_feat=features['image_feat'].to('cpu'), features['sound_feat_positive'].to('cpu')
				for j in range(len(colors)):
					idx=np.where(gt==j)[0]
					if idx.size!=0:
						np_img_feat=img_feat[idx].numpy()
						np_sp_feat = sp_feat[idx].numpy()
						if np_img_feat.shape[1]==2:
							ax.scatter(np_img_feat[:, 0], np_img_feat[:, 1],  marker='o', color=colors[j])
							ax.scatter(np_sp_feat[:, 0], np_sp_feat[:, 1],  marker='v', color=colors[j])
						else:
							ax.scatter(np_img_feat[:,0], np_img_feat[:,1], np_img_feat[:,2], marker='o', color=colors[j], s=20, alpha=0.2)
							ax.scatter(np_sp_feat[:,0], np_sp_feat[:,1], np_sp_feat[:,2], marker='v', color=colors[j], s=20, alpha=0.2)

							if n==config.plotNumBatch and config.annotateLastBatch:
								for k, txt in enumerate(idx):
									# annotate the points with index
									ax.text(np_img_feat[k, 0], np_img_feat[k, 1], np_img_feat[k, 2], str(txt))

		if config.plotRepresentationExtra:
			# load images
			imageList=[]
			fileList = [os.path.basename(x) for x in glob.glob(os.path.join(config.plotExtraPath, '*.jpg'))]
			for j, filePath in enumerate(sorted(fileList, key=lambda x: int(x[0:-4]))):
				image_in=cv2.cvtColor(cv2.imread(os.path.join(config.plotExtraPath,filePath)), cv2.COLOR_BGR2RGB)
				image_in=cv2.resize(image_in, (config.img_dim[2], config.img_dim[1]))
				imageList.append(np.transpose(image_in, (2, 0, 1)))

			if imageList:
				img=torch.from_numpy(np.array(imageList))
				sp=torch.zeros((len(imageList), config.sound_dim[0], config.sound_dim[1], config.sound_dim[2]))
				sn=torch.zeros((len(imageList), config.sound_dim[0], config.sound_dim[1], config.sound_dim[2]))
				features = net((img/255.).float().to(torch_device),
												   sp.to(torch_device),
												   sn.to(torch_device))
				img_feat = features[0].to('cpu')
				np_img_feat = img_feat.numpy()
				imageStr = np.arange(0, len(imageList))
				if np_img_feat.shape[1] == 2:
					ax.scatter(np_img_feat[:, 0], np_img_feat[:, 1], marker='*', color='k', s=20)
					for j, txt in enumerate(imageStr):
						ax.annotate(str(txt),(np_img_feat[j, 0], np_img_feat[j, 1]))
				else:
					ax.scatter(np_img_feat[:, 0], np_img_feat[:, 1], np_img_feat[:, 2], marker='*', color='k', s=20)
					for j, txt in enumerate(imageStr):
						ax.text(np_img_feat[j, 0], np_img_feat[j, 1], np_img_feat[j, 2], str(txt))

		if config.representationDim == 3:
			ax.plot([-1, 1], [0, 0], [0, 0], color="k", alpha=0.2, linewidth=1)
			ax.plot([0, 0], [-1, 1], [0, 0], color="k", alpha=0.2, linewidth=1)
			ax.plot([0, 0], [0, 0], [-1, 1], color="k", alpha=0.2, linewidth=1)
			ax.set_axis_off()
			u, v = np.mgrid[0:2 * np.pi:40j, 0:np.pi:20j]
			x = np.cos(u) * np.sin(v)
			y = np.sin(u) * np.sin(v)
			z = np.cos(v)
			ax.plot_wireframe(x, y, z, color="lightgray", alpha=0.2, linewidth=1)

		checkpoint_number= os.path.splitext(os.path.basename(config.pretextModelLoadDir))[0]
		medoid_path = os.path.join(os.path.dirname(config.pretextModelLoadDir), checkpoint_number+'_medoids.pickle')

		plt.show()
		return fig, ax


def trainLinear(model, epoch, lr, start_ep, config):
	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
	# evaluate the representation (model) by training a linear layer using softmax cross entropy
	print('Begin training a linear layer for representation evaluation')
	# load data
	data_generator, ds = loadEnvData(data_dir=config.pretextDataDir,
									 config=config,
									 batch_size=config.pretextTrainBatchSize,
									 shuffle=True,
									 num_workers=config.pretextDataNumWorkers,  # change it to 0 if hangs
									 drop_last=True,
									 loadNum=config.pretextDataFileLoadNum,
									 dtype=config.pretextDataset)

	model.eval() # model is your representation which should not be updated here
	linear_model_img=config.pretextLinearImgModel(config).cuda().train() # the linear model you want to train
	linear_model_sound = config.pretextLinearSoundModel(config).cuda().train()  # the linear model you want to train

	optimizer_img = optim.Adam(filter(lambda parameters: parameters.requires_grad, linear_model_img.parameters()),
						   lr=lr,
						   weight_decay=config.pretextAdamL2)

	optimizer_sound = optim.Adam(filter(lambda parameters: parameters.requires_grad, linear_model_sound.parameters()),
						   lr=lr,
						   weight_decay=config.pretextAdamL2)

	scheduler_img = get_scheduler(config, optimizer_img)
	scheduler_sound = get_scheduler(config, optimizer_sound)
	criterion = torch.nn.CrossEntropyLoss()

	feat_point=project_to_representation(data_generator, config, model, device, test_method='linear')
	from dataset import featureDataset

	feature_dataset = featureDataset(feat_point, config)
	generator = torch.utils.data.DataLoader(feature_dataset,
											batch_size=config.pretextTrainBatchSize,
											shuffle=True,
											num_workers=config.pretextDataNumWorkers,  # change it to 0 if hangs
											pin_memory=True,
											drop_last=True)
	checkpoint_number = os.path.splitext(os.path.basename(config.pretextModelLoadDir))[0]
	# main training loop
	for ep in trange(epoch, position=0):

		loss_ep_img = []
		loss_ep_sound = []

		for n, (image_data, sound_data, image_gt, sound_gt) in enumerate(generator):
			#image
			linear_model_img.zero_grad()
			optimizer_img.zero_grad()
			image_out = linear_model_img(image_data.float().cuda())
			loss_img = criterion(image_out, image_gt.long().cuda())
			loss_img.backward()
			optimizer_img.step()
			loss_ep_img.append(loss_img.item())

			# sound
			linear_model_sound.zero_grad()
			optimizer_sound.zero_grad()
			sound_out = linear_model_sound(sound_data.float().cuda())
			loss_sound = criterion(sound_out, sound_gt.long().cuda())
			loss_sound.backward()
			optimizer_sound.step()
			loss_ep_sound.append(loss_sound.item())


		if config.pretextLRStep == "step":
			scheduler_img.step()
			scheduler_sound.step()

		if (ep + 1) % config.pretextModelSaveInterval == 0 or ep + 1 == epoch:
			fname_img = os.path.join(config.pretextModelSaveDir, str(checkpoint_number)+'_linear_img' + str(ep) + '.pt')
			fname_sound = os.path.join(config.pretextModelSaveDir, str(checkpoint_number)+'_linear_sound'+str(ep) +'.pt')
			if not os.path.exists(config.pretextModelSaveDir):
				os.makedirs(config.pretextModelSaveDir)
			torch.save(linear_model_img.state_dict(), fname_img, _use_new_zipfile_serialization=False)
			torch.save(linear_model_sound.state_dict(), fname_sound, _use_new_zipfile_serialization=False)
			print('Model saved to ' + fname_img)
			print('Model saved to ' + fname_sound)

		avg_loss_img = np.sum(loss_ep_img) / len(loss_ep_img)
		avg_loss_sound = np.sum(loss_ep_sound) / len(loss_ep_sound)
		print('average img loss', avg_loss_img, 'average souond loss', avg_loss_sound)

	print('Linear Training Complete')


def testRepresentation(config, pretextModel):
	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
	if config.pretextTestMethod == 'medoid':
		# test the performance of our representation
		# each data point will be assigned to the nearest centroid
		# the accuracy is defined as the portion that the data point is assigned to the correct centroid
		head, tail = os.path.split(config.pretextModelLoadDir)
		medoidFilePath = os.path.join(head, os.path.splitext(tail)[0]+'_medoids.pickle')
		if os.path.exists(medoidFilePath):
			print('Found medoids.pickle, load directly')
			with open(medoidFilePath, 'rb') as fp:
				medoids = pickle.load(fp)
			print(medoids)
		else:  # calculate the medoids

			data_generator, ds = loadEnvData(data_dir=config.pretextDataDir,
											 config=config,
											 batch_size=config.pretextTrainBatchSize,
											 shuffle=True,
											 num_workers=config.pretextDataNumWorkers,
											 drop_last=True,
											 loadNum=config.pretextDataFileLoadNum,
											 dtype=config.pretextDataset)

			medoids = medoids_with_ground_truth(data_generator=data_generator,
												torch_device=device,
												model=pretextModel,
												config=config,
												)
			# save it for the next time
			with open(medoidFilePath, 'wb') as fp:
				pickle.dump(medoids, fp)
				print('medoid.pickle saved to', medoidFilePath)
				print(medoids) 
				exit("Change the config from train to test")

		# load test set
		data_generator, ds = loadEnvData(data_dir=config.pretextDataDir,
										 config=config,
										 batch_size=config.pretextTrainBatchSize,
										 shuffle=True,
										 num_workers=config.pretextDataNumWorkers,
										 drop_last=True,
										 loadNum=config.pretextDataFileLoadNum,
										 dtype=config.pretextDataset,
										 train_test='train')

		feat_point=project_to_representation(data_generator, config, pretextModel, device, test_method='medoid')

		acc_list = []

		for i in range(config.taskNum+1):
			X = feat_point[i]
			if i==config.taskNum and config.RSI_ver==3:
				count=0
				for f in X:
					if np.sum(f)<0.5:
						count=count+1
			else:

				repeated_feat=X[:,None,:] # insert a new axis
				repeated_feat = np.repeat(repeated_feat, repeats=config.taskNum+1, axis=1)
				cos_sim = np.sum(repeated_feat * medoids, axis=-1)

				idx = np.argmax(cos_sim, axis=-1)
				count = len(np.where(idx == i)[0])
			acc_list.append(count / X.shape[0])
			print('task id', i, 'has accuracy', acc_list[-1])
		print("Mean accuracy", np.mean(acc_list))

	elif config.pretextTestMethod == 'linear':
		# load test set
		data_generator, ds = loadEnvData(data_dir=config.pretextDataDir,
										 config=config,
										 batch_size=config.pretextTrainBatchSize,
										 shuffle=True,
										 num_workers=config.pretextDataNumWorkers,
										 drop_last=True,
										 loadNum=config.pretextDataFileLoadNum,
										 dtype=config.pretextDataset,
										 train_test='train') 
		feat_point = project_to_representation(data_generator, config, pretextModel, device, test_method='linear')

		checkpoint_number= os.path.splitext(os.path.basename(config.pretextModelLoadDir))[0]

		# load linear layer for img
		weight_path_img = os.path.join(os.path.dirname(config.pretextModelLoadDir),str(checkpoint_number)+'_linear_img'+str(config.pretextLinearEpoch-1)+'.pt')
		linearModelImg = config.pretextLinearImgModel(config)
		linearModelImg.load_state_dict(torch.load(weight_path_img))
		linearModelImg.cuda().eval()
		print('Load weights for linearModelImg from', weight_path_img)

		weight_path_sound = os.path.join(os.path.dirname(config.pretextModelLoadDir),str(checkpoint_number)+'_linear_sound'+str(config.pretextLinearEpoch-1)+'.pt')
		linearModelSound = config.pretextLinearSoundModel(config)
		linearModelSound.load_state_dict(torch.load(weight_path_sound))
		linearModelSound.cuda().eval()
		print('Load weights for linearModelSound from', weight_path_sound)

		acc_list_img=[]
		acc_list_sound=[]
		for i in range(config.taskNum+1):
			X = feat_point[i]
			# img
			with torch.no_grad():
				_,predicted_label= torch.max(linearModelImg(torch.from_numpy(X['img']).float().cuda()), 1)
			predicted_label=predicted_label.cpu().numpy()
			count = len(np.where(predicted_label == i)[0])
			acc_list_img.append(count / X['img'].shape[0])
			print('task id', i, 'has img accuracy', acc_list_img[-1])

			# sound
			with torch.no_grad():
				_,predicted_label= torch.max(linearModelSound(torch.from_numpy(X['sound']).float().cuda()), 1)
			predicted_label=predicted_label.cpu().numpy()
			count = len(np.where(predicted_label == i)[0])
			acc_list_sound.append(count / X['sound'].shape[0])
			print('task id', i, 'has sound accuracy', acc_list_sound[-1])
			print()
		print("Mean accuracy for img", np.mean(acc_list_img))
		print("Mean accuracy for sound", np.mean(acc_list_sound))
		print("Mean accuracy for both", np.mean(acc_list_img+acc_list_sound))

if __name__ == '__main__':
	config = main_config()
	gym_register(config)
	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
	print("Using device:", device)
	cudnn.benchmark = True

	pretextModel = config.pretextModel(config).cuda().eval()  # this model will be trained using triplet loss

	if config.pretextCollection:

		if config.pretextManualCollect:
			manuallyCollectPretextData()
		else:
			collectPretextData(config)
		print('Data Collection Complete')

	if config.pretextTrain:  # if we want to train the pretext model from scratch

		if config.pretextModelFineTune:
			weight_path = config.pretextModelLoadDir
			pretextModel.load_state_dict(torch.load(weight_path))
			print('Load weights for pretextModel from', weight_path)

		if not os.path.exists(config.pretextModelSaveDir):
			os.makedirs(config.pretextModelSaveDir)
		copyfile(os.path.join('..','Envs', config.envFolder, 'RSI2','config.py'),
				 os.path.join(config.pretextModelSaveDir, 'config.py'))
		p = True if config.plotRepresentation >= 0 else False
		trainRepresentation(model=pretextModel,
							epoch=config.pretextEpoch, lr=config.pretextLR, start_ep=0, plot=p)

	if config.pretextTrainLinear:
		# train a linear layer for representation evaluation
		weight_path = config.pretextModelLoadDir
		pretextModel.load_state_dict(torch.load(weight_path))
		print('Load weights for pretextModel from', weight_path)
		trainLinear(model=pretextModel, epoch=config.pretextLinearEpoch, lr=config.pretextLinearLR, start_ep=0, config=config)

	if (not config.pretextTrain) and (not config.pretextCollection) and (not config.pretextTrainLinear): # test
		weight_path = config.pretextModelLoadDir
		pretextModel.load_state_dict(torch.load(weight_path))
		print('Load weights for pretextModel from', weight_path)

		if config.plotRepresentationOnly:
			trainRepresentation(model=pretextModel, epoch=0, lr=0, start_ep=0, plot=True)

			exit()
		else: # test our representation
			testRepresentation(config, pretextModel)
