import torch
import torch.backends.cudnn as cudnn
import numpy as np
from argument_parser import argument_parser
from datasets.datasets import Datasets
from torch.utils.data.dataloader import DataLoader
import pickle
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import os
from os.path import join, isfile
from utils import model_loader
from metrics.manifold import tsne
from analysis.plot import check_plot_save
from analysis import plot as custom_plot

args = argument_parser()

out_folder = args.out_folder if args.out_folder.endswith("/") else args.out_folder + "/"
if not os.path.exists(os.path.dirname(out_folder)):
	os.makedirs(os.path.dirname(out_folder), exist_ok=True)

use_cuda = torch.cuda.is_available() and args.cuda

if use_cuda:
	torch.cuda.init()

device = torch.device('cuda:0' if use_cuda else 'cpu')
ngpu = int(args.ngpu)

root = args.root
test_root = args.test_root
dataset_path = args.dataset
debug = args.debug
n_samples = args.n_samples

save = args.save
load = args.load
model = args.model

def print_metrics(data=np.ones(1), title='Data'):
	print(title + ' Min: {0:.3f}'.format(data.min().item()))
	print(title + ' Max: {0:.3f}'.format(data.max().item()))
	print(title + ' Mean: {0:.3f}'.format(data.mean().item()))
	print(title + ' Std: {0:.3f}'.format(data.std().item()))

def plot_histogram_relevances(relevances=None):
	fig = plt.figure()
	ax = fig.gca()
	BIGGER_SIZE = 20
	plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1))
	plt.rc('font', size=BIGGER_SIZE)          # controls default text sizes
	plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
	plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
	plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
	plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
	plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
	plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
	ax.tick_params(labelsize=BIGGER_SIZE)
	n, bins, patches = plt.hist(x=relevances, bins = 10,
								alpha=0.7, rwidth=0.85, weights=np.ones(len(relevances)) / len(relevances))

	plt.grid(axis='y', alpha=0.75)
	plt.ylabel('Frequency', fontsize=BIGGER_SIZE)
	
	plt.savefig(join(out_folder, args.dataset.split('.')[0] + '_percentage.png'), bbox_inches='tight', pad_inches=0)
	plt.savefig(join(out_folder, args.dataset.split('.')[0] + '_percentage.eps'), bbox_inches='tight', pad_inches=0)
	
	plt.cla()
	#plt.show()
	plt.close('all')
	

	#fig = plt.figure()
	#ax = fig.gca()
	# plt.rc('font', size=BIGGER_SIZE)          # controls default text sizes
	# plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
	# plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
	# plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
	# plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
	# plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
	# plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
	ax.tick_params(labelsize=BIGGER_SIZE)
	n, bins, patches = plt.hist(x=relevances, bins = 10, alpha=0.7, rwidth=0.85)
	plt.grid(axis='y', alpha=0.75)
	plt.ylabel('Frequency', size=BIGGER_SIZE)
	plt.savefig(join(out_folder, args.dataset.split('.')[0] + '_raw.png'), bbox_inches='tight', pad_inches=0)
	plt.savefig(join(out_folder, args.dataset.split('.')[0] + '_raw.eps'), bbox_inches='tight', pad_inches=0)
	plt.cla()
	plt.close('all')
	#plt.show()

def generate_threshold_relevances_all(model, range_of_selected_nodes=10, below=False, below_str=False):
	for selected_node in range(0, range_of_selected_nodes):
		for threshold in np.around(np.arange(0.1, 1, 0.1), decimals=1):
			prototypes_decoded_before, prototypes_decoded_after = threshold_relevances(model=combined_model, node_control=node_control, 
																selected_node=selected_node, threshold=threshold, below=below)
			#new_img = custom_plot.concat_images(prototypes_decoded_before.view(28,28).cpu().detach(), prototypes_decoded_after.view(28,28).cpu().detach())
			custom_plot.plot_image(title='Clean Features of ' + str(selected_node) +' with relevances ' + below_str +  ' threshold ' + str(threshold), 
				data=prototypes_decoded_after.view(28,28).cpu().detach(), 
				save_path=join(out_folder, args.dataset.split('.')[0] + '_select_node' + str(selected_node) + 
								'_threshold' + str(threshold) + '_' + below_str + '.png'),
				figsize=(20, 20), constrained_layout=False, cmap='gray', save=True)

def generate_threshold_relevances_paper(model, node_control, below=False):
	save = True
	selected_node = 7
	threshold = [0.0, 0.3, 0.4, 0.5]
	for th in threshold:
		_, prototypes_decoded_after = threshold_relevances(model, node_control, selected_node=selected_node, threshold=th, below=below)

		plt.figure()
		ax1 = plt.subplot(1, 1, 1)
		ax1.imshow(prototypes_decoded_after, cmap='gray')
		plt.axis('off')
		plt.draw()

		plt.savefig(join(out_folder, args.dataset.split('.')[0] + '_select_node' + str(selected_node) + 
								'_threshold' + str(th) + '_' + below_str + '.png'), bbox_inches='tight', pad_inches=0)
		plt.show()
		plt.close()


def threshold_relevances(model, node_control, selected_node=0, threshold=0.3, below=True):
	model.eval()

	prototypes_activated = model.som.weights[node_control.bool()]
	prototypes_decoded_before = model.decoder(prototypes_activated[selected_node])

	relevances_activated = model.som.relevance[node_control.bool()]

	if(below):
		prototypes_activated[selected_node] = prototypes_activated[selected_node] * (relevances_activated[selected_node] > threshold)
	else:
		prototypes_activated[selected_node] = prototypes_activated[selected_node] * (relevances_activated[selected_node] <= threshold)
	
	prototypes_decoded_after = model.decoder(prototypes_activated[selected_node])

	return prototypes_decoded_before.view(28,28).cpu().detach(), prototypes_decoded_after.view(28,28).cpu().detach()


if not load or model is None:
	#TODO
	print("For now, you must have a pre-trained model beforehand to visualize the relevances.")
	exit(0)

combined_model, _, _, _, _, _ = model_loader.load_autoencodersom_model(model, device)

if use_cuda:
	combined_model.cuda()
	cudnn.benchmark = True


dataset = Datasets(dataset=dataset_path, root_folder=root,
				   debug=debug, n_samples=n_samples)

train_loader = DataLoader(dataset.train_data, shuffle=True)
test_loader = DataLoader(dataset.test_data, shuffle=False)

combined_model.eval()

# selected_node = 0
# threshold = 0.2
range_of_selected_nodes = 10
below = False
below_str = 'below' if below else 'high'
node_control = combined_model.som.node_control
relevances_activated = combined_model.som.relevance[node_control.bool()]

# print_metrics(data=relevances_activated, title='Relevances')
plot_histogram_relevances(relevances_activated.view(-1))
# generate_threshold_relevances_all(combined_model, range_of_selected_nodes=range_of_selected_nodes, below=below, below_str=below_str)
# generate_threshold_relevances_paper(combined_model, node_control)