#coding:utf-8

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
from scipy.cluster.hierarchy import linkage, dendrogram
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import auc
import matplotlib

def plot_lambda(result_fnames, names, fname, select=["MAX"]):
	assert all([s in ["MAX","SELECT"] for s in select])
	fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(8,5))
	markers = {"MAX": "o", "SAMPLE": "v"}
	colors = {n: ["b","c","g","k","m","r","y"][inn] for inn, n in enumerate(names)}
	lambdas = list(sorted(list(result_fnames.keys())))
	order_axes = {(0,0): "interbatch_div", (0,2): "intrabatch_div", 
		(1,0): "pos_div", (1,1): "precision", (1,2): "relevance"}
	names_axes = {(0,0): "EDV", (0,2): "ADV", 
		(1,0): "PDV", (1,1): "PREC", (1,2): "REL"}
	names_algos = {"QDDecomposition":"QDDecomp.","ConditionalDPP":"CondDPP",
		"EpsGreedy": r"$\varepsilon$-Greedy","Fabaphe": "Pabrah", "kMarkovDPP": "MarkovDPP","MMR":"MMR"}
	axes[0,1].axis("off")
	AUCs = {}
	fs = 15
	for i in range(2):
		for j in range(3):
			if ((i,j) not in order_axes):
				continue
			yaxes_mean = {}
			yaxes_std = {}
			AUC = {}
			for lbd in lambdas:
				df = pd.read_csv(result_fnames[lbd], index_col=0)
				df = df[[col for col in df.columns if ((col.split("_MAX")[0] in names) or (col.split("_SAMPLE")[0] in names))]]
				df = df.loc[order_axes[(i,j)]]
				for n in names:
					for s in select:
						if (n+"_"+s in df.index):
							lst1 = yaxes_mean.get(n+"_"+s, [])
							lst2 = yaxes_std.get(n+"_"+s, [])
							vals = [float(x) for x in df.loc[n+"_"+s].split(" +-")]
							lst1.append(vals[0])
							lst2.append(vals[1])
							yaxes_mean.update({n+"_"+s: lst1})
							yaxes_std.update({n+"_"+s: lst2})
			for inn, n in enumerate(yaxes_mean.keys()):
				if (i==0 and j==0):
					axes[i,j].plot(lambdas, yaxes_mean[n], f"{colors[n.split('_')[0]]}{'-' if (inn%2==0) else '--'}{markers[n.split('_')[-1]]}", alpha=0.5,  label=names_algos[n.split("_")[0]])
				else:
					axes[i,j].plot(lambdas, yaxes_mean[n], f"{colors[n.split('_')[0]]}{'-' if (inn%2==0) else '--'}{markers[n.split('_')[-1]]}", alpha=0.5)
				maxvals = np.array(yaxes_mean[n])+np.array(yaxes_std[n])
				minvals = np.array(yaxes_mean[n])-np.array(yaxes_std[n])
				minvals = np.maximum(minvals, 0)
				AUC.update({n : auc(lambdas, yaxes_mean[n])})
				axes[i,j].fill_between(lambdas, list(minvals), list(maxvals), alpha=0.1)
			if (i==0 and j==0):
				axes[i,j].legend(bbox_to_anchor=(1.1, 1.), fontsize=fs)
			axes[i,j].set_xlabel(r"$\lambda$", fontsize=fs)
			axes[i,j].set_ylabel(names_axes[(i,j)], fontsize=fs)
			axes[i,j].set_yticks(axes[i,j].get_yticks())
			axes[i,j].set_xticks(axes[i,j].get_xticks())
			axes[i,j].set_xlim((0,1))
			axes[i,j].set_yticklabels(axes[i,j].get_yticklabels(), fontsize=fs)
			axes[i,j].set_xticklabels(axes[i,j].get_xticklabels(), fontsize=fs)
			AUCs.update({names_axes[(i,j)] : AUC})
	print(pd.DataFrame(AUCs))
	plt.subplots_adjust(wspace=0.7, hspace=0.7)
	plt.savefig(fname, bbox_inches="tight")
	plt.close()
	
if __name__ == "__main__":
	from glob import glob
	fnames = glob("../results/SYNTHETIC250_lambda*/final_summary.csv")
	result_fnames = {float(fn.split("_lambda")[-1].split("/")[0]): fn for fn in fnames}
	names = ["QDDecomposition","ConditionalDPP","EpsGreedy","Fabaphe","kMarkovDPP","MMR"]
	#names = ["QDDecomposition","ConditionalDPP","EpsGreedy","Fabaphe","MMR"]
	plot_lambda(result_fnames, names, "../results/plot_SYNTHETIC250_lambda.png")
	exit()

def plot_regret(results, names, fname):
	fig, axes = plt.subplots(nrows=1, ncols=len(names), figsize=(5*len(names),2))
	for inn, name in enumerate(names):
		regret = eval(results[name].loc["regret"])
		if (len(names)==1):
			ax = axes
		else:
			ax = axes[inn]
		T = len(regret)
		ax.plot(range(T), regret, "b-", label="Cum. regret")
		ax.set_title(f"{name} = {np.round(regret[-1],2)}")
	plt.savefig(fname+".png", bbox_inches="tight")
	plt.close()

def plot_grid_single_user(results, env, names, fname, seed=1245, s=10):
	rng = np.random.default_rng(seed)
	## Order items in a grid according to the hierarchical clustering
	## of their (PCA) embeddings and visualize which recommendations have
	## been made in time 
	## Use TruncatedSVD for sparse matrices
	recs = results[names].loc["recs"]
	T = len(eval(recs.loc[recs.index[0]]))
	R = np.zeros((env.nitem, T, len(names)))
	for iin, name in enumerate(names):
		for t in range(T):
			R[eval(recs.loc[name])[t],t,iin] = 1
	ids1 = rng.choice(R.shape[0], size=s, replace=False)
	ids = R.sum(axis=1).sum(axis=1)>0
	ids[ids1] = True
	## Remove lines/items which are never recommended
	R = R[ids,:,:]
	Phi = env.item_embs(np.argwhere(ids).ravel().tolist())
	if (Phi.shape[1]>10):
		pca = PCA(n_components=min(10,min(Phi.shape[0], Phi.shape[1])-1))
		Phi = pca.fit_transform(StandardScaler().fit_transform(Phi.toarray()))
	else:
		Phi = Phi.toarray()
	fig, axes = plt.subplots(nrows=1+len(names),ncols=1,figsize=(50,5*len(names)/5))#, constrained_layout = True)
	#fig.tight_layout()
	Z = linkage(Phi, method='average', metric='euclidean')
	leaves = dendrogram(Z,  ax=axes[0], orientation="top", no_labels=True, get_leaves=True)["leaves"]
	axes[0].set_yticklabels([])
	R = R[leaves,:,:] ## reorder according to dendrogram
	## https://matplotlib.org/stable/users/explain/colors/colormaps.html
	cmaps = ['Greens', 'Blues', 'Greys', 'Reds', 'Purples', 'Oranges',
                      'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu',
                      'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn']
	for iin, name in enumerate(names):
		sns.heatmap(R[:,:,iin].T, vmin=0, vmax=2, cmap=cmaps[iin], center=1, ax=axes[iin+1], cbar=False, square=True, xticklabels=False, yticklabels=False)
		#axes[iin+1].set_xlabel("T")
		#axes[iin+1].set_title(name,rotation=90)
		#axes[iin+1].set_xticklabels(range(1,T+1))
		axes[iin+1].yaxis.set_label_coords(-0.1,1.02)
		axes[iin+1].set_ylabel(name,fontsize=6,rotation='horizontal')
		#axes[iin+1].set_ylabel("Items")
	plt.subplots_adjust(left=0.1, right=0.2, 
                    top=0.5, bottom=0.1, 
                    wspace=0.05, hspace=0.4)
	plt.savefig(fname+".png", bbox_inches="tight")
	plt.close()
	
def plot_fulltraj_single_user(results, user, env, names, fname, initial_H=[], thres=0.5):
	T = len(eval(results[results.columns[0]].loc["recs"]))
	fig, axes = plt.subplots(nrows=len(names), ncols=T, figsize=(T,len(names)))
	items = list(set([r for name in names for rec in eval(results[name].loc["recs"]) for r in rec]+initial_H))
	#Phi = env.item_embs_slice(0, env.nitem)
	Phi = env.item_embs(items)
	pca = PCA(n_components=2)
	X = pca.fit_transform(StandardScaler().fit_transform(Phi.toarray()))
	scores = env.feedback(items, user).toarray()
	for inn, name in enumerate(names):
		recs_all = eval(results[name].loc["recs"])
		recs_past = initial_H.copy()
		for t in range(T):
			recs = [items.index(i) for i in recs_all[t]]
			if (T==1 and len(names)>1):
				ax = axes[inn] 
			elif (len(names)==1 and T>1):
				ax = axes[t]
			elif (T==1 and len(names)==1):
				ax = axes
			else:
				ax = axes[inn, t]
			pos, neg = [], []
			for i in recs:
				if (scores[i]>thres):
					pos.append(i)
				else:
					neg.append(i)
			if (t==0):
				ax.scatter(X[:, 0], X[:, 1], marker='+', c="gray", alpha=0.1)
				if (len(pos)>0):
					ax.scatter(X[pos, 0], X[pos, 1], marker='o', c='forestgreen', label="selected positive", alpha=0.3)
				if (len(neg)>0):
					ax.scatter(X[neg, 0], X[neg, 1], marker='o', c='red', label='selected negative', alpha=0.3)
			else:
				ax.scatter(X[:, 0], X[:, 1], marker='+', c="gray", alpha=0.1)
				if (len(recs)>0):
					ax.scatter(X[recs_past, 0], X[recs_past, 1], marker='*', c='blue', label='selected', alpha=0.3)
				if (len(pos)>0):
					ax.scatter(X[pos, 0], X[pos, 1], marker='o', c='forestgreen', label="extra selected positive", alpha=0.3)
				if (len(neg)>0):
					ax.scatter(X[neg, 0], X[neg, 1], marker='o', c='red', label='extra selected negative', alpha=0.3)
			recs_past += recs
			ax.grid(False)
			ax.set_xticklabels([])
			ax.set_yticklabels([])
		if (T==1 and len(names)>1):
			ax = axes[inn] 
		elif (len(names)==1 and T>1):
			ax = axes[0]
		elif (T==1 and len(names)==1):
			ax = axes
		else:
			ax = axes[inn, 0]
		ax.yaxis.set_label_coords(-0.8,1.02)
		ax.set_ylabel(f"{name}",fontsize=6,rotation='horizontal')
	#axes[-1,-1].set_xlabel(f"T")
	plt.savefig(fname+".png", bbox_inches="tight")
	plt.close()
	
if __name__ == "__main__":
	## run simulate.py first
	from time import time
	from known_setting.known_envs import SyntheticCosine
	seed = 1234
	thres = 0.5
	user = 0
	beta = None
	nitem = 5000
	result_folder_name = "../results/Synthetic/DenudedKernel"
	result_folder_name += (f"_beta={beta}" if (beta is not None) else "")
	result_folder_name += (f"_nitem={nitem}" if (nitem != 500000) else "")
	result_folder_name += "/"
	env = SyntheticCosine(dict(name="../../datasets/Synthetic", nitem=nitem, nuser=100, d=5, seed=seed, 
		quantize_digit=-1, new=False))
	results = pd.read_csv(f"{result_folder_name}/final_{env.name.split('/')[-1]}_linearKernel.csv", index_col=0)
	#names = list(sorted([f"{rec_name}_{rec_type}" for rec_type in ["SAMPLE", "MAX"] for rec_name in ["ConditionalDPP", "Fabaphe","kMarkovDPP"]]))
	plot_grid_single_user(results, env, list(results.columns), f"{result_folder_name}/grid_plot", seed)
	initial_H = []
	plot_fulltraj_single_user(results, user, env, list(results.columns), f"{result_folder_name}/traj_plot", initial_H=initial_H, thres=0.5)
