import tensorflow as tf
import time
from svgp_nn_inducing.tf2.utils import save_model
class EpochCSVLogger(tf.keras.callbacks.CSVLogger):

    def __init__(self, X_train, y_train, X_test, y_test, batch_size = None, filename = "./training.log", separator=',', append=False,
                 predict_test=True, optimizer = '', path_results = ''):
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.batch_size = batch_size
        self.filename_ = filename
        self.optimizer = optimizer
        self.path_results = path_results
        self.predict_test = predict_test
        super(EpochCSVLogger, self).__init__(filename, separator=',', append=False)

    def on_train_begin(self, logs=None):
        self.total_time_training = 0.0
        self.predict_time = 0
        super(EpochCSVLogger, self).on_train_begin()

    def on_train_end(self, logs=None):
        _, rmse_train, nll_train = self.model.evaluate(self.X_train, self.y_train, batch_size = self.batch_size)
        start_predict = time.time()
        _, rmse_test, nll_test = self.model.evaluate(self.X_test, self.y_test, batch_size = self.batch_size)
        end_predict = time.time()
        self.predict_time = end_predict - start_predict
        print("RMSE_train {}, NLL_train {}, RMSE_test {}, NLL_test {}, total_training_time {}, prediction_time {}".format(
            rmse_train, nll_train, rmse_test, nll_test, self.total_time_training, self.predict_time) )
        filename_extension = self.filename.split(".")
        with open(filename_extension[0] + '_final.'+ filename_extension[1], "w") as myfile:
            myfile.write('RMSE_train, NLL_train, RMSE_test, NLL_test, total_training_time, prediction_time' '\n')
            myfile.write(str(rmse_train) + " " + str(nll_train) + " " + str(rmse_test) +
                         " " + str(nll_test) + " " + str(self.total_time_training) + " " + str(self.predict_time) + '\n')

        super(EpochCSVLogger, self).on_train_end()


    def on_epoch_begin(self, epoch, logs=None):
        self.start_epoch = time.time()

    def on_epoch_end(self, epoch, logs=None):
       # save_model(self.model,self.optimizer, self.path_results)
        training_time = time.time() - self.start_epoch
        self.total_time_training += training_time
        logs['total_time_train'] = training_time
        if self.predict_test:
            _, err_test, nll_test = self.model.evaluate(self.X_test, self.y_test, batch_size = self.batch_size)
            logs[self.model.metrics_names[1] + '_test'] = err_test
            logs['nll_test'] = nll_test
        super(EpochCSVLogger, self).on_epoch_end(epoch, logs)

class NBatchCSVLogger(tf.keras.callbacks.CSVLogger):

    def __init__(self, X_test, y_test, batch_size = None, filename = "./training.log", separator=',', append=False, each_n_batches=500):
        # self.X_train = X_train
        # self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.batch_size = batch_size
        self.filename_ = filename
        self.filename_batch = filename[:filename.rfind('.txt')] + '_batch'+'.txt'
        self.each_n_batches = each_n_batches
        self.batch = 0
        super(NBatchCSVLogger, self).__init__(filename, separator=',', append=False)

    def on_train_begin(self, logs=None):
        self.total_time_training = 0.0
        self.predict_time = 0
        with open(self.filename_batch, "w") as myfile:
            myfile.write("batch" + " " +"err_test" + " " + "nll_test" + " " +
                            "training_time" +'\n')
        super(NBatchCSVLogger, self).on_train_begin()

    def on_train_end(self, logs=None):
        # _, rmse_train, nll_train = self.model.evaluate(self.X_train, self.y_train, batch_size = self.batch_size)
        start_predict = time.time()
        _, rmse_test, nll_test = self.model.evaluate(self.X_test, self.y_test, batch_size = self.batch_size)
        end_predict = time.time()
        self.predict_time = end_predict - start_predict

        print("RMSE_test {}, NLL_test {}, total_training_time {}, prediction_time {}".format(
            rmse_test, nll_test, self.total_time_training, self.predict_time) )
        filename_extension = self.filename.split(".")
        with open(filename_extension[0] + '_final.'+ filename_extension[1], "w") as myfile:
            myfile.write('RMSE_test, NLL_test, total_training_time, prediction_time' '\n')
            myfile.write(str(rmse_test) +
                         " " + str(nll_test) + " " + str(self.total_time_training) + " " + str(self.predict_time) + '\n')

        super(NBatchCSVLogger, self).on_train_end()

    def on_train_batch_begin(self, batch, logs=None):
        self.start_epoch_batch = time.time()

        super(NBatchCSVLogger, self).on_train_batch_begin(batch, logs)

    def on_train_batch_end(self, batch, logs=None):

        training_time = time.time() - self.start_epoch_batch
        self.total_time_training += training_time

        if self.batch % self.each_n_batches == 0:
            _, err_test, nll_test = self.model.evaluate(self.X_test, self.y_test, batch_size = self.batch_size)
            
            with open(self.filename_batch, "a") as myfile:
                myfile.write(str(self.batch) + " "+str(err_test) + " " + str(nll_test) + " " +
                            str(self.total_time_training) +'\n')
        self.batch += 1
        # super(NBatchCSVLogger, self).on_train_batch_end(batch, logs)
        
