import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model

class DataSampler():
	'''
	This function generates samples to be fed to a neural network
	'''
	def __init__(self, batch_size=1):
		self.batch_size = batch_size
		self.index = -1
		# self.num_instances = X.shape[0]
		# self.order = np.random.permutation(self.num_instances)
		self.next_epoch = False

	def check_if_shuffle(self):
		'''
		This function checks if we need to shuffle the data. This condition
		happens when the index is greater than the number of samples
		'''

		if self.index >= self.num_instances:
			self.order = np.random.permutation(self.num_instances)
			self.index = 0
			self.next_epoch = True

	def get_data(self, X, y):
		'''
		This function returns the data that should be used for the current 
		batch.
		'''
		self.next_epoch = False

		if self.index == -1:
			self.index = 0
			self.num_instances = len(X)
			self.order = np.random.permutation(self.num_instances)

		# Get the indexes of the data to sample
		start_index = self.index
		end_index = start_index + self.batch_size

		if end_index > self.num_instances:
			end_index = self.num_instances

		batch_indexes = self.order[start_index:end_index]

		X_batch = list()
		y_batch = list()

		for element in batch_indexes:
			X_batch.append(X[element])
			y_batch.append(y[element])

		X_batch = np.array(X_batch, dtype=np.float32)
		y_batch = np.array(y_batch, dtype=np.float32)

		self.index = end_index
		self.check_if_shuffle()

		return X_batch, y_batch, self.next_epoch


class SoftmaxRegressionModel(Model):
	def __init__(self, n_classes=3):
		super(SoftmaxRegressionModel, self).__init__()
		self.Dense_1 = Dense(units=n_classes)

	def call(self, inputs):
		x = self.Dense_1(inputs)
		x = tf.nn.softmax(x)

		return x

class Learner():
	'''
	Base learner that contains common functions for most of the learners
	'''
	def __init__(self):
		self.model = None

	def get_training_parameters(self, params):
		'''
		Function that returns the parameersneeded for the training procedure
		'''

		if 'epochs' in params:
			epochs = params['epochs']
		else:
			epochs = 10000

		if 'learning_rate' in params:
			lr = params['learning_rate']
		else:
			lr = 1E-5

		if 'optimizer' in params:
			optimizer = params['optimizer']
		else:
			optimizer = tf.keras.optimizers.Adam(lr)

		if 'loss_metric' in params:
			loss_metric = params['loss_metric']
		else:
			loss_metric = tf.keras.metrics.Mean()

		if 'batch_size' in params:
			batch_size = params['batch_size']
		else:
			batch_size = 50

		return epochs, lr, optimizer, loss_metric, batch_size

	def predict(self, X):
		return(self.model(X))

	def score(self, X, y, thresh=0.5):
		predictions = self.predict(X).numpy()

		predictions = np.reshape(predictions, (-1))
		predictions_bool = predictions > thresh

		ground_truth = np.reshape(y, (-1))
		ground_truth_bool = ground_truth > thresh

		accuracy = np.mean(np.equal(predictions_bool, ground_truth_bool))

		return accuracy

	def get_weights(self):
		return self.model.get_weights()


class SoftmaxRegression(Learner):
	def __init__(self, n_classes=3):
		self.model = SoftmaxRegressionModel(n_classes)

	def train(self, X, y, params={}):
		'''
		Set or get the parameters of the training procedure
		'''

		epochs, lr, optimizer, loss_metric, batch_size = self.get_training_parameters(params)
		loss_t_1 = np.inf
		cross_entropy_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False, 
			reduction=tf.keras.losses.Reduction.SUM)

		data_sampler = DataSampler(batch_size)
		next_epoch = False

		for epoch in range(epochs):
			# print('Start of epoch %d' % (epoch,))
			while next_epoch == False:
				with tf.GradientTape() as tape:
					X_batch, y_batch, next_epoch = data_sampler.get_data(X, y)
					y_hat = self.model(X_batch)
					loss = cross_entropy_loss(y_true=y_batch, y_pred=y_hat)

				grads = tape.gradient(loss, self.model.trainable_weights)

				optimizer.apply_gradients(zip(grads, self.model.trainable_weights))

				loss_metric(loss)

			next_epoch = False
			if epoch % 2 == 0:
				loss_t = loss_metric.result().numpy()
				print('Epoch ', epoch, ' Loss: ', loss_t, end='\r')

				if np.abs(loss_t - loss_t_1) < 1E-4:
					print('\n')
					print('Ending condition met. Weights converged')
					break

				loss_t_1 = loss_t

			next_epoch = False
def main():
	return -1

if __name__ == '__main__':
	main()