from tqdm import trange
import copy
import pickle
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import numpy as np
import os

import time
from utils import get_scheduler, medoids_with_ground_truth
from dataset import loadEnvData
from shutil import copyfile
import torch.optim as optim
from cfg import main_config, gym_register
import gym
import pandas as pd
from RSI2.pretext_RSI2 import manuallyCollectPretextData, collectPretextData

from models.supConLoss import SupConLoss
import matplotlib.pyplot as plt
import cv2
import glob


def plotRepresentationRSI3(generator, torch_device, net, config, **kwargs):
	assert (config.representationDim<4)
	fig = plt.figure()
	if config.representationDim==3: # 3d scatter plot
		ax = fig.add_subplot(111, projection='3d')
		ax.set_zlabel('Z Label')
	else:
		ax = fig.add_subplot(111)

	ax.set_xlabel('X Label')
	ax.set_ylabel('Y Label')
	colors=['r','y','b','g','tab:purple', 'c']

	with torch.no_grad():
		for n, (im, sp, gt) in enumerate(generator):
			if n > config.plotNumBatch:
				break  # show only config.plotNumBatch batch size data points on the plot
			else:
				if n == config.plotNumBatch and config.annotateLastBatch:
					# save this batch of image to config.episodeImgSaveDir with ID
					for j, pic in enumerate(im1):
						pic=np.transpose((pic.numpy()*255).astype(np.uint8), (1, 2, 0))
						imgSave = cv2.resize(pic,
											 (config.episodeImgSize[1], config.episodeImgSize[0]))
						if config.episodeImgSize[2] == 3:
							imgSave = cv2.cvtColor(imgSave, cv2.COLOR_RGB2BGR)
						fileName = 'lastBatch' + str(j) + '.jpg'
						cv2.imwrite(os.path.join(config.episodeImgSaveDir, fileName), imgSave)

				im = im[0]
				sp = sp[0]
				image_feat, sp_feat, _, _, _ = net(im.to(torch_device), sp.float().to(torch_device), None)
				image_feat, sp_feat=image_feat.to('cpu'), sp_feat.to('cpu')
				for j in range(len(colors)):
					idx=np.where(gt==j)[0]
					if idx.size!=0:
						np_img_feat=image_feat[idx].numpy()
						np_sp_feat = sp_feat[idx].numpy()
						if np_img_feat.shape[1]==2:
							ax.scatter(np_img_feat[:, 0], np_img_feat[:, 1],  marker='o', color=colors[j])
							ax.scatter(np_sp_feat[:, 0], np_sp_feat[:, 1],  marker='v', color=colors[j])
						else:
							ax.scatter(np_img_feat[:,0], np_img_feat[:,1], np_img_feat[:,2], marker='o', color=colors[j], s=20, alpha=0.2)
							ax.scatter(np_sp_feat[:,0], np_sp_feat[:,1], np_sp_feat[:,2], marker='v', color=colors[j], s=20, alpha=0.2)

							if n==config.plotNumBatch and config.annotateLastBatch:
								for k, txt in enumerate(idx):
									# annotate the points with index
									ax.text(np_img_feat[k, 0], np_img_feat[k, 1], np_img_feat[k, 2], str(txt))

		if config.plotRepresentationExtra:
			# load images
			imageList=[]
			fileList = [os.path.basename(x) for x in glob.glob(os.path.join(config.plotExtraPath, '*.jpg'))]
			for j, filePath in enumerate(sorted(fileList, key=lambda x: int(x[0:-4]))):
				image_in=cv2.cvtColor(cv2.imread(os.path.join(config.plotExtraPath,filePath)), cv2.COLOR_BGR2RGB)
				image_in=cv2.resize(image_in, (config.img_dim[2], config.img_dim[1]))
				imageList.append(np.transpose(image_in, (2, 0, 1)))

			if imageList:
				img=torch.from_numpy(np.array(imageList))
				sp=torch.zeros((len(imageList), config.sound_dim[0], config.sound_dim[1], config.sound_dim[2]))
				sn=torch.zeros((len(imageList), config.sound_dim[0], config.sound_dim[1], config.sound_dim[2]))
				features = net((img/255.).float().to(torch_device),
												   sp.to(torch_device),
												   sn.to(torch_device))
				img_feat = features[0].to('cpu')
				np_img_feat = img_feat.numpy()
				imageStr = np.arange(0, len(imageList))
				if np_img_feat.shape[1] == 2:
					ax.scatter(np_img_feat[:, 0], np_img_feat[:, 1], marker='*', color='k', s=20)
					for j, txt in enumerate(imageStr):
						ax.annotate(str(txt),(np_img_feat[j, 0], np_img_feat[j, 1]))
				else:
					ax.scatter(np_img_feat[:, 0], np_img_feat[:, 1], np_img_feat[:, 2], marker='*', color='k', s=20)
					for j, txt in enumerate(imageStr):
						ax.text(np_img_feat[j, 0], np_img_feat[j, 1], np_img_feat[j, 2], str(txt))

		if config.representationDim == 3:
			ax.plot([-1, 1], [0, 0], [0, 0], color="k", alpha=0.2, linewidth=1)
			ax.plot([0, 0], [-1, 1], [0, 0], color="k", alpha=0.2, linewidth=1)
			ax.plot([0, 0], [0, 0], [-1, 1], color="k", alpha=0.2, linewidth=1)
			ax.set_axis_off()
			u, v = np.mgrid[0:2 * np.pi:40j, 0:np.pi:20j]
			x = np.cos(u) * np.sin(v)
			y = np.sin(u) * np.sin(v)
			z = np.cos(v)
			ax.plot_wireframe(x, y, z, color="lightgray", alpha=0.2, linewidth=1)

		if 'medoids' in kwargs:
			medoids=kwargs['medoids']
			ax.scatter(medoids[:, 0], medoids[:, 1], medoids[:, 2], marker='*', color='k', s=800)

		plt.show()
		return fig, ax


def trainRepresentation(model, epoch, lr, start_ep=0, plot=False):
	print('Begin representation training')
	# load data
	data_generator, ds = loadEnvData(data_dir=config.pretextDataDir,
									config=config,
									batch_size=config.pretextTrainBatchSize,
									shuffle=True,
									num_workers=config.pretextDataNumWorkers, # change it to 0 if hangs
									drop_last=True,
									loadNum=config.pretextDataFileLoadNum,
									dtype=config.pretextDataset)

	if not os.path.isdir(config.pretextModelSaveDir):
		os.makedirs(config.pretextModelSaveDir)

	model.train()

	optimizer = optim.Adam(filter(lambda parameters: parameters.requires_grad, model.parameters()),
						   lr=lr,
						   weight_decay=config.pretextAdamL2)

	scheduler = get_scheduler(config, optimizer)
	criterion = SupConLoss(temperature=config.representationTau)

	norm_criterion = torch.nn.BCEWithLogitsLoss()

	loss_saved=[]

	# main training loop
	for ep in trange(epoch, position=0):
		if config.plotRepresentation >= 0 and ep % config.plotRepresentation == 0 and ep > 0 and plot:
			model.eval()
			plotRepresentationRSI3(data_generator, device, model, config)
			model.train()

		loss_ep_supCon=[]
		loss_ep_img_norm=[]
		loss_ep_sound_norm=[]

		for n_iter, (image, sound_positive, gt) in enumerate(data_generator):
			# image or sound is a list of augmentations
			model.zero_grad()
			optimizer.zero_grad()
			imgViewNum = len(image)
			soundViewNum=len(sound_positive)
			image=torch.cat(image, dim=0)
			sound_positive=torch.cat(sound_positive, dim=0)

			image_feat, sound_feat_positive, _, image_norm, sound_norm = \
				model(image.to(device), sound_positive.float().to(device), None, is_train=True)

			groundtruth=gt

			image_feat_list=torch.split(image_feat, [image_feat.size()[0]]*imgViewNum, dim=0)
			sound_positive_feat_list=torch.split(sound_feat_positive, [sound_feat_positive.size()[0]]*soundViewNum, dim=0)
			feat_list=image_feat_list+sound_positive_feat_list
			feat_list=map(lambda ele: ele.unsqueeze(1), feat_list)
			feat= torch.cat(list(feat_list), dim=1)

			# calculate ground truth for the BCELoss
			gt_norm=torch.ones_like(gt)
			gt_norm[gt==config.taskNum]=0.

			gt_norm=gt_norm.float().cuda()

			loss_supCon = criterion(feat, labels=groundtruth if config.representationUseLabels else None)

			if config.pretextEmptyCenter:
				loss_img_norm = norm_criterion(image_norm, gt_norm)
				loss_sound_norm = norm_criterion(sound_norm, gt_norm)
				loss=4.0*loss_supCon+0.3*loss_img_norm+0.3*loss_sound_norm
				loss_ep_img_norm.append(loss_img_norm.item())
				loss_ep_sound_norm.append(loss_sound_norm.item())
			else:
				loss=4.0*loss_supCon
			loss.backward()
			optimizer.step()

			loss_ep_supCon.append(loss_supCon.item())


		if config.pretextLRStep == "step":
			scheduler.step()

		if (ep + 1) % config.pretextModelSaveInterval == 0 or ep+1==epoch:
			fname = os.path.join(config.pretextModelSaveDir, str(start_ep+ep) + '.pt')
			if not os.path.exists(config.pretextModelSaveDir):
				os.makedirs(config.pretextModelSaveDir)
			torch.save(model.state_dict(), fname, _use_new_zipfile_serialization=False)
			print('Model saved to ' + fname)

		avg_supCon_loss=np.sum(loss_ep_supCon) / len(loss_ep_supCon)
		loss_saved.append(avg_supCon_loss)

		if config.pretextEmptyCenter:
			avg_img_norm=np.sum(loss_ep_img_norm) / len(loss_ep_img_norm)
			avg_sound_norm = np.sum(loss_ep_sound_norm) / len(loss_ep_img_norm)
			print('average loss',  avg_img_norm, avg_supCon_loss, avg_sound_norm)
		else:
			print('average loss', avg_supCon_loss)

	if config.pretextTrain:
		df = pd.DataFrame({'avg_loss': loss_saved})
		save_path = os.path.join(config.pretextModelSaveDir,'progress.csv')
		df.to_csv(save_path, mode='w', header=True, index=False)
		print('results saved to', save_path)
	print('Pretext Training Complete')
	model.eval()
	if config.plotRepresentation >= 0 and plot:
		plotRepresentationRSI3(data_generator, device, model, config)

if __name__ == '__main__':
	config = main_config()
	gym_register(config)
	device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
	print("Using device:", device)
	cudnn.benchmark = True
	torch.manual_seed(config.pretextEnvSeed)
	torch.cuda.manual_seed_all(config.pretextEnvSeed)

	pretextModel = config.pretextModel(config).cuda().eval()  # this model will be trained using triplet loss

	if config.pretextCollection:

		if config.pretextManualCollect:
			manuallyCollectPretextData()
		else:
			collectPretextData(config)
		print('Data Collection Complete')

	if config.pretextTrain:  # if we want to train the pretext model from scratch

		if config.pretextModelFineTune:
			weight_path = config.pretextModelLoadDir
			pretextModel.load_state_dict(torch.load(weight_path))
			print('Load weights for pretextModel from', weight_path)

		if not os.path.exists(config.pretextModelSaveDir):
			os.makedirs(config.pretextModelSaveDir)
		copyfile(os.path.join('..','Envs', config.envFolder, 'RSI3','config.py'),
				 os.path.join(config.pretextModelSaveDir, 'config.py'))
		p = True if config.plotRepresentation >= 0 else False
		trainRepresentation(model=pretextModel,
							epoch=config.pretextEpoch, lr=config.pretextLR, start_ep=0, plot=p)

	elif (not config.pretextTrain) and (not config.pretextCollection):  # load the trained pretext model
		weight_path = config.pretextModelLoadDir
		pretextModel.load_state_dict(torch.load(weight_path))
		print('Load weights for pretextModel from', weight_path)

		if config.plotRepresentationOnly:
			trainRepresentation(model=pretextModel, epoch=0, lr=0, start_ep=0, plot=True)
			exit()
		else: # test our representation
			# test the performance of our speech recognition
			# each data point will be assigned to the nearest centroid
			# the accuracy is defined as the portion that the data point is assigned to the correct centroid
			from Envs.audioLoader import audioLoader

			data_generator, ds = loadEnvData(data_dir=config.pretextDataDir,
												 config=config,
												 batch_size=config.pretextTrainBatchSize,
												 shuffle=True,
												 num_workers=config.pretextDataNumWorkers,
												 # self.config.pretextDataNumWorkers, change it to 0 if multiprocessing error
												 drop_last=True,
												 loadNum=config.pretextDataFileLoadNum,
											 dtype=config.pretextDataset)

			medoids = medoids_with_ground_truth(data_generator=data_generator,
									   torch_device=device,
									   model=pretextModel,
									   config=config,
									   )
			print(medoids)

			audio = audioLoader(config=config)
			audio.loadData()

			task_audio=[audio.words['none']['lamp']['activate'], audio.words['none']['lamp']['deactivate'],
						audio.words['none']['music']['activate'],
						audio.words['none']['music']['deactivate']
						]
			for i in range(config.taskNum):
				w=task_audio[i]
				mfcc_feat = []
				for sound in w:
					mfcc_feat.append(audio.get_mfcc(sound, param=audio.param_dict[config.soundSource['dataset']]))
				x = torch.from_numpy(np.array(mfcc_feat))
				with torch.no_grad():
					features = pretextModel(None, x.float().to(device), None)
					sp_feat = features[1].to('cpu').numpy()
				repeated_sound_feat = sp_feat[:,None,:]
				repeated_sound_feat=np.repeat(repeated_sound_feat, repeats=config.taskNum, axis=1)
				cos_sim = np.sum(repeated_sound_feat * medoids, axis=-1)

				idx = np.argmax(cos_sim, axis=-1)
				count=len(np.where(idx==i)[0])
				print('task id', i, 'has accuracy', count/len(w))
