import tensorflow.keras.backend as K
import tensorflow as tf
import numpy as np
from tensorflow import keras


### Define Training Loop ###
def training_loop(model, x, y, batch_size, epochs, verbose, val_data, val_batch_size, path, lr_schedule=None, early_stopping=None):
    
    # CHECK_EPOCH = 20
    
    # Initialize training history
    result_history = {'Acc-U': [], 'Acc-L': [], 'Loss-U': [], 'Loss-L': [],
               'val_Acc-U': [], 'val_Acc-L': [], 'val_Loss-U': [], 'val_Loss-L': []}

    # Batch training & validation dataset
    train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
    x_val, y_val = val_data
    val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
    val_dataset = val_dataset.batch(val_batch_size)


    for epoch in range(1, epochs+1):
        if verbose:
           print(f"Epoch: {epoch}/{epochs}")

        # Train Step
        for step_train, (x_batch_train, y_batch_train) in enumerate(train_dataset):

            batch_num = x_batch_train.shape[0]
            train_history = model.train_step((x_batch_train, y_batch_train), batch_num)

            if verbose:
              print(f"\r{step_train+1}/{len(train_dataset)} " f"Acc-U: {train_history['Acc-U'].numpy():.4f} " f"Acc-L: {train_history['Acc-L'].numpy():.4f} "
                    f"Loss-U: {train_history['Loss-U'].numpy():.4f} " f"Loss-L: {train_history['Loss-L'].numpy():.4f}", end="", flush=True)
        if verbose:
           print("\n", end="")

        # Update to history per epoch
        result_history['Acc-U'].append(train_history['Acc-U'].numpy())
        result_history['Acc-L'].append(train_history['Acc-L'].numpy())
        result_history['Loss-U'].append(train_history['Loss-U'].numpy())
        result_history['Loss-L'].append(train_history['Loss-L'].numpy())

        # Validation Step
        for step_val, (x_batch_val, y_batch_val) in enumerate(val_dataset):
            val_history = model.test_step((x_batch_val, y_batch_val))

            if verbose:
              print(f"\r{step_val+1}/{len(val_dataset)} " f"val_Acc-U: {val_history['Acc-U'].numpy():.4f} " f"val_Acc-L: {val_history['Acc-L'].numpy():.4f} "
                    f"val_Loss-U: {val_history['Loss-U'].numpy():.4f} " f"val_Loss-L: {val_history['Loss-L'].numpy():.4f}", end="", flush=True)
        if verbose:
            print("\n", end="")

        # Update to history per epoch
        result_history['val_Acc-U'].append(val_history['Acc-U'].numpy())
        result_history['val_Acc-L'].append(val_history['Acc-L'].numpy())
        result_history['val_Loss-U'].append(val_history['Loss-U'].numpy())
        result_history['val_Loss-L'].append(val_history['Loss-L'].numpy())

        if lr_schedule is not None:
           lr = lr_schedule(epoch)
           model.optimizer.learning_rate.assign(lr)
            
        if early_stopping is not None:
           if early_stopping(result_history, epoch):
                model.save(path+'_model.keras')
                break
               
        if (epoch % CHECK_EPOCH == 0) and (epoch != epochs):
            model.save(path+'_model_'+str(epoch)+'.keras')
            
        if epoch == epochs:
           model.save(path+'_model.keras')
    return result_history

