# ---------
# Libraries
# ---------

import numpy as np
import time
import sys
import json
import os
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
import tensorflow.keras.backend as K
from tensorflow.keras.models import model_from_json

from BesselConv2d import BesselConv2d
from GroupConv2d import GroupConv2d

# ---------
# Callbacks
# ---------

# Get the computation time for each epoch
class TimeHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []
        self.times_batch = []

    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time.time()

    def on_batch_begin(self, batch, logs={}):
        self.batch_time_start = time.time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time.time() - self.epoch_time_start)

    def on_batch_end(self, batch, logs={}):
        self.times_batch.append(time.time() - self.batch_time_start)

# Get the loss & accuracy for each batch
class BatchHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.train_loss_batch = []

    def on_batch_end(self, batch, logs={}):
        self.train_loss_batch.append(logs.get('loss'))

# ------------
# Loading data
# ------------

# Load data
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
train_images, test_images = (train_images[:,:,:,tf.newaxis] / 255.0), (test_images[:,:,:,tf.newaxis] / 255.0)

# Rotate testing
angles = np.random.uniform(0, 2.*np.pi, size=(10000))
test_rotated = tfa.image.rotate(test_images, angles=angles)

# ---------------------------
# Loading the model & running
# ---------------------------

run_id = int(sys.argv[1])

n_epochs = 20
batch_size = 32
models = [x for x in os.listdir('./models') if '.json' in x]

results = {}
for model_file in models:

    print(model_file)

    # Loading the model
    with open('./models/' + model_file, 'r') as json_file:
        model = keras.models.model_from_json(json_file.read(), custom_objects = {'BesselConv2d': BesselConv2d, 
                                                                                 'GroupConv2d': GroupConv2d})

    # Training
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
    model.compile(optimizer=optimizer,
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])

    time_callback = TimeHistory()
    batch_callback = BatchHistory()
    history = model.fit(train_images, train_labels, epochs=n_epochs, batch_size=batch_size,
                        validation_data=(test_rotated, test_labels),
                        callbacks=[time_callback, batch_callback],
                        verbose=2)

    # Saving
    results[model_file] = {'n_epochs': int(len(history.history['accuracy'])),
                           'batch_size': int(batch_size),
                           'n_trainable_params': int(np.sum([K.count_params(w) for w in model.trainable_weights])),
                           'n_nontrainable_params': int(np.sum([K.count_params(w) for w in model.non_trainable_weights])),
                           'time_per_epoch': time_callback.times,
                           'time_per_batch': time_callback.times_batch,
                           'training_loss': history.history['loss'],
                           'training_loss_batch': batch_callback.train_loss_batch,
                           'testing_loss': history.history['val_loss'],
                           'training_acc': history.history['accuracy'],
                           'testing_acc': history.history['val_accuracy']}

# -------
# Dumping
# -------

with open('./results/' + str(run_id) + '.json', 'w') as fp:
    json.dump(results, fp, indent = 4)