
import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf
from tensorflow.keras.datasets import mnist

from tensorflow.keras.losses import MSE


from tensorflow.compat.v1.layers import batch_normalization


np.random.seed(0)

#MNIST dataset parameters.
num_classes = 10 # total classes (0-9 digits).
num_features = 784 # data features (img shape: 28*28).


#For noise z
mean = 0
std = 1



lr_gan = 0.33 

training_steps = 200500 
batch_size = 1000

display_step = 500 #500

#Network params
z_space = num_features

n_hidden_1 = 512
n_hidden_2 = 512




def main():

	(x_train, y_train), (x_test, y_test) = mnist.load_data()

	#Cast, flatten, and normalize
	x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)
	x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])
	x_train, x_test = x_train / 255., x_test / 255.


	#Shuffle batches
	train_data = tf.data.Dataset.from_tensor_slices((x_train, x_train))#tf.data.Dataset.from_tensor_slices((x_train, y_train))
	train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)



	# Store layers weight & bias
	ones = tf.ones((batch_size, z_space), dtype=tf.dtypes.float32)
	zeros = tf.zeros((batch_size, z_space), dtype=tf.dtypes.float32)

	#ones = tf.matmul(ones, num_features_to_1d_mask)
	#zeros = tf.matmul(zeros, num_features_to_1d_mask)


	print("--------------")
	print("Total trainable params:")
	print(get_total_param_number(d_weights, d_biases))
	print("--------------")


	d_dist_hist = [] #distribution history: (time, percentiles). Used to save all the info from the session
	g_dist_hist = []


	for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):
		# Run the optimization to update W and b values.

		noise = tf.random.normal(shape=(batch_size, z_space), mean=mean, stddev=std, dtype=tf.dtypes.float32, seed=0) #For the encoder
		optimize_gan(batch_x, noise, ones, zeros)
		
		#tf.keras.backend.clear_session()	
		
		if step % display_step == 0:

			#noise = tf.random.uniform(shape=(x_test.shape[0], n_hidden_1), minval=0, maxval=1, dtype=tf.dtypes.float32, seed=0)

			#See a sample from x_train
			gen = neural_net_d(noise, use_noise=False)
			d_real_out = neural_net_d(batch_x, use_noise=False) #in GAN mode
			d_genn_out = neural_net_d(gen, use_noise=False)

			#d_real_out = tf.matmul(d_real_out, num_features_to_1d_mask)
			#d_genn_out = tf.matmul(d_genn_out, num_features_to_1d_mask)

			print("{} -- D_gen_loss: {:.7f}, D_true_loss: {:.7f}".format(step, bce_loss(d_genn_out, zeros), bce_loss(d_real_out, ones)))  ###np.mean(MSE(generated, batch_x)) #(Only for VAEGAN)


			save_imgs(step)

			from tools import dist_tracker
			dist_tracker(dist_hist=d_dist_hist, new=neural_net_d(batch_x, track=True), name="single_gan_d")
			dist_tracker(dist_hist=g_dist_hist, new=neural_net_d(noise, track=True), name="single_gan_g")
		

		if step > 0:
			if step % (display_step*400) == 0: #FOR NOW TRYING 200,000

				from sklearn.model_selection import GridSearchCV
				from sklearn.model_selection import KFold
				from sklearn.neighbors import KernelDensity
				from sklearn.model_selection import cross_validate


				x_hat_kde = neural_net_d(contant_kde_noise)

				bandwidth = 10 ** np.linspace(-1, 1, 10)
				grid = GridSearchCV(KernelDensity(kernel='gaussian'), param_grid={'bandwidth': bandwidth}, cv=KFold(n_splits=3))#, cv=3)
				grid.fit(x_hat_kde.numpy())
				kde = grid.best_estimator_
				#print(grid.score(x_test))
				print("FIT COMPLETE")
				scores = kde.score_samples(x_test)
				print(np.mean(scores))
				print(np.std(scores))

	





# Stochastic gradient descent optimizer.
#optimizer = tf.optimizers.SGD(learning_rate)

#gan_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_gan, beta_1=0.8, beta_2=0.99, epsilon=1e-07)
gan_optimizer = tf.optimizers.SGD(lr_gan) #

bce_loss = tf.keras.losses.BinaryCrossentropy()

def optimize_gan(x, noise, ones, zeros):



	gen_unwatched = neural_net_d(noise, use_noise=False)#Outside descriminator gradient tape (dont calc / apply F1 grads)


	with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
		#VAE

		#gen_tape.watch(noise)
		#reg_noise = tf.random.normal(shape=(batch_size, num_features), mean=0, stddev=0.02, dtype=tf.dtypes.float32, seed=0)

		gen = neural_net_d(noise, use_noise=False) #optimzie generator

		#Want to not train this NN call
		d_fake_out = neural_net_d(gen) #Will not apply updates from these gradients 
		d_real = neural_net_d(x)
		d_fake = neural_net_d(gen_unwatched)


		true_loss = bce_loss(d_real, ones) #Wants to guess true == 1
		fake_loss = bce_loss(d_fake, zeros) #Wants to guess false == 0


		regularizer = tf.norm((d_real - 0.5), ord=1) + tf.norm((d_fake - 0.5), ord=1)

		d_loss = true_loss + fake_loss + 0.000011*regularizer #0.00001*regularizer #00001 was good ####OLD : 0.00011*regularizer
		gen_loss = bce_loss(d_fake_out, ones)
		#gen_loss_to_remove = bce_loss(d_fake_to_remove, ones)


	trainable_variables = list(d_weights.values()) + list(d_biases.values())
	#Compute gradients.

	grads_of_gen = gen_tape.gradient(gen_loss, trainable_variables)
	grads_of_disc = disc_tape.gradient(d_loss, trainable_variables)

	#TODO: COMBINE GRADIENTS?
	gan_optimizer.apply_gradients(zip(grads_of_gen, trainable_variables))
	gan_optimizer.apply_gradients(zip(grads_of_disc, trainable_variables))





#NN Specs:

random_normal = tf.initializers.RandomNormal()

d_weights = {
	'h1': tf.Variable(random_normal([z_space, n_hidden_1])),
	'h2': tf.Variable(random_normal([n_hidden_1, n_hidden_2])),
	'h3': tf.Variable(random_normal([n_hidden_2, num_features])),
	#'h4': tf.Variable(random_normal([n_hidden_3, z_space]))
}
	


d_biases = {
	#'b_h': tf.Variable(tf.zeros([n_hidden_1])),
	'b1': tf.Variable(tf.zeros([n_hidden_1])),
	'b2': tf.Variable(tf.zeros([n_hidden_2])),
	'b3': tf.Variable(tf.zeros([num_features])),
	#'b4': tf.Variable(tf.zeros([z_space]))
}



def neural_net_d(x, use_noise=True, track=False): ##### USE NOISE = TRUE?

	
	layer_1 = tf.add(tf.matmul(x, d_weights['h1']), d_biases['b1'])
	layer_1 = tf.nn.sigmoid(layer_1)


	if use_noise:
		noise = tf.random.normal(shape=(layer_1.shape[0], n_hidden_1), mean=0, stddev=0.05, dtype=tf.dtypes.float32, seed=0) #(Was std 0.04 for great model at 1200, 1200)
		layer_1 = layer_1 + noise

	layer_2 = tf.add(tf.matmul(layer_1, d_weights['h2']), d_biases['b2']) #Adds homotopy at end, equivalent to a lifted dimension

	if track is True: #For monitoring internal covariance shift
		return layer_2

	layer_2 = tf.nn.sigmoid(layer_2)

	#layer_2 = batch_normalization(layer_2)

	layer_3 = tf.add(tf.matmul(layer_2, d_weights['h3']), d_biases['b3']) 
	layer_3 = tf.nn.sigmoid(layer_3)

	return layer_3







contant_kde_noise = tf.random.normal(shape=(10000, z_space), mean=mean, stddev=std, dtype=tf.dtypes.float32, seed=0)


#constant_noise = tf.random.normal(shape=(64, z_space), mean=mean, stddev=std, dtype=tf.dtypes.float32, seed=0)

constant_noise = tf.random.normal(shape=(88, z_space), mean=mean, stddev=std, dtype=tf.dtypes.float32, seed=0)

def save_imgs(step):

	r = 8
	c = 11

	gen_images = neural_net_d(constant_noise, use_noise=False) ###Give generator x as input

	gen_images = 0.5 * gen_images + 0.5
	fig, axs = plt.subplots(r, c, figsize=(11/2.3, 8/2.3)) #, gridspec_kw = {'wspace':0, 'hspace':0}

	cnt = 0
	for i in range(r):
		for j in range(c): 

			axs[i,j].imshow(np.reshape(gen_images[cnt], [28, 28]), cmap='gray')
			axs[i,j].axis('off')
			axs[i,j].set_facecolor('black')
			axs[i,j].set_aspect('equal')

			cnt += 1


	fig.subplots_adjust(hspace=0.0001, wspace=0.0001)
	
	try:
		fig.savefig("enn_photos\\single_gan\\mnist_%d.png" % step, interpolation='nearest')
	except:
		print("Photo not saved")
	plt.close()



def build_1d_mask(size):
	#Builds a size x size matrix w/ zero everywhere but at 1,1.
	zeros = tf.zeros(size-1)
	one = [1]
	zeros = tf.concat([one, zeros], axis=0)
	return tf.linalg.tensor_diag(zeros)

num_features_to_1d_mask = build_1d_mask(num_features) #To broadcast outputs to 1 dimension




def get_total_param_number(weights, biases):

	total = 0
	for mat in weights.values():
		total = total + tf.size(mat)
		print(tf.size(mat))
	for vec in biases.values():
		print(tf.size(vec))
		total = total + tf.size(vec)

	return total.numpy()




if __name__ == "__main__":
	main()