# ---------
# Libraries
# ---------

import numpy as np
import time
import sys
import json
import os
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
import tensorflow.keras.backend as K
from tensorflow.keras.models import model_from_json

from BesselConv2d import BesselConv2d
from GroupConv2d import GroupConv2d

# ---------
# Callbacks
# ---------

# Get the computation time for each epoch
class TimeHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time.time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time.time() - self.epoch_time_start)

# Desactivate random rotations when testing
class DesactivateRot(keras.callbacks.Callback):
    def on_test_begin(self, logs=None):
        self.model.layers[0].factor = 0.

    def on_test_end(self, logs=None):
        self.model.layers[0].factor = 1.

# ------------
# Loading data
# ------------

train_images = np.load('./Outex_train_images.npy')
test_images = np.load('./Outex_test_images.npy')
train_labels = np.load('./Outex_train_labels.npy')
test_labels = np.load('./Outex_test_labels.npy')

# ---------------------------
# Loading the model & running
# ---------------------------

run_id = int(sys.argv[1])

n_epochs = 200
batch_size = 32
models = [x for x in os.listdir('./models') if '.json' in x]

results = {}
for model_file in models:

    print(model_file)

    # Loading the model
    with open('./models/' + model_file, 'r') as json_file:
        model = keras.models.model_from_json(json_file.read(), custom_objects = {'BesselConv2d': BesselConv2d, 
                                                                                 'GroupConv2d': GroupConv2d})

    # Training
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0005)
    model.compile(optimizer=optimizer,
                  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])

    time_callback = TimeHistory()
    desac_rot = DesactivateRot()
    history = model.fit(train_images, train_labels, epochs=n_epochs, batch_size=batch_size,
                        validation_data=(test_images, test_labels),
                        callbacks=[time_callback, desac_rot],
                        verbose=2)

    # Saving
    results[model_file] = {'n_epochs': int(len(history.history['accuracy'])),
                           'batch_size': int(batch_size),
                           'n_trainable_params': int(np.sum([K.count_params(w) for w in model.trainable_weights])),
                           'n_nontrainable_params': int(np.sum([K.count_params(w) for w in model.non_trainable_weights])),
                           'time_per_epoch': time_callback.times,
                           'training_loss': history.history['loss'],
                           'testing_loss': history.history['val_loss'],
                           'training_acc': history.history['accuracy'],
                           'testing_acc': history.history['val_accuracy']}

# -------
# Dumping
# -------

with open('./results/' + str(run_id) + '.json', 'w') as fp:
    json.dump(results, fp, indent = 4)