from get_data import *
import time
import matplotlib.pyplot as plt
import numpy as np
import warnings
import os
from datetime import datetime
from sklearn.utils import shuffle
import json
warnings.filterwarnings("ignore")


from RFMS import *

from baseline_MS import *
from baseline_MSPP import *


if __name__ == '__main__':
	datasets = ['WirelessLocalization','UserKnowledge','iris','WallRobot']
	methods = [run_RFMS,run_MS,run_MSPP]
	n_trail = 10

	##########################################################################################################################################################

	times = {}
	time_stds = {}
	NMIs = {}
	NMI_stds = {}
	AMIs = {}
	AMI_stds = {}
	ARSs = {}
	ARS_stds = {}
	RSs = {}
	RS_stds = {}

	for _dict in [times,time_stds,NMIs,NMI_stds,AMIs,AMI_stds,ARSs,ARS_stds,RSs,RS_stds]:
		for dataset in datasets:
			_dict[dataset] = {}
			for method in methods:
				_dict[dataset][method.__name__] = None

	for dataset in datasets:

		if dataset in ['UserKnowledge','iris','WallRobot','WirelessLocalization']:
			norm = True
		else:
			norm = False


		X, y,count = load_data(dataset,norm)
		X, y = shuffle(X, y)
		print("# Class: ",count)
		print(X.shape)
		assert (len(X) == len(y))

		for method in methods:
			print("===================================================",dataset,"-",method.__name__)
			_time = []
			_NMI = []
			_AMI = []
			_ARS =[]
			_RS = []
			for trail in range(n_trail):
				print("===================",trail)
				
				_pass = None
				while _pass is None:
					try:
						start_time = time.time()
						NMI,AMI,ARS,RS = method(X, y,count,dataset)
						duration = time.time() - start_time
						_pass = True
					except KeyboardInterrupt:
						exit()


				print(method.__name__,"-",dataset)
				print(NMI,"-",AMI,"-",ARS,"-",RS)
				_time.append((duration))
				_NMI.append((NMI))
				_AMI.append((AMI))
				_ARS.append((ARS))
				_RS.append((RS))
			times[dataset][method.__name__] = round(np.mean(_time),3)
			time_stds[dataset][method.__name__] = round(np.std(_time),3)
			NMIs[dataset][method.__name__] = round(np.mean(_NMI),3)
			NMI_stds[dataset][method.__name__] = round(np.std(_NMI),3)
			AMIs[dataset][method.__name__] = round(np.mean(_AMI),3)
			AMI_stds[dataset][method.__name__] = round(np.std(_AMI),3)
			ARSs[dataset][method.__name__] = round(np.mean(_ARS),3)
			ARS_stds[dataset][method.__name__] = round(np.std(_ARS),3)
			RSs[dataset][method.__name__] = round(np.mean(_RS),3)
			RS_stds[dataset][method.__name__] = round(np.std(_RS),3)

	print("===========================End of Exp===========================")
	result_path = "results/"+datetime.now().strftime("%m_%d_%y_%H_%M_%S")
	os.mkdir(result_path)

	with open(result_path + '\\' + "time.json", "w") as file:
	    json.dump(times, file, indent=4)
	    json.dump(time_stds, file, indent=4)
	with open(result_path + '\\' + "NMI.json", "w") as file:
	    json.dump(NMIs, file, indent=4)
	    json.dump(NMI_stds, file, indent=4)
	with open(result_path + '\\' + "AMI.json", "w") as file:
	    json.dump(AMIs, file, indent=4)
	    json.dump(AMI_stds, file, indent=4)
	with open(result_path + '\\' + "ARS.json", "w") as file:
	    json.dump(ARSs, file, indent=4)
	    json.dump(ARS_stds, file, indent=4)
	with open(result_path + '\\' + "RS.json", "w") as file:
	    json.dump(RSs, file, indent=4)
	    json.dump(RS_stds, file, indent=4)


	for dataset in datasets:
		for graph in ['time_','NMI_','AMI_','ARS_','RS_']:
			if graph == 'time_':
				temp_means = times
				temp_stds = time_stds
			if graph == 'NMI_':
				temp_means = NMIs
				temp_stds = NMI_stds
			if graph == 'AMI_':
				temp_means = AMIs
				temp_stds = AMI_stds
			if graph == 'ARS_':
				temp_means = ARSs
				temp_stds = ARS_stds
			if graph == 'RS_':
				temp_means = RSs
				temp_stds = RS_stds

			means = [temp_means[dataset][method.__name__] for method in methods]
			stds = [temp_stds[dataset][method.__name__] for method in methods]
			labels = [method.__name__ for method in methods]

			print('=======================')
			print('++++',dataset,'-',graph)
			print(labels)
			print(means)
			print(stds)
			print('=======================')

			plt.bar(labels, means, yerr=stds, capsize=5)
			xlocs, xlabs = plt.xticks()
			for i, v in enumerate(means):
				plt.text(xlocs[i] - 0.25, v + 0.01, str(v))

			plt.xlabel('Method')
			plt.ylabel('Value')
			plt.title(graph + dataset)
			plt.savefig(result_path + '\\' + dataset + '_' + graph + ".png")
			plt.clf()









