import matplotlib.pyplot as plt
import numpy as np


import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.datasets import cifar10

from tensorflow.keras.losses import MSE


from tensorflow.compat.v1.layers import batch_normalization

from plots import  line_plot, see_latent_space, save_history



np.random.seed(0)

num_classes = 10 
num_features = 784 

#For regularization noise
mean = 0
std = 1


# Training parameters. 
lr = 0.001 #0.005

training_steps = 500500 #1000000 #5000000 #5 million
batch_size = 1000

display_step = 500 #500

#Network params
z_space = 100 

n_hidden_1 = 800 #512#406#294 #(for baseline comparison)
n_hidden_2 = 800 #512#406#294



def main():

	(x_train, y_train), (x_test, y_test) = mnist.load_data()

	#FOR CIFAR
	#(x_train, y_train), (x_test, y_test) = cifar10.load_data()
	#x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1]*x_train.shape[2]*x_train.shape[3]))
	#x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1]*x_test.shape[2]*x_test.shape[3]))


	#Cast, flatten, and normalize
	x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)
	x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])
	x_train, x_test = x_train / 255., x_test / 255.


	#Shuffle batches
	train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
	train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1) #5000 is fine since the set is balanced


	t_ones = tf.ones(batch_size)

	#clone_decoder()
	#test_clone_sync(x_train[0:5]) #We makes sure the encoder and decoder are functionally equivalent

	print("-------------- Trainable Variables:")
	print(get_total_param_number(d_weights, d_biases))
	print("--------------")


	from sklearn.neighbors import KernelDensity
	from sklearn.model_selection import cross_validate
	kde = KernelDensity(kernel='gaussian', bandwidth=0.2)#.fit(x_train[:5000])
	print("kde")
	#print(kde.score_samples(x_train))
	#print(np.mean(kde.score_samples(x_test)))
	#In practice we will fit it to *our* generated distribution than measure the divergecne of the test set


	test_history = []
	train_history = []

	for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):

		#To check that the labels are balanced after shuffle
		#print(np.histogram(batch_y.numpy()))

		optimize_vae(batch_x, t_ones)
		
		
		if step % display_step == 0:


			#Calc train loss
			out = neural_net_d(batch_x, use_sigmoid=False) #neural_net_d(x_test) #
			
			mu = out[:, 0:z_space//2]#First half
			log_sigma = out[:, z_space//2:z_space] #Second half
			z_raw = sample_z([mu, log_sigma])
			z = format_noise(z_raw)

			x_hat =	neural_net_d(z)

			vae_losses = vae_loss(batch_x, x_hat, mu, log_sigma, verbose=1)
			#vae_losses = vae_loss(batch_x, x_hat, z_raw, verbose=1)
			train_loss = np.mean(MSE(x_hat, batch_x))


			#Calc test loss
			enc_test = neural_net_d(x_test, use_sigmoid=False) #neural_net_d(x_test) #
			mu = enc_test[:, 0:z_space//2]#First half
			log_sigma = enc_test[:, z_space//2:z_space] #Second half
			enc_test = sample_z([mu, log_sigma], batch=len(x_test))
			x_hat_test = neural_net_d(format_noise(enc_test))

			test_losses = vae_loss(x_test, x_hat_test, mu, log_sigma, verbose=1) #np.mean(MSE(x_hat_test, x_test)) ####FOR NOW WE SKIPPING

			#np.mean(kde.score_samples(x_hat))

			print("{} -- Train MSE: {:.7f}, Test Losses: {}, Train Losses:{}".format(step, train_loss, test_losses, vae_losses)) #, KDE: {}

			save_imgs(step, x_test, y_test)
			#save_manifold(step, x_test)



			#For plotting the most recent loss

			train_history.append(vae_losses[-1])			
			test_history.append(test_losses[-1])

			data_dict = {'test_loss':test_history, 'train_loss':train_history}

			if step > display_step:
				if step % (display_step*20) == 0: #20

					#line_plot(data_dict, interval=display_step, title="'HGN-AE BASELINE {}x{}".format(n_hidden_1, n_hidden_2), folder="enn_photos\\line_plots\\")


					#Write loss/time into history
					test_key = "single_vae_{}_test_{}_{}_long".format(z_space//2, n_hidden_1, n_hidden_2)
					train_key = "single_vae_{}_train_{}_{}_long".format(z_space//2, n_hidden_1, n_hidden_2)
					save_history(train_key, test_key, train_history, test_history)
					

					#enc_test = enc_test[:,:z_space//2] #Drop all zero columns (= z_raw)
					#see_latent_space(step, Z_plot=enc_test, Y_plot=y_test, title='single_vae', folder="enn_photos\\line_plots\\")



				#"""
				if step % (display_step*1000) == 0: #FOR NOW TRYING 200,000

					from sklearn.model_selection import GridSearchCV
					from sklearn.model_selection import KFold

					x_hat_kde = neural_net_d(contant_kde_noise)

					bandwidth = 10 ** np.linspace(-1, 1, 10)
					grid = GridSearchCV(KernelDensity(kernel='gaussian'), param_grid={'bandwidth': bandwidth}, cv=KFold(n_splits=3))#, cv=3)
					grid.fit(x_hat_kde.numpy())
					kde = grid.best_estimator_
					#print(grid.score(x_test))
					print("FIT COMPLETE")
					scores = kde.score_samples(x_test)
					print(np.mean(scores))
					print(np.std(scores))
					
				#"""






#vae_optimizer = tf.optimizers.SGD(lr) 
vae_optimizer = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.6, beta_2=0.99, epsilon=1e-07)#tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.6, beta_2=0.9, epsilon=1e-07)#tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.99, beta_2=0.99, epsilon=1e-07)


def optimize_vae(x, t):


	with tf.GradientTape() as g:
		#VAE

		# Sample z ~ Q(z|X)
		out = neural_net_d(x, use_sigmoid=False)  #out = neural_net_d(x)#
		#out = tf.matmul(out, projection_into_z) #Set unwanted to 0

		mu = out[:, 0:z_space//2]#First half
		log_sigma = out[:, z_space//2:z_space] #Second half
		#Adjust sample
		z_raw = tf.keras.layers.Lambda(sample_z)([mu, log_sigma])
		z = format_noise(z_raw)


		#reg_noise = tf.random.normal(shape=(batch_size, num_features), mean=mean, stddev=0.02, dtype=tf.dtypes.float32, seed=0) #0.05 p good
		#reg_noise = tf.matmul(reg_noise, projection_into_z)

		x_hat = neural_net_d(z) #+ reg_noise)
		#x_hat = tf.matmul(z, projection_into_z) ####1-t=0 here gives the encoder

		loss = vae_loss(x, x_hat, mu, log_sigma)
		#loss = vae_loss(x, x_hat, z_raw)


	#Calculating and applying loss gradients
	trainable_variables = list(d_weights.values()) + list(d_biases.values())

	gradients = g.gradient(loss, trainable_variables)
	vae_optimizer.apply_gradients(zip(gradients, trainable_variables))









def sample_z(args, batch=batch_size):
	mu, log_sigma = args

	epsilon = tf.random.normal(shape=(batch, z_space//2), mean=mean, stddev=std)

	return mu + tf.math.exp(log_sigma * 0.5) * epsilon

def inverse_sigmoid(x):
	return tf.math.log((x)/(1-x))




def vae_loss(y_true, y_pred, mu, log_sigma, verbose=0):

	# E[log P(X|z)] or MSE
	recon_loss = y_true * tf.math.log(1e-10 + y_pred) + (1 - y_true) * tf.math.log(1e-10 + 1 - y_pred)
	recon_loss = -tf.math.reduce_sum(recon_loss, 1)

	#KL(z||z')
	kl_loss = 1 + log_sigma - tf.math.square(mu) - tf.exp(log_sigma)
	kl_loss = -0.5 * tf.math.reduce_sum(kl_loss, 1)
     


	if verbose == 1:

		return (tf.math.reduce_mean(recon_loss).numpy(),  tf.math.reduce_mean(kl_loss).numpy(), tf.math.reduce_mean(recon_loss + kl_loss).numpy()) #[np.mean(recon_loss.numpy()),, np.mean(factor*(recon_loss + kl).numpy())]


	return tf.math.reduce_mean(recon_loss + kl_loss) #recon_loss #factor*(recon_loss + (1/10000)*kl) #Was(1/100000) at peak ~~~ Note: this is tuned to batch size 1000





#NN Specs:

random_normal = tf.initializers.RandomNormal()

d_weights = {
	'h1': tf.Variable(random_normal([num_features, n_hidden_1])),
	'h2': tf.Variable(random_normal([n_hidden_1, n_hidden_2])),
	'h3': tf.Variable(random_normal([n_hidden_2, num_features])),
	#'h4': tf.Variable(random_normal([n_hidden_3, z_space]))
}
	


d_biases = {
	#'b_h': tf.Variable(tf.zeros([n_hidden_1])),
	'b1': tf.Variable(tf.zeros([n_hidden_1])),
	'b2': tf.Variable(tf.zeros([n_hidden_2])),
	'b3': tf.Variable(tf.zeros([num_features])),
	#'b4': tf.Variable(tf.zeros([z_space]))
}





def neural_net_d(x, use_noise=False, use_sigmoid=True):


	layer_1 = tf.add(tf.matmul(x, d_weights['h1']), d_biases['b1'])
	layer_1 = tf.nn.sigmoid(layer_1)

	if use_noise:
		noise = tf.random.normal(shape=(layer_1.shape[0], n_hidden_1), mean=0, stddev=0.06, dtype=tf.dtypes.float32, seed=0) #(Was std 0.04 for great model at 1200, 1200)
		layer_1 = layer_1 + noise

	layer_2 = tf.add(tf.matmul(layer_1, d_weights['h2']), d_biases['b2']) #Adds homotopy at end, equivalent to a lifted dimension
	layer_2 = tf.nn.sigmoid(layer_2)

	#layer_2 = batch_normalization(layer_2)

	layer_3 = tf.add(tf.matmul(layer_2, d_weights['h3']), d_biases['b3']) #Adds homotopy at end, equivalent to a lifted dimension
	
	if use_sigmoid == True:
		layer_3 = tf.nn.sigmoid(layer_3) ####TRY LINEAR HERE TOO

	#layer_2 = tf.nn.sigmoid(layer_2)

	return layer_3 #0.5 + (layer_1_T / tf.norm(layer_1_T)) ###A linear transformation to squish the range into I 



def format_noise(z_raw):
	#Takes in (batch_size x z_space) noise and pads zeroes. For sampling decoder
	return tf.concat([z_raw, tf.zeros([z_raw.shape[0], num_features-(z_space//2)])], axis=1)


#For main loop KDE measurement
contant_kde_noise = format_noise(tf.random.normal(shape=(10000, z_space//2), mean=mean, stddev=std, dtype=tf.dtypes.float32, seed=0)) #8000

#constant_noise = tf.random.normal(shape=(64, num_features), mean=0, stddev=std, dtype=tf.dtypes.float32, seed=0) #+0.02 to std to see more outliers
constant_noise = format_noise(tf.random.normal(shape=(88, z_space//2), mean=mean, stddev=std, dtype=tf.dtypes.float32, seed=0))

def save_imgs(step, x, y):

	r = 8
	c = 11

	#vae_out = tf.matmul(neural_net_d(x[:r*c], tf.zeros(r*c), broadcaster=tf.ones([r*c, 1])), projection_into_z) ####1-t=0 here gives the encoder
	#To just watch the homotopy
	#vae_out = x[:r*c]

	gen_images = neural_net_d(constant_noise) ###Give generator x as input

	# Rescale images 0 - 1
	gen_images = 0.5 * gen_images + 0.5


	fig, axs = plt.subplots(r, c, figsize=(11/2.3, 8/2.3)) #, gridspec_kw = {'wspace':0, 'hspace':0}

	cnt = 0
	for i in range(r):
		for j in range(c): 

			axs[i,j].imshow(np.reshape(gen_images[cnt], [28, 28]), cmap='gray')
			axs[i,j].axis('off')
			axs[i,j].set_facecolor('black')
			axs[i,j].set_aspect('equal')

			cnt += 1


	fig.subplots_adjust(hspace=0.0001, wspace=0.0001)
	
	try:
		fig.savefig("enn_photos\\single_vae\\mnist_%d.png" % step, interpolation='nearest')
	except:
		print("Photo not saved")
	plt.close()



#Function code from https://blog.keras.io/building-autoencoders-in-keras.html
def save_manifold(step, x):
	n = 17  # figure with 15x15 digits
	digit_size = 28
	figure = np.zeros((digit_size * n, digit_size * n))
	# we will sample n points within [-15, 15] standard deviations
	grid_x = np.linspace(-15, 15, n)
	grid_y = np.linspace(-15, 15, n)

	for i, yi in enumerate(grid_x):
	    for j, xi in enumerate(grid_y):

	        z_sample = format_noise(np.array([[xi, yi]]) * std//4) #pads with zeros
	        x_decoded = neural_net_d(z_sample).numpy()

	        digit = x_decoded[0].reshape(digit_size, digit_size)
	        figure[i * digit_size: (i + 1) * digit_size,
	               j * digit_size: (j + 1) * digit_size] = digit

	plt.figure(figsize=(10, 10))
	plt.grid(False)
	plt.xticks([])
	plt.yticks([])
	plt.imshow(figure, cmap='gray')
	plt.savefig("enn_photos\\single_vae\\man_mnist_%d.png" % step)
	plt.close('all')
	#plt.imshow(figure)
	#plt.show()






def get_total_param_number(weights, biases):

	total = 0
	for mat in weights.values():
		total = total + tf.size(mat)
		print(tf.size(mat))
	for vec in biases.values():
		print(tf.size(vec))
		total = total + tf.size(vec)

	return total.numpy()







if __name__ == "__main__":
	main()