"""
Training time experiment:
Usage:
  training_time_experiment.py <experiment> --gpu <gpu> [--monitfreq <frequency>] [-s <root>]
  training_time_experiment.py -h | --help
  training_time_experiment.py --version

Options:
  --rev                  Whether to do a reversed experiment - remove everything except from the explanation per predicate.
  -s --save <root>       Specify whether and where to save the models and results.
  -h --help              Show this screen.
  --version              Show version.
"""
# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
from docopt import docopt
import os
import ampligraph
import numpy as np
import time
import pickle as pkl
import json
import keras
from ampligraph.datasets import load_fb15k_237
from ampligraph.utils import save_model
import json
from ampligraph.compat import TransE
import copy
from ampligraph.datasets import data_adapter
from ampligraph.latent_features.layers.calibration import CalibrationLayer
from ampligraph.latent_features.layers.corruption_generation import CorruptionGenerationLayerTrain
from ampligraph.datasets.data_indexer import DataIndexer
from ampligraph.latent_features import optimizers
from ampligraph.latent_features import loss_functions
from ampligraph.evaluation import train_test_split_no_unseen
from tensorflow.python.keras import callbacks as callbacks_module
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.eager import def_function
from tensorflow.python.keras import metrics as metrics_mod
from tensorflow.python.keras.engine import compile_utils
import tensorflow as tf
from ampligraph.evaluation import hits_at_n_score, mr_score, mrr_score
from ampligraph.compat import evaluate_performance
from ampligraph.utils import save_model


SATURATION = 5

def proba_not_changed(predictions, fname):

    with open(fname, 'rb') as f:
        preds_prev = pkl.loads(f.read())

    return np.allclose(predictions, preds_prev)


class MonitorTargetTripleProbaCallback(keras.callbacks.Callback):
    def __init__(self, X, name, root, monit_freq, explainer_name):
        self.X = X
        self.explainer_name = explainer_name
        self.name = name
        self.root = root
        if not os.path.exists(os.path.join(root, explainer_name)):
            os.mkdir(os.path.join(root, explainer_name))
        self.monit_freq = monit_freq
        self.saturation_count = 1

    def on_train_begin(self, logs=None):
        print("Starting training with monitoring predictions of test triples.")

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        if self.saturation_count > SATURATION:
            print("Saturated probabilities, ending retraining.")
            self.model.stop_training = True
        if epoch % self.monit_freq == 0 and epoch != 0:
            print(f"Evaluating predictions at epoch {epoch}")
            # calibration
            calibration_layer = external_calibrate(self.model, self.X['valid'], positive_base_rate=0.5, data_indexer=self.model.data_indexer)
            predictions = external_predict_proba(self.model, calibration_layer, self.X['test'], data_indexer=self.model.data_indexer)
            with open(f"{self.root}/{self.explainer_name}/predictions_at_epoch_{epoch}_{self.name}.pkl", 'wb') as f:
                f.write(pkl.dumps(predictions))
            fname = f"{self.root}/{self.explainer_name}/predictions_at_epoch_{epoch-self.monit_freq}_{self.name}.pkl"
            if epoch != self.monit_freq and proba_not_changed(predictions, fname):
                self.saturation_count += 1


def external_calibrate(model, X_pos, X_neg=None, positive_base_rate=None, batch_size=32, epochs=50, verbose=1, data_indexer=True):
        """Calibrate predictions"""
        data_handler_calibrate_pos = data_adapter.DataHandler(X_pos,
                                                              batch_size=batch_size,
                                                              dataset_type='test',
                                                              epochs=epochs,
                                                              use_filter=False,
                                                              use_indexer=data_indexer)
        
        pos_size = data_handler_calibrate_pos._parent_adapter.get_data_size()
        neg_size = pos_size
        
        if X_neg is None:
            assert positive_base_rate is not None, 'Please provide the negatives or positive base rate!'
            is_calibrate_with_corruption = True
        else:
            is_calibrate_with_corruption = False
            
            pos_batch_count = int(np.ceil(pos_size / batch_size))
            
            data_handler_calibrate_neg = data_adapter.DataHandler(X_neg,
                                                                  batch_size=batch_size,
                                                                  dataset_type='test',
                                                                  epochs=epochs,
                                                                  use_filter=False,
                                                                  use_indexer=data_indexer)

            neg_size = data_handler_calibrate_neg._parent_adapter.get_data_size()
            neg_batch_count = int(np.ceil(neg_size / batch_size))
            
            if pos_batch_count != neg_batch_count:
                batch_size_neg = int(np.ceil(neg_size / pos_batch_count))
                data_handler_calibrate_neg = data_adapter.DataHandler(X_neg,
                                                                      batch_size=batch_size_neg,
                                                                      dataset_type='test',
                                                                      epochs=epochs,
                                                                      use_filter=False,
                                                                      use_indexer=data_indexer)
            
            if positive_base_rate is None:
                positive_base_rate = pos_size / (pos_size + neg_size)

        if positive_base_rate is not None and (positive_base_rate <= 0 or positive_base_rate >= 1):
            raise ValueError("positive_base_rate must be a value between 0 and 1.")
        
        calibration_layer = CalibrationLayer(pos_size, neg_size, positive_base_rate)
        def calibrate_with_corruption(iterator):
              inputs = next(iterator)
              inp_emb = model.encoding_layer(inputs)
              inp_score = model.scoring_layer(inp_emb)
              
              corruptions = model.corruption_layer(inputs, model.num_ents, 1)
              corr_emb = model.encoding_layer(corruptions)
              corr_score = model.scoring_layer(corr_emb)
              return inp_score, corr_score

        calibrate_function = calibrate_with_corruption
        
        optimizer = tf.keras.optimizers.Adam()

        if not is_calibrate_with_corruption:
            negative_iterator = iter(data_handler_calibrate_neg.enumerate_epochs())
                
        for _, iterator in data_handler_calibrate_pos.enumerate_epochs(True):
            if not is_calibrate_with_corruption:
                _, neg_handle = next(negative_iterator)
            
            with data_handler_calibrate_pos.catch_stop_iteration():
                for step in data_handler_calibrate_pos.steps():
                    if is_calibrate_with_corruption:
                        scores_pos, scores_neg = calibrate_function(iterator)

                    else:
                        scores_pos = calibrate_function(iterator)
                        with data_handler_calibrate_neg.catch_stop_iteration():
                            scores_neg = calibrate_function(neg_handle)
                            
                    with tf.GradientTape() as tape:
                        out = calibration_layer(scores_pos, scores_neg, 1)

                    gradients = tape.gradient(out, calibration_layer._trainable_weights)
                    # update the trainable params
                    optimizer.apply_gradients(zip(gradients, calibration_layer._trainable_weights))
        return calibration_layer

def external_predict_proba(model, calibration_layer,
                           x,
                           batch_size=32,
                           verbose=1,
                           callbacks=None, data_indexer=False):
    data_handler_test = data_adapter.DataHandler(x,
                                                 batch_size=batch_size,
                                                 dataset_type='test',
                                                 epochs=1,
                                                 use_filter=False,
                                                 use_indexer=data_indexer)

    if not isinstance(callbacks, callbacks_module.CallbackList):
        callbacks = callbacks_module.CallbackList(
            callbacks,
            add_history=True,
            add_progbar=verbose != 0,
            model=model,
            verbose=verbose,
            epochs=1,
            steps=data_handler_test.inferred_steps)

    predict_function = external_make_predict_function(model)
    callbacks.on_predict_begin()
    outputs = []
    for _, iterator in data_handler_test.enumerate_epochs():
        with data_handler_test.catch_stop_iteration():
            for step in data_handler_test.steps():
                callbacks.on_predict_batch_begin(step)
                batch_outputs = predict_function(iterator)
                probas = calibration_layer(batch_outputs, training=0)
                outputs.append(probas)

                callbacks.on_predict_batch_end(step, {'outputs': batch_outputs})
    callbacks.on_predict_end()
    return np.concatenate(outputs)


def external_make_predict_function(model):
    ''' Similar to keras lib, this function returns the handle to predict step function. 
    It processes one batch of data by iterating over the dataset iterator and computes the predict outputs.

    Returns
    -------
    out: Function handle.
          Handle to the predict step function
    '''
    def predict_function(iterator):
        inputs = next(iterator)
        outputs = model.predict_step(inputs)
        return outputs

    if not model.run_eagerly and not model.is_partitioned_training:
        predict_function = def_function.function(predict_function,
                                                 experimental_relax_shapes=True)

    return predict_function

def get_time_experiment_callback(X, name, root, monit_freq, explainer_name):
    time_exp_callback = MonitorTargetTripleProbaCallback(X, name, root, monit_freq, explainer_name)
    return time_exp_callback


def train_model_with_probabilities_monit(model, X, name, root, monit_freq=50, explainer_name='none'):
    time_experiment_callback = get_time_experiment_callback(X, name, root, monit_freq, explainer_name)
    filter = np.concatenate((X['train'], X['valid'], X['test']))

    model.fit(X["train"], True,
              {
                  'x_valid': X['valid'][::2],
                  'criteria': 'mrr',
                  'x_filter': filter,
                  'stop_interval': 4,
                  'burn_in': 0,
                  'check_interval': 50
              }, callbacks=[time_experiment_callback])
    save_model(model, "'model_fb15k_237_monitoring_TT'")

    ranks = evaluate_performance(X['test'],
                                 model,
                                 filter,
                                 verbose=True)

    # compute and print metrics:
    mr = mr_score(ranks)
    mrr = mrr_score(ranks)
    hits_1 = hits_at_n_score(ranks, n=1)
    hits_3 = hits_at_n_score(ranks, n=3)
    hits_10 = hits_at_n_score(ranks, n=10)

    result = {
        "mr": mr,
        "mrr": mrr,
        "H@1": hits_1,
        "H@3": hits_3,
        "H@10": hits_10
    }
    print(result)
    with open(f'{root}/results_model_fb15k_237_monitoring_TT_{name}.json', 'w') as f:
        f.write(json.dumps(result))
    return result, model


if __name__ == "__main__":
    arguments = docopt(__doc__, version='Training time experiment') 
    root = arguments['--save']
    gpu = arguments['<gpu>']
    experiment_name = arguments['<experiment>']
    monit_freq = int(arguments['<frequency>'])
    with open('./config.json') as f:
        config = json.loads(f.read())
    dataset = experiment_name.split('_')[0].upper()
    model_name = experiment_name.split('_')[1].upper()
    params = config['hyperparams'][dataset][model_name]
    if root is not None and not os.path.exists(root):
        os.mkdir(root)
    print(arguments)
    model_cls = config['model_name_map'][experiment_name.split('_')[1].upper()]

    with open('./config.json') as f:
        config = json.loads(f.read())

    hyperparams = config['hyperparams'][dataset.upper()][model_name.upper()]
#    hyperparams['epochs'] = epochs
    load_func = getattr(ampligraph.datasets,
                        config["load_function_map"][dataset])
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    X = load_func()
    model =  getattr(ampligraph.compat, model_cls)(**hyperparams)
    name = f"{experiment_name}_full"
    result = train_model_with_probabilities_monit(model, X, name, root, monit_freq)
