import torch.nn as nn
import torch
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import os
import glob
from torch.optim.lr_scheduler import MultiStepLR
import pickle
import matplotlib.font_manager as font_manager
from matplotlib.legend_handler import HandlerTuple
import time


class Flatten(nn.Module):
	def forward(self, x):
		return x.view(x.size(0), -1)


def get_kMedoids(data_generator, torch_device, model, filter_empty, config):
	"""
	Use K-Medoids algorithm to find centroids without using ground truth labels
	"""
	from sklearn_extra.cluster import KMedoids
	X=[]
	for n, (img, sp, sn, gt) in enumerate(data_generator):
		if n > config.numBatchMedoids:
			break  # use num_batch * batch_size to calculate medoids
		else:
			with torch.no_grad():
				features = model(None, sp.to(torch_device),None)
				sp_feat = features[1].to('cpu')
				X.append(sp_feat)
	X=np.concatenate(X, axis=0)
	kmedoids = KMedoids(n_clusters=5, random_state=0, metric='cosine', init='k-medoids++').fit(X)
	if filter_empty:
		empty_sound=torch.zeros([1]+list(config.sound_dim))
		empty_img=torch.zeros([1]+list(config.img_dim))
		with torch.no_grad():
			features = model(empty_img.to(torch_device), empty_sound.to(torch_device), None)
			sp_feat = features[1].squeeze().to('cpu')
		cos_sim=np.dot(kmedoids.cluster_centers_, sp_feat)
		empty_idx=np.argmax(cos_sim)
		return np.delete(kmedoids.cluster_centers_, empty_idx, 0)
	else:
		return kmedoids.cluster_centers_


def project_to_representation(data_generator, config, pretextModel, device, test_method='medoid'):

	with torch.no_grad():
		for n, data in enumerate(data_generator):
			if config.RSI_ver == 2:
				img = data[0]
				sp = data[1]
				gt = data[3]  # sound negative is data[2]
			elif config.RSI_ver == 3:
				img = data[0][0]
				sp = data[1][0]
				gt = data[2]
			else:
				raise NotImplementedError
			if test_method=='medoid':
				# [[feature points for task id 0],[feature points for task id 1],[], ...]
				feat_point = [[] for i in range(config.taskNum + 1)]
				features = pretextModel(img.to(device), sp.float().to(device), None)
				img_feat, sp_feat = features['image_feat'].to('cpu').numpy(), features['sound_feat_positive'].to('cpu').numpy()
				for i in range(config.taskNum+1):
					idx = np.where(gt == i)[0]
					feat_point[i].append(img_feat[idx])
					feat_point[i].append(sp_feat[idx])
			elif test_method=='linear':
				# [{'img':img feature points for task id 0, 'sound':sound feature points for task id 0},{},{}, ...]
				feat_point = [{'img':[], 'sound':[]} for i in range(config.taskNum + 1)]
				features = pretextModel(img.to(device), sp.float().to(device), None)
				img_feat, sp_feat = features['image_feat_raw'].to('cpu').numpy(), features['pos_sound_raw'].to('cpu').numpy()
				for i in range(config.taskNum+1):
					idx = np.where(gt == i)[0]
					feat_point[i]['img'].append(img_feat[idx])
					feat_point[i]['sound'].append(sp_feat[idx])

	for i in range(config.taskNum+1):
		if test_method=='medoid':
			feat_point[i]=np.concatenate(feat_point[i], axis=0)
		elif test_method=='linear':
			for key in feat_point[i]:
				feat_point[i][key]=np.concatenate(feat_point[i][key], axis=0)

	return feat_point #[task id 0 numpy array of shape (num_data, representation_dim), task id 1 numpy array, ...]


def medoids_with_ground_truth(data_generator, torch_device, model, config):
	"""
	Use ground truth labels of all data points to calculate the medoid for each task.
	Useful when we want to do analysis to the representation
	"""
	from sklearn_extra.cluster import KMedoids
	feat_point=project_to_representation(data_generator, config, model, torch_device)

	medoid_list=[]
	for i in range(config.taskNum+1):
		X=feat_point[i]
		m=KMedoids(n_clusters=1, random_state=0, metric='cosine', init='k-medoids++').fit(X).cluster_centers_
		medoid_list.append(m[0])

	return np.array(medoid_list)


def drawArrows(ax, fig, v_img, v_sound, quiver_img, quiver_sound):
	if quiver_img is not None and quiver_sound is not None:
		quiver_img.remove()
		quiver_sound.remove()
	v_img = v_img[0]
	v_sound = v_sound[0]
	quiver_img = ax.quiver(0., 0., 0., v_img[0], v_img[1], v_img[2], color='m', alpha=.6, lw=3)
	quiver_sound = ax.quiver(0., 0., 0., v_sound[0], v_sound[1], v_sound[2], color='m', alpha=1., lw=3)
	fig.canvas.draw_idle()
	fig.canvas.start_event_loop(0.001)
	return quiver_img, quiver_sound


def get_scheduler(config, optimizer):
	if config.pretextLRStep == "step":
		return MultiStepLR(optimizer, milestones=config.pretextLRDecayEpoch, gamma=config.pretextLRDecayGamma)
	else:
		return None


def plotTrainingCurve():
	legends = ['Kuka', '']

	# Change the folder directories here! Add more logs to logDicts as needed to plot.
	logs1 = pd.read_csv("path/to/first/progress.csv")
	logs2 = pd.read_csv("path/to/second/progress.csv")
	logDicts = {1: logs1, 2: logs2}
	graphDicts = {0: 'min', }

	legendList = []
	# summarize history for accuracy

	# for each metric
	for i in range(len(graphDicts)):
		plt.figure(i)
		plt.title(graphDicts[i])
		j = 0
		for key in logDicts:
			if graphDicts[i] not in logDicts[key]:
				continue
			else:
				plt.plot(logDicts[key]['misc/total_timesteps'], logDicts[key][graphDicts[i]])

				legendList.append(legends[j])
				print('avg', str(key), graphDicts[i], np.average(logDicts[key][graphDicts[i]]))
			j = j + 1
		print('------------------------')

		plt.xlabel('total_timesteps')
		plt.legend(legendList, loc='lower right')
		legendList = []

	plt.show()


def get_img_realsense_alone():

	import pyrealsense2 as rs

	# Configure depth and color streams
	pipeline = rs.pipeline()
	config = rs.config()

	# Get device product line for setting a supporting resolution
	pipeline_wrapper = rs.pipeline_wrapper(pipeline)
	pipeline_profile = config.resolve(pipeline_wrapper)
	device = pipeline_profile.get_device()
	device_product_line = str(device.get_info(rs.camera_info.product_line))

	config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30)

	# Start streaming
	pipeline.start(config)
	sensor = pipeline.get_active_profile().get_device().query_sensors()[1]
	option_range = sensor.get_option_range(rs.option.brightness)
	sensor.set_option(rs.option.brightness, 32)

	cv2.namedWindow('RealSense', cv2.WINDOW_AUTOSIZE)

	try:
		while True:

			# Wait for a coherent pair of frames: depth and color
			frames = pipeline.wait_for_frames()

			color_frame = frames.get_color_frame()
			if not color_frame:
				continue

			# Convert images to numpy arrays depth_image = np.asanyarray(depth_frame.get_data())
			color_image = np.asanyarray(color_frame.get_data())

			# Show images
			cv2.line(color_image, (0, 80), (640, 80), (255, 0, 0), thickness=2)
			cv2.imshow('RealSense', color_image)
			cv2.waitKey(1)

	finally:

		# Stop streaming and change the option back to default
		sensor.set_option(rs.option.brightness, option_range.default)
		pipeline.stop()


def convert_pickle_protocol(path):
	for filePath in glob.glob(os.path.join(path, '*.pickle')):
		with open(filePath, 'rb') as f:
			x = pickle.load(f)
		with open(filePath, 'wb') as f:
			pickle.dump(x, f, protocol=2)


def update_RSI2_checkpoint(path):
	"""
	modify old RSI2 checkpoint to fit in the new network
	:param path: the path to the checkpoint
	:return: None
	"""
	x=torch.load(path) # an OrderedDict
	# for ithor env
	old_key=['imgBranch.17.weight', 'imgBranch.17.bias', 'imgBranch.19.weight', 'imgBranch.19.bias', 'fc.0.weight',
			 'fc.0.bias', 'fc.2.weight', 'fc.2.bias', 'fc.4.weight', 'fc.4.bias']
	new_key=['imgTriplet.0.weight', 'imgTriplet.0.bias', 'imgTriplet.2.weight', 'imgTriplet.2.bias',
			 'soundTriplet.0.weight', 'soundTriplet.0.bias', 'soundTriplet.2.weight',
			 'soundTriplet.2.bias', 'soundTriplet.4.weight', 'soundTriplet.4.bias']

	for i in range(len(old_key)):
		x[new_key[i]]=x.pop(old_key[i])
	checkpoint_number = os.path.splitext(os.path.basename(path))[0]
	torch.save(x, os.path.join(os.path.dirname(path), str(checkpoint_number)+'_RSI3Ver.pt'), _use_new_zipfile_serialization=False)


def removeSoundFromDataset(path):
	"""
	remove sound data from the data. It is useful to convert old dataset format (image, sound+, sound-, gt) to the new
	dataset format (image, gt). Besides, task id 4 (empty) becomes task id 5
	"""
	dataNum = [0, 0, 0, 0, 0]
	for filePath in glob.glob(os.path.join(path, '*.pickle')):

		print("Processing", filePath)
		with open(filePath, 'rb') as f:
			data = pickle.load(f)

			for item in data:
				dataNum[int(item['ground_truth'])]=dataNum[int(item['ground_truth'])]+1
				item.pop('sound_positive')
				item.pop('sound_negative')
				if item['ground_truth']==4:
					item['ground_truth']=5
			new_filepath=os.path.join(os.path.dirname(filePath), 'imgGt_'+os.path.splitext(os.path.basename(filePath))[0]+ '.pickle')
			with open(new_filepath, 'wb') as ff:
				pickle.dump(data, ff)

	print("Dataset contains", dataNum)


def changeIndexFromDataset(path):
	"""
   Task id 4 (empty) becomes task id 5
	"""
	dataNum = [0, 0, 0, 0, 0]
	for filePath in glob.glob(os.path.join(path, '*.pickle')):

		print("Processing", filePath)
		with open(filePath, 'rb') as f:
			data = pickle.load(f)

			for item in data:
				dataNum[int(item['ground_truth'])]=dataNum[int(item['ground_truth'])]+1
				if item['ground_truth']==4:
					item['ground_truth']=5
			new_filepath=os.path.join(os.path.dirname(filePath), 'index_'+os.path.splitext(os.path.basename(filePath))[0]+ '.pickle')
			with open(new_filepath, 'wb') as ff:
				pickle.dump(data, ff)

	print("Dataset contains", dataNum)


def plot_Kinova_sim2real_trajectory(epNum):
	simPath=os.path.join('data','episodeRecord', 'simEp', 'ep'+str(epNum))
	realPath = os.path.join('data', 'episodeRecord', 'realEp', 'ep' + str(epNum))

	plt.figure(figsize=[7, 7])
	sim_position = pd.read_csv(os.path.join(simPath, "position.csv")).values
	desired_position = pd.read_csv(os.path.join(simPath, "desired_position.csv")).values
	real_position = pd.read_csv(os.path.join(realPath,"position.csv")).values

	plt.scatter(sim_position[:, 0], sim_position[:, 1], label='Sim position')
	plt.scatter(desired_position[:, 0], desired_position[:, 1], label='Desired position')
	plt.scatter(real_position[:, 0], real_position[:, 1], label='Real position')
	plt.xlabel("x[m]")
	plt.ylabel("y[m]")
	plt.title("Sim & Real position")
	plt.legend()
	plt.show()


def listen_audioAugmentation(tsk, Task):
	from Envs.audioLoader import audioLoader
	from Envs.ai2thor.RSI3.config import AI2ThorConfig
	import sounddevice as sd

	config = AI2ThorConfig()

	audio = audioLoader(config=config)
	audio.loadData()

	for i in range(20):
		x = audio.getAudioFromTask(np.random, tsk, Task)
		sd.play(x[1], audio.fs)
		time.sleep(3.)
