from os import path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE, Isomap
import seaborn as sns




plt.style.use('seaborn-darkgrid')
palette = plt.get_cmap('Set2')


def line_plot(data_dict, interval, title, folder):

	data_length = len(list(data_dict.values())[0])
	time = np.arange(interval, (data_length+1)*interval, interval)#Like 0, 500, 1000, ..., 10000

	for i, (name, data) in enumerate(data_dict.items()):

		i = i+2

		if i%2 == 0:
			plt.plot(time, data, label=name, marker='', color=palette(i), linewidth=1)
		else: #Test data ("test" comes before "train" in history.csv)
			plt.plot(time, data, label=name, marker='', color=palette(i-1), linewidth=1, alpha=0.5)

	#plt.subplot(211)
	#plt.plot(t, s)
	#plt.subplot(212)

	plt.legend()
	plt.title(title, fontsize=8)
	plt.ylabel('MSE')
	plt.xlabel('epoch')

	plt.grid(b=True, which='major', axis='both')
	plt.savefig(folder+"\\train_and_test_full_{}.png".format(data_length*interval))

	plt.close()



def save_history(train_key, test_key, train_history, test_history):
	file = "enn_photos\\history.csv"
	display_step = 500 #Function caller must adhere to this
	max_step = 1000000 #Just needs to be large
	if path.exists(file):
		df_load = pd.read_csv(file, index_col=0)
		df_load[test_key] = np.nan
		df_load[train_key] = np.nan
		
	else:
		print("~~ CREATING HISTORY CSV ~~")
		index = range(display_step, max_step+display_step, display_step)
		df_load = pd.DataFrame(np.nan, columns=[test_key, train_key], index=index)

	df_load[test_key].loc[:(len(test_history))*display_step] = test_history
	df_load[train_key].loc[:(len(test_history))*display_step] = train_history
	df_load.to_csv(file)



def load_and_plot_history():
	file = "enn_photos\\history_backup.csv"
	interval = 500
	load_df = pd.read_csv(file, index_col=0)

	max_loc = load_df.isnull().any(axis=1).argmax() - interval #Change to 500*20!	# Smallest axis with no nan 
	#load_df.reset_index(drop=True, inplace=True)
	data_dict = load_df.loc[:max_loc].to_dict(orient='list')
	print(data_dict)


	line_plot(data_dict, interval=interval, title="Comparison of Model Loss", folder="enn_photos\\line_plots\\")


#load_and_plot_history()






def see_latent_space(epoch, Z_plot, Y_plot, title, folder):
		#z_plot must be 2 dimensional encoder outputs. Y_plot is corresponding targets list. X_hat_plot is PCA'd
		"""
		For example:
		rand_perm = np.random.permutation(X_test.shape[0])[:3000] #scale of test
		X = encoder.predict(X_test[rand_perm])
		Y = Y_test[rand_perm]


		"""



		z_feat_cols = ['z_'+str(i) for i in range(Z_plot.shape[1])]
		z_df = pd.DataFrame(Z_plot.numpy(), columns=z_feat_cols)
		z_df['Label'] = Y_plot

		#For PCA
		"""
		x_feat_cols = ['x_'+str(i) for i in range(X_hat_plot.shape[1])]
		x_df = pd.DataFrame(X_hat_plot.numpy(), columns=x_feat_cols)
		x_df['y'] = Y_plot

		pca = PCA(n_components=2)
		X_pca = pca.fit_transform(X=x_df[x_feat_cols].values)

		x_df['pca-2d-one'] = X_pca[:,0]
		x_df['pca-2d-two'] = X_pca[:,1]
		"""


		plt.figure(figsize=(12,12))
		plt.xlim(-7, 7)
		plt.ylim(-7, 7)
		

		#ax1 = plt.subplot(1, 2, 1)
		"""
		#ax1.set_title("X-hat PCA with Labels (Step {})".format(epoch)) 
		sns.scatterplot(
			x="pca-2d-one", y="pca-2d-two",
			hue="y",
			palette=sns.color_palette("hls", 10),
			data=x_df,
			legend="full",
			alpha=0.9,
			ax=ax1
		)
		"""
		#plt = plt.figure()
		#plt.title("2d Encoded Space (Step {})".format(epoch))
		ax = sns.scatterplot(
			x="z_0", y="z_1",
			hue="Label",
			palette=sns.color_palette("hls", 10),
			data=z_df,
			#legend="full",
			alpha=0.9#,
			#ax=ax2
		)


		#ax.legend(loc='upper right')#, bbox_to_anchor=(1.25, 0.5), ncol=1)

		#sns.set_context(font_scale=3)
		sns.set(rc={'xtick.labelsize': 16, 'ytick.labelsize': 16})
		plt.setp(ax.get_legend().get_texts(), fontsize='19')
		ax.set_xlabel(' ');
		ax.set_ylabel(' ');

		plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize=19, markerscale=2)

		plt.savefig(folder+"\\{}_data_viz_{}.png".format(title, epoch), bbox_inches='tight')
		plt.close('all')




def four_corners_plot(epoch, photo_list, z_list, targets, title, folder):
		#z_plot must be 2 dimensional encoder outputs. Y_plot is corresponding targets list. X_hat_plot is PCA'd
		"""
		For example:
		rand_perm = np.random.permutation(X_test.shape[0])[:3000] #scale of test
		X = encoder.predict(X_test[rand_perm])
		Y = Y_test[rand_perm]


		"""
		#These will be photos to show. H_0 := H_1z, H_1 =: H_0x.  left and top right
		feat_cols = ['z_'+str(i) for i in range(Z_plot.shape[1])]
		H_0 = build_photo_tiles(photo_list, targets, r=6, c=6)
		H_1 = build_photo_tiles(photo_list, targets, r=6, c=6)

		#H_2 := H_1x, H_3 =: H_0z. Bottom left and bottom right
		#These are 2d plots
		H_2 = 0
		H_3 = 0

		feat_cols = ['z_'+str(i) for i in range(Z_plot.shape[1])]
		z_df = pd.DataFrame(Z_plot.numpy(), columns=z_feat_cols)
		z_df['y'] = Y_plot


		plt.figure(figsize=(24,12))

		ax1 = plt.subplot(2, 2, 1)
		ax1.set_title("H(0, x)") 
		sns.scatterplot(
			x="pca-2d-one", y="pca-2d-two",
			hue="y",
			palette=sns.color_palette("hls", 10),
			data=x_df,
			legend="full",
			alpha=0.9,
			ax=ax1
		)
		ax2 = plt.subplot(2, 2, 2)
		ax2.set_title("H(1, x)")
		sns.scatterplot(
			x="z_0", y="z_1",
			hue="y",
			palette=sns.color_palette("hls", 10),
			data=z_df,
			legend="full",
			alpha=0.9,
			ax=ax2
		)

		ax3 = plt.subplot(2, 2, 3)
		ax1.set_title("H(1, x)".format(epoch)) 
		sns.scatterplot(
			x="pca-2d-one", y="pca-2d-two",
			hue="y",
			palette=sns.color_palette("hls", 10),
			data=x_df,
			legend="full",
			alpha=0.9,
			ax=ax1
		)

		ax4 = plt.subplot(2, 2, 4)
		ax4.set_title("2d Encoded Space (Step {})".format(epoch))
		sns.scatterplot(
			x="z_0", y="z_1",
			hue="y",
			palette=sns.color_palette("hls", 10),
			data=z_df,
			legend="full",
			alpha=0.9,
			ax=ax2
		)


		plt.savefig(folder+"\\{}_data_viz_{}.png".format(title, epoch))
		plt.close()




"""
How to input

constant_noise = tf.random.normal(shape=(64, n_hidden_2), mean=0, stddev=std, dtype=tf.dtypes.float32, seed=0) #+0.02 to std to see more outliers
gen_images = neural_net(x_test[:r*c], constant_noise)
gen_images = tf.reshape(gen_images, (tf.shape(gen_images)[0], tf.shape(gen_images)[1], 1))
# Rescale images 0 - 1
gen_images = 0.5 * gen_images + 0.5

"""

def build_photo_tiles(photo_list, y_test, r=8, c=8):

	gen_images = photo_list[:r*c]
	labels = y_test[:r*c]

	fig, axs = plt.subplots(r, c)
	cnt = 0
	for i in range(r):
		for j in range(c): 

			axs[i,j].imshow(np.reshape(gen_images[cnt], [28, 28]), cmap='gray')
			axs[i,j].axis('off')
			#axs[i,j].set_title(labels[cnt])
			cnt += 1

	#fig.savefig("enn_photos\\line_plots\\mnist_%d.png" % step)
	#plt.close()
	return fig











def custom_plot():


	plt.rc('xtick',labelsize=20)
	plt.rc('ytick',labelsize=20)

	folder="enn_photos\\"
	file = "enn_photos\\history.csv"
	df = pd.read_csv(file, index_col=0)

	print(df.columns)

	df.index = df.index * 1000

	plt.figure(figsize=(5, 5))


	from scipy.ndimage.filters import gaussian_filter1d

	#MNIST
	#"""
	plt.plot(df.index, gaussian_filter1d(df["single_vae_2_test_800_800"], sigma=2), label="single-VAE test", linestyle='solid', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["single_vae_2_train_800_800"], sigma=2), label="single-VAE train", linestyle='dashed', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_2_test_657_657"], sigma=2), label="baseline-VAE test", color='orange', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_2_train_657_657"], sigma=2), label="baseline-VAE train", linestyle='dashed', color='orange', linewidth=3)
	#"""

	"""
	plt.plot(df.index, gaussian_filter1d(df["single_vae_10_test_800_800"], sigma=2), label="single-VAE test", linestyle='solid', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["single_vae_10_train_800_800"], sigma=2), label="single-VAE train", linestyle='dashed', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_10_test_653_653"], sigma=2), label="baseline-VAE test", color='orange', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_10_train_653_653"], sigma=2), label="baseline-VAE train", linestyle='dashed', color='orange', linewidth=3)
	"""

	"""
	plt.plot(df.index, gaussian_filter1d(df["single_vae_50_test_800_800"], sigma=2), label="single-VAE test", linestyle='solid', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["single_vae_50_train_800_800"], sigma=2), label="single-VAE train", linestyle='dashed', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_50_test_635_635"], sigma=2), label="baseline-VAE test", color='orange', linewidth=2)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_50_train_635_635"], sigma=2), label="baseline-VAE train", linestyle='dashed', color='orange', linewidth=3)
	"""

	"""
	plt.plot(df.index, gaussian_filter1d(df["single_vae_260_test_800_800"], sigma=2), label="Single-VAE Test", linestyle='solid', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["single_vae_260_train_800_800"], sigma=2), label="Single-VAE Train", linestyle='dashed', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_260_test_550_550"], sigma=2), label="Baseline-VAE Test", color='orange', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_260_train_550_550"], sigma=2), label="Baseline-VAE Train", linestyle='dashed', color='orange', linewidth=3)
	"""


	###############################################################################################################################################################################
	#SVHN
	"""
	plt.plot(df.index, gaussian_filter1d(df["singe_vae_svhn_5_test_512_512"], sigma=3), label="single-VAE test", linestyle='solid', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["singe_vae_svhn_5_train_512_512"], sigma=3), label="single-VAE train", linestyle='dashed', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_svhn_5_test_479_479"], sigma=3), label="baseline-VAE test", color='orange', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_svhn_5_train_479_479"], sigma=3), label="baseline-VAE train", linestyle='dashed', color='orange', linewidth=3)
	"""

	"""
	plt.plot(df.index, gaussian_filter1d(df["singe_vae_svhn_50_test_512_512"], sigma=3), label="single-VAE test", linestyle='solid', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["singe_vae_svhn_50_train_512_512"], sigma=3), label="single-VAE train", linestyle='dashed', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_svhn_50_test_471_471"], sigma=3), label="baseline-VAE test", color='orange', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_svhn_50_train_471_471"], sigma=3), label="baseline-VAE train", linestyle='dashed', color='orange', linewidth=3)
	"""

	"""
	plt.plot(df.index, gaussian_filter1d(df["singe_vae_svhn_500_test_512_512"], sigma=3), label="single-VAE test", linestyle='solid', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["singe_vae_svhn_500_train_512_512"], sigma=3), label="single-VAE train", linestyle='dashed', color='blue', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_svhn_500_test_404_404"], sigma=3), label="baseline-VAE test", color='orange', linewidth=3)
	plt.plot(df.index, gaussian_filter1d(df["baseline_vae_svhn_500_train_404_404"], sigma=3), label="baseline-VAE train", linestyle='dashed', color='orange', linewidth=3)
	"""




	#plt.subplot(211)
	#plt.plot(t, s)
	#plt.subplot(212)

	#plt.legend()
	#plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize=20)
	#plt.title("MNIST, 2", fontsize=8)
	#plt.xticks(fontsize=100)
	#plt.yticks(fontsize=100)
	plt.xscale('log')
	#matplotlib.rcParams.update({'font.size': 64})
	#matplotlib.rc('xtick', labelsize=128) 
	#matplotlib.rc('ytick', labelsize=128)
	#plt.xlabel('', fontsize=18)
	#plt.ylabel('', fontsize=18)

	#plt.ylim(1870, 2000)
	plt.ylim(80, 170)
	

	#plt.box(on=None)
	#plt.grid(False)
	plt.grid(b=True, which='major', axis='both')
	plt.savefig(folder+"\\train_and_test_full_{}.png".format("MNIST_2"), bbox_inches='tight')

	plt.close()


#custom_plot()



















def reg_plot():

	folder="enn_photos\\"
	file = "enn_photos\\dist_history_backup.csv"
	interval = 500
	df = pd.read_csv(file, index_col=0)

	plt.rc('xtick',labelsize=20)
	plt.rc('ytick',labelsize=20)

	print(df.columns)

	df.index = df.index * 1000

	plt.figure(figsize=(6, 3))


	from scipy.ndimage.filters import gaussian_filter1d



	"""
	plt.plot(df.index, df["single_gan_d_15"], label="single-VAE test", linestyle='solid', color='black', linewidth=2)
	#plt.plot(df.index, df["single_gan_d_50"], label="single-VAE train", linestyle='solid', color='black', linewidth=2)
	plt.plot(df.index, df["single_gan_d_85"], label="baseline-VAE test", color='black', linewidth=2)
	plt.fill_between(df.index, df["single_gan_d_85"], df["single_gan_d_15"], color='blue', alpha=0.3)

	plt.plot(df.index, df["single_gan_g_15"], label="single-VAE test", linestyle='solid', color='black', linewidth=2)
	#plt.plot(df.index, df["single_gan_g_50"], label="single-VAE train", linestyle='solid', color='black', linewidth=2)
	plt.plot(df.index, df["single_gan_g_85"], label="baseline-VAE test", color='black', linewidth=2)
	plt.fill_between(df.index, df["single_gan_g_85"], df["single_gan_g_15"], color='red', alpha=0.3)
	"""

	"""
	plt.plot(df.index, df["base_gan_d_no_bn_15"], label="single-VAE test", linestyle='solid', color='black', linewidth=2)
	#plt.plot(df.index, df["base_gan_d_no_bn_50"], label="single-VAE train", linestyle='solid', color='black', linewidth=2)
	plt.plot(df.index, df["base_gan_d_no_bn_85"], label="baseline-VAE test", color='black', linewidth=2)
	plt.fill_between(df.index, df["base_gan_d_no_bn_85"], df["base_gan_d_no_bn_15"], color='blue', alpha=0.3)

	plt.plot(df.index, df["base_gan_g_no_bn_15"], label="single-VAE test", linestyle='solid', color='black', linewidth=2)
	#plt.plot(df.index, df["single_gan_g_50"], label="single-VAE train", linestyle='solid', color='black', linewidth=2)
	plt.plot(df.index, df["base_gan_g_no_bn_85"], label="baseline-VAE test", color='black', linewidth=2)
	plt.fill_between(df.index, df["base_gan_g_no_bn_85"], df["base_gan_g_no_bn_15"], color='red', alpha=0.3)
	"""

	#BASELINE W/ ADAM AND BN 
	"""
	plt.plot(df.index, df["base_gan_d_wbn_15"], label="single-VAE test", linestyle='solid', color='black', linewidth=2)
	#plt.plot(df.index, df["base_gan_d_no_bn_50"], label="single-VAE train", linestyle='solid', color='black', linewidth=2)
	plt.plot(df.index, df["base_gan_d_wbn_85"], label="baseline-VAE test", color='black', linewidth=2)
	plt.fill_between(df.index, df["base_gan_d_wbn_85"], df["base_gan_d_wbn_15"], color='blue', alpha=0.3)

	plt.plot(df.index, df["base_gan_g_wbn_15"], label="single-VAE test", linestyle='solid', color='black', linewidth=2)
	#plt.plot(df.index, df["single_gan_g_50"], label="single-VAE train", linestyle='solid', color='black', linewidth=2)
	plt.plot(df.index, df["base_gan_g_wbn_85"], label="baseline-VAE test", color='black', linewidth=2)
	plt.fill_between(df.index, df["base_gan_g_wbn_85"], df["base_gan_g_wbn_15"], color='red', alpha=0.3)
	"""


	################ ^ DEPRICATED




	#Single Network SGD
	"""

	plt.plot(df.index, df["single_gan_d_15"], linestyle='solid', color='black', linewidth=2)
	plt.plot(df.index, df["single_gan_d_85"], color='black', linewidth=2)
	plt.fill_between(df.index, df["single_gan_d_85"], df["single_gan_d_15"], hatch=('\\\\\\\\'), linewidth=1, facecolor='red', edgecolor='black', alpha=0.2, label="Discriminator")# color='blue', alpha=0.3)

	plt.plot(df.index, df["single_gan_g_15"], linestyle='solid', color='black', linewidth=2)
	plt.plot(df.index, df["single_gan_g_85"], color='black', linewidth=2)
	plt.fill_between(df.index, df["single_gan_g_85"], df["single_gan_g_15"],  hatch='////', linewidth=1, facecolor='blue', edgecolor='black', alpha=0.2,  label="Generator") #, color='red', alpha=0.3)

	"""


	#BASELINE Adam w/ Batchnorm
	#"""

	plt.plot(df.index, df["base_gan_d_wbn_15"], linestyle='solid', color='black', linewidth=2)
	#plt.plot(df.index, df["base_gan_d_no_bn_50"], label="single-VAE train", linestyle='solid', color='black', linewidth=2) #
	plt.plot(df.index, df["base_gan_d_wbn_85"], color='black', linewidth=2)
	plt.fill_between(df.index, df["base_gan_d_wbn_85"], df["base_gan_d_wbn_15"], hatch=('\\\\\\\\'), linewidth=1, facecolor='red', edgecolor='black', alpha=0.2, label="Discriminator")# color='blue', alpha=0.3)

	plt.plot(df.index, df["base_gan_g_wbn_15"], linestyle='solid', color='black', linewidth=2)
	plt.plot(df.index, df["base_gan_g_wbn_85"], color='black', linewidth=2)
	plt.fill_between(df.index, df["base_gan_g_wbn_85"], df["base_gan_g_wbn_15"],  hatch='////', linewidth=1, facecolor='blue', edgecolor='black', alpha=0.2,  label="Generator") #, color='red', alpha=0.3)
	#"""

	#BASELINE SGD NO BN
	"""

	plt.plot(df.index, df["base_gan_d_no_bn_sgd_15"], linestyle='solid', color='black', linewidth=2)
	#plt.plot(df.index, df["base_gan_d_no_bn_50"], label="single-VAE train", linestyle='solid', color='black', linewidth=2) #
	plt.plot(df.index, df["base_gan_d_no_bn_sgd_85"], color='black', linewidth=2)
	plt.fill_between(df.index, df["base_gan_d_no_bn_sgd_85"], df["base_gan_d_no_bn_sgd_15"], hatch=('\\\\\\\\'), linewidth=1, facecolor='red', edgecolor='black', alpha=0.2, label="Discriminator")# color='blue', alpha=0.3)
	#plt.fill_between(df.index, df["base_gan_d_no_bn_sgd_85"], df["base_gan_d_no_bn_sgd_15"])# color='blue', alpha=0.3)

	plt.plot(df.index, df["base_gan_g_no_bn_sgd_15"], linestyle='solid', color='black', linewidth=2)
	#plt.plot(df.index, df["single_gan_g_50"], label="single-VAE train", linestyle='solid', color='black', linewidth=2)
	plt.plot(df.index, df["base_gan_g_no_bn_sgd_85"], color='black', linewidth=2)
	plt.fill_between(df.index, df["base_gan_g_no_bn_sgd_85"], df["base_gan_g_no_bn_sgd_15"],  hatch='////', linewidth=1, facecolor='blue', edgecolor='black', alpha=0.2,  label="Generator") #, color='red', alpha=0.3)
	#plt.fill_between(df.index, df["base_gan_g_no_bn_sgd_85"], df["base_gan_g_no_bn_sgd_15"], hatch='/', alpha=0)
	"""

	#plt.legend()
	#plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., fontsize=19)
	#plt.title("MNIST, 2", fontsize=8)
	#plt.xscale('log')
	#plt.ylim(1870, 2000)

	#plt.box(on=None)
	#plt.grid(False)
	plt.grid(b=True, which='major', axis='both')
	#plt.savefig(folder+"\\activations_{}.png".format("single_gan"), bbox_inches='tight')
	plt.savefig(folder+"\\activations_{}.png".format("baseline_gan_adam_bn"), bbox_inches='tight')
	#plt.savefig(folder+"\\activations_{}.png".format("baseline_gan_no_bn_sgd"), bbox_inches='tight')
	plt.close()


#reg_plot()



