import matplotlib.pyplot as plt
import numpy as np
from pandas import read_csv

import tensorflow as tf
from tensorflow.keras.datasets import mnist


from tensorflow.keras.losses import MSE




np.random.seed(0)

# MNIST dataset parameters.
num_classes = 10 
num_features = 784 

z_space = 500

mean = 0
std = 1


# Training parameters.
lr_gan = 0.0002 #0.0005s
#lr_gan = 0.1
training_steps = 500500#100000
batch_size = 1000

display_step = 500 #500

# Network params:
n_hidden_1 = 330 #1st layer number of neurons.
n_hidden_2 = 330

def main():

	(x_train, y_train), (x_test, y_test) = mnist.load_data()
	# Convert to float32.
	x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)
	# Flatten images to 1-D vector of 784 features (28*28).
	x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])
	# Normalize images value from [0, 255] to [0, 1].
	x_train, x_test = x_train / 255., x_test / 255.

	
	# Use tf.data API to shuffle and batch data.
	train_data = tf.data.Dataset.from_tensor_slices((x_train))#tf.data.Dataset.from_tensor_slices((x_train, y_train))
	train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)

	# Store layers weight & bias

	##### NOT CURRENTLY FLIPPED
	ones = tf.ones((batch_size, z_space), dtype=tf.dtypes.float32) #tf.zeros((batch_size, z_space), dtype=tf.dtypes.float32)
	zeros = tf.zeros((batch_size, z_space), dtype=tf.dtypes.float32) #tf.ones((batch_size, z_space), dtype=tf.dtypes.float32)#


	print("--------------")
	print(get_total_param_number(d_weights, d_biases) + get_total_param_number(g_weights, g_biases))
	print("--------------")


	d_dist_hist = [] #distribution history: (time, percentiles). Used to save all the info from the session
	g_dist_hist = []

	# Run training for the given number of steps.
	for step, (batch_x) in enumerate(train_data.take(training_steps), 1):
		# Run the optimization to update W and b values.

		noise = tf.random.normal(shape=(batch_size, z_space), mean=mean, stddev=std, dtype=tf.dtypes.float32, seed=0) ###TODO: CHANGE THIS DISTRIBUTION?
		#noise_2 = tf.random.normal(shape=(batch_size, z_space), mean=mean, stddev=std, dtype=tf.dtypes.float32, seed=0) 

		optimize_gan(batch_x, noise, ones, zeros)
		
		if step % display_step == 0:

			generated = neural_net_g(noise)
			d_g = neural_net_d(generated)
			d_true = neural_net_d(batch_x)

			print("{} -- D_gen_loss: {}, D_true_loss: {}".format(step, bce_loss(d_g, zeros), bce_loss(d_true, ones)))  ###np.mean(MSE(generated, batch_x)) #(Only for VAEGAN)

			save_imgs_clean(step)

			#Look inside NNs and save the percetiles of their activatiosn to CSV
			from tools import dist_tracker
			dist_tracker(dist_hist=d_dist_hist, new=neural_net_d(batch_x, track=True), name="base_gan_d_no_bn_sgd")
			dist_tracker(dist_hist=g_dist_hist, new=neural_net_g(noise, track=True), name="base_gan_g_no_bn_sgd")


			if step > 0:
				if step % (display_step*400) == 0: #FOR NOW TRYING 200,000
					from sklearn.model_selection import GridSearchCV
					from sklearn.model_selection import KFold
					from sklearn.neighbors import KernelDensity
					from sklearn.model_selection import cross_validate

					x_hat_kde = neural_net_g(contant_kde_noise)

					bandwidth = 10 ** np.linspace(-1, 1, 10)
					grid = GridSearchCV(KernelDensity(kernel='gaussian'), param_grid={'bandwidth': bandwidth}, cv=KFold(n_splits=3))#, cv=3)
					grid.fit(x_hat_kde.numpy())
					kde = grid.best_estimator_
					#print(grid.score(x_test))
					print("FIT COMPLETE")
					scores = kde.score_samples(x_test)
					print(np.mean(scores))

					print(np.std(scores))






optimizer = tf.optimizers.SGD(0.1)
#optimizer = tf.keras.optimizers.Adam(learning_rate=lr_gan, beta_1=0.7, beta_2=0.99, epsilon=1e-07)
optimizer_g = optimizer
optimizer_d = optimizer
#d_optimizer = tf.optimizers.SGD(lr_gan)
bce_loss = tf.keras.losses.BinaryCrossentropy()
 #tf.keras.losses.CategoricalCrossentropy()#
#"""
def optimize_gan(x, noise, ones, zeros):

	gen_unwatched = neural_net_g(noise)#Don't watch gradients
	#reg_noise = tf.random.normal(shape=(batch_size, num_features), mean=0, stddev=0.1, dtype=tf.dtypes.float32, seed=0)

	with tf.GradientTape(persistent=True) as gen_tape:#, tf.GradientTape() as disc_tape:  #persistent=True?

		
		gen = neural_net_g(noise) #optimzie generator
		d_fake_out = neural_net_d(gen, use_noise=False) #Will not apply updates to disc. from this call
		gen_loss = bce_loss(d_fake_out, ones)

		d_real = neural_net_d(x)
		d_fake = neural_net_d(gen_unwatched) #gen_unwatched

		true_loss = bce_loss(d_real, ones) #Wants to guess true == 1
		fake_loss = bce_loss(d_fake, zeros) #Wants to guess false == 0

		regularizer = tf.norm((d_real - 0.5), ord=1) + tf.norm((d_fake - 0.5), ord=1) #Was d_fake not d_fake_out

		d_loss = true_loss + fake_loss + 0.000023*regularizer ###000025 for no-batch-norm
		#### GO A LITTLE LOWER than .0002


	g_trainable_variables = list(g_weights.values()) + list(g_biases.values())
	d_trainable_variables = list(d_weights.values()) + list(d_biases.values())

	grads_of_gen = gen_tape.gradient(gen_loss, g_trainable_variables)
	grads_of_disc = gen_tape.gradient(d_loss, d_trainable_variables)

	optimizer_g.apply_gradients(zip(grads_of_gen, g_trainable_variables))
	optimizer_d.apply_gradients(zip(grads_of_disc, d_trainable_variables))

#"""


#Older bug fix
d_n_hidden_1 = n_hidden_1
d_n_hidden_2 = n_hidden_2

#NN Specs:

random_normal = tf.initializers.RandomNormal()

d_weights = {
	'h1': tf.Variable(random_normal([num_features, d_n_hidden_1])),
	'h2': tf.Variable(random_normal([d_n_hidden_1, d_n_hidden_2])),
	'h3': tf.Variable(random_normal([d_n_hidden_2, z_space])),
}



d_biases = {
	'b1': tf.Variable(tf.zeros([d_n_hidden_1])),
	'b2': tf.Variable(tf.zeros([d_n_hidden_2])),
	'b3': tf.Variable(tf.zeros([z_space])),
}



g_weights = {
	'h1': tf.Variable(random_normal([z_space, n_hidden_1])),
	#'hz': tf.Variable(random_normal([z_space, n_hidden_1])),
	'h2': tf.Variable(random_normal([n_hidden_1, n_hidden_2])),
	'h3': tf.Variable(random_normal([n_hidden_2, num_features])),
}

g_biases = {
	#'b_h': tf.Variable(tf.zeros([n_hidden_1])),
	'b1': tf.Variable(tf.zeros([n_hidden_1])),
	'b2': tf.Variable(tf.zeros([n_hidden_2])),
	'b3': tf.Variable(tf.zeros([num_features])),
	#'b4': tf.Variable(tf.zeros([z_space]))
}


def neural_net_g(x, track=False):

	#layer_1 = x
	layer_1 = tf.add(tf.matmul(x, g_weights['h1']), g_biases['b1'])
	#layer_1 = tf.nn.batch_normalization(x=layer_1, mean=0, variance=1, offset=0, scale=1, variance_epsilon=1e-6, name=None)
	layer_1 = tf.nn.sigmoid(layer_1)


	layer_2 = tf.add(tf.matmul(layer_1, g_weights['h2']), g_biases['b2'])
	#layer_2 = tf.nn.batch_normalization(x=layer_2, mean=0, variance=1, offset=0, scale=5, variance_epsilon=1e-6, name=None)

	if track:
		return layer_2

	layer_2 = tf.nn.sigmoid(layer_2)

	layer_3 = tf.add(tf.matmul(layer_2, g_weights['h3']), g_biases['b3'])
	#layer_3 = tf.nn.batch_normalization(x=layer_3, mean=0, variance=1, offset=0, scale=1, variance_epsilon=1e-6, name=None)
	layer_3 = tf.nn.sigmoid(layer_3)

	#print(layer_2)
	return layer_3


def neural_net_d(x, use_noise=False, track=False):


	layer_1 = tf.add(tf.matmul(x, d_weights['h1']), d_biases['b1'])
	#layer_1 = tf.nn.batch_normalization(x=layer_1, mean=0, variance=1, offset=0, scale=1, variance_epsilon=1e-6, name=None)
	layer_1 = tf.nn.sigmoid(layer_1)

	if use_noise:
		noise = tf.random.normal(shape=(layer_1.shape[0], d_n_hidden_1), mean=0, stddev=0.05, dtype=tf.dtypes.float32, seed=0) #(Was std 0.04 for great model at 1200, 1200)

		layer_1 = layer_1 + noise


	layer_2 = tf.add(tf.matmul(layer_1, d_weights['h2']), d_biases['b2'])
	#layer_2 = tf.nn.batch_normalization(x=layer_2, mean=0, variance=1, offset=0, scale=5, variance_epsilon=1e-6, name=None)

	if track:
		return layer_2

	layer_2 = tf.nn.sigmoid(layer_2)

	#if use_noise:
	#	noise = tf.random.normal(shape=(layer_2.shape[0], n_hidden_1), mean=0, stddev=0.2, dtype=tf.dtypes.float32, seed=0) #(Was std 0.04 for great model at 1200, 1200)
	#	layer_2 = layer_2 + noise


	layer_3 = tf.add(tf.matmul(layer_2, d_weights['h3']), d_biases['b3'])
	#layer_3 = tf.nn.batch_normalization(x=layer_3, mean=0, variance=1, offset=0, scale=1, variance_epsilon=1e-6, name=None)
	layer_3 = tf.nn.sigmoid(layer_3)

	#print(layer_2)
	return layer_3





#constant_noise = tf.random.normal(shape=(64, z_space), mean=mean, stddev=std, dtype=tf.dtypes.float32, seed=0)
contant_kde_noise = tf.random.normal(shape=(10000, z_space), mean=mean, stddev=std, dtype=tf.dtypes.float32, seed=0)

constant_noise = tf.random.normal(shape=(88, z_space), mean=mean, stddev=std, dtype=tf.dtypes.float32, seed=0)



def save_imgs_clean(step):
	r = 8
	c = 11

	gen_images = neural_net_g(constant_noise).numpy()

	#n = 17 
	digit_size = 28
	figure = np.zeros((digit_size * r, digit_size * c, 1))
	cnt = 0
	for i in range(r):
		for j in range(c):

			digit = gen_images[cnt].reshape(digit_size, digit_size, 1)

			figure[i * digit_size: (i + 1) * digit_size,
					j * digit_size: (j + 1) * digit_size, :] = digit

			cnt = cnt + 1
		
	plt.figure(figsize=(11/2.3, 8/2.3))
	plt.grid(False)
	plt.axis('off')
	#print(np.squeeze(figure))
	plt.imshow(np.squeeze(figure), cmap='gray')
	#plt.show()
	plt.savefig("enn_photos\\gan_baseline\\mnist_%d.png" % step)
	plt.close('all')




def get_total_param_number(weights, biases):

	total = 0
	for mat in weights.values():
		total = total + tf.size(mat)
		print(tf.size(mat))
	for vec in biases.values():
		print(tf.size(vec))
		total = total + tf.size(vec)

	return total.numpy()


if __name__ == "__main__":
	main()










