import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import MaxPool2D, Conv2D, Dense, Softmax, Flatten
import pickle

def normalize_img(image, label):
	"""Normalizes images: `uint8` -> `float32`."""
	return tf.cast(image, tf.float32) / 255., label
	
class VGG16Model(Model):
	def __init__(self, num_classes, num_last_hidden=1000):

		super(VGG16Model, self).__init__()

		# Define the architecture of the network
		self.conv1 = Conv2D(kernel_size=[3,3], filters=64, activation='relu', padding='same')
		self.conv2 = Conv2D(kernel_size=[3,3], filters=64, activation='relu', padding='same')
		self.pool1 = MaxPool2D(pool_size=[2,2])

		self.conv3 = Conv2D(kernel_size=[3,3], filters=128, activation='relu', padding='same')
		self.conv4 = Conv2D(kernel_size=[3,3], filters=128, activation='relu', padding='same')
		self.pool2 = MaxPool2D(pool_size=[2,2])

		self.conv5 = Conv2D(kernel_size=[3,3], filters=256, activation='relu', padding='same')
		self.conv6 = Conv2D(kernel_size=[3,3], filters=256, activation='relu', padding='same')
		self.conv7 = Conv2D(kernel_size=[3,3], filters=256, activation='relu', padding='same')
		# self.pool3 = MaxPool2D(pool_size=[2,2])
		
		# self.conv8 = Conv2D(kernel_size=[3,3], filters=512, activation='relu', padding='same')
		# self.conv9 = Conv2D(kernel_size=[3,3], filters=512, activation='relu', padding='same')
		# self.conv10 = Conv2D(kernel_size=[3,3], filters=512, activation='relu', padding='same')
		# self.pool4 = MaxPool2D(pool_size=[2,2])

		# self.conv11 = Conv2D(kernel_size=[3,3], filters=512, activation='relu', padding='same')
		# self.conv12 = Conv2D(kernel_size=[3,3], filters=512, activation='relu', padding='same')
		# self.conv13 = Conv2D(kernel_size=[3,3], filters=512, activation='relu', padding='same')
		# self.pool5 = MaxPool2D(pool_size=[2,2])

		self.flat = Flatten()

		self.hidden1 = Dense(units=4096, activation='relu')
		self.hidden2 = Dense(units=num_last_hidden, activation='relu')

		self.final = Dense(units=num_classes)

	def call(self, inputs):

		# First conv-pooling layer
		x = self.conv1(inputs)
		x = self.conv2(x)
		x = self.pool1(x)

		# Second conv-pooling layer
		x = self.conv3(x)
		x = self.conv4(x)
		x = self.pool2(x)

		# Third conv-pooling layer
		x = self.conv5(x)
		x = self.conv6(x)
		x = self.conv7(x)
		# x = self.pool3(x)

		# # Fourth conv-pooling layer
		# x = self.conv8(x)
		# x = self.conv9(x)
		# x = self.conv10(x)
		# x = self.pool4(x)

		# # Fifth conv-pooling layer
		# x = self.conv11(x)
		# x = self.conv12(x)
		# x = self.conv13(x)
		# x = self.pool5(x)

		# Flatten the array 
		x = self.flat(x)

		# First fully connected layer
		x = self.hidden1(x)

		# Second fully connected layer
		embedding = self.hidden2(x)

		# Prediction layer
		z = self.final(embedding)
		y = tf.nn.softmax(z)

		return embedding, z, y

class VGG16():
	def __init__(self, num_classes, num_last_hidden=1000):
		self.model = VGG16Model(num_classes, num_last_hidden)

	def get_training_parameters(self, params):
		'''
		Function that returns the parameersneeded for the training procedure
		'''

		if 'epochs' in params:
			epochs = params['epochs']
		else:
			epochs = 500

		if 'learning_rate' in params:
			lr = params['learning_rate']
		else:
			lr = 1E-5

		if 'optimizer' in params:
			optimizer = params['optimizer']
		else:
			optimizer = tf.keras.optimizers.Adam(lr)

		if 'loss_metric' in params:
			loss_metric = params['loss_metric']
		else:
			loss_metric = tf.keras.metrics.Mean()

		return epochs, lr, optimizer, loss_metric

	def train(self, dataset, params={}, save_name='./Model_CNN.pkl'):
		'''
		Set or get the parameters of the training procedure
		'''

		epochs, lr, optimizer, loss_metric = self.get_training_parameters(params)
		loss_t_1 = np.inf
		cross_entropy_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

		for epoch in range(epochs):
			# print('Start of epoch %d' % (epoch,))
			for batch in dataset:
				with tf.GradientTape() as tape:
					X_batch, y_batch = batch
					_, z, y_hat = self.model(X_batch)

					loss = cross_entropy_loss(y_true=y_batch, y_pred=y_hat)

				grads = tape.gradient(loss, self.model.trainable_weights)

				optimizer.apply_gradients(zip(grads, self.model.trainable_weights))

				loss_metric(loss)

			if epoch % 1 == 0:
				loss_t = loss_metric.result().numpy()
				print('Epoch ', epoch, ' Loss: ', loss_t, end='\r')

				if np.abs(loss_t - loss_t_1) < 1E-4:
					print('\n')
					print('Ending condition met. Weights converged')
					break

				loss_t_1 = loss_t

				trained_weights = self.model.get_weights()
				pickle.dump([X_batch.shape[1:3], trained_weights], open(save_name, 'wb'))



	def predict(self, X):
		return(self.model(X))

	def get_weights(self):
		return self.model.get_weights()

	def load_from_pickle(self, pickle_path, channels=3):
		# Get the input size, and the trained weights
		shape, weights = pickle.load(open(pickle_path, 'rb'))

		# Make a fake prediction to initialize the model
		self.model.predict(np.zeros((1,shape[0],shape[1],channels)))

		# Load the weights
		self.model.set_weights(weights)

def main():
	return -1

if __name__ == '__main__':
	main()
