#!/usr/bin/env python3
"""
Created on 21:27, May. 17th, 2023

@author: Anonymous
"""
import os, sys
import time, json
import copy as cp
import numpy as np
import tensorflow as tf
from functools import partial
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir))
import utils; import utils.model; import utils.data.eeg
from models.conv_net import conv_ensemble_net as conv_net_model

__all__ = [
    "train",
]

# Global variables.
params = None; paths = None
model = None; optimizer = None

"""
init funcs
"""
# def init func
def init(base_, params_):
    """
    Initialize `conv_net` training variables.

    Args:
        base_: str - The base path of current project.
        params_: DotDict - The parameters of current training process.

    Returns:
        None
    """
    global params, paths
    # Initialize params.
    params = cp.deepcopy(params_)
    paths = utils.Paths(base=base_, params=params)
    # Initialize model.
    _init_model()
    # Initialize training process.
    _init_train()

# def _init_model func
def _init_model():
    """
    Initialize model used in the training process.
    """
    global params, model
    ## Initialize tf configuration.
    # Not set random seed, should be done before initializing `model`.
    tf.keras.backend.set_floatx(params._precision)
    # Check whether run in graph mode or eager mode.
    tf.config.run_functions_eagerly(not params.train.use_graph_mode)
    ## Initialize the model.
    model = conv_net_model(params.model)

# def _init_train func
def _init_train():
    """
    Initialize the training process.
    """
    global params, optimizer
    # Make an ADAM optimizer for model.
    optimizer = tf.keras.optimizers.Adam(learning_rate=params.train.lr_pretrain_i)

"""
data funcs
"""
# def load_data func
def load_data(load_params):
    """
    Load data from specified dataset.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train: tf.data.Dataset - The input train dataset.
        dataset_validation: tf.data.Dataset - The input validation dataset.
        dataset_test: tf.data.Dataset - The input test dataset.
    """
    global params
    # Load data from specified dataset.
    try:
        func = getattr(sys.modules[__name__], "_".join(["_load_data", params.train.dataset]))
        dataset_train, dataset_validation, dataset_test = func(load_params)
    except Exception:
        raise ValueError("ERROR: Unknown dataset type {} in train.conv_net.".format(params.train.dataset))
    # Return the final `dataset_*`.
    return dataset_train, dataset_validation, dataset_test

# def _load_data func
def _load_data(path_tfrecords, parse_func):
    """
    Load data from specified tfrecord files.

    Args:
        path_tfrecords: list - The list of specified tfrecord files.
        parse_func: func - The parse function.

    Returns:
        dataset: tf.data.Dataset - The instantiated dataset.
    """
    # Get the corresponding tfrecord files.
    tfrecord_files = tf.data.Dataset.list_files(path_tfrecords, shuffle=True)
    # Load dataset from tfrecord files, then transform dataset.
    dataset = tfrecord_files.interleave(map_func=tf.data.TFRecordDataset, cycle_length=1)
    dataset = dataset.map(lambda x: parse_func(x))
    # Shuffle dataset, then crop it into batches.
    dataset = dataset.shuffle(buffer_size=params.train.buffer_size)
    dataset = dataset.batch(batch_size=params.train.batch_size).prefetch(buffer_size=params.train.batch_size)
    # Return the final `dataset`.
    return dataset

# def _load_data_eeg_zhou2023cibr func
def _load_data_eeg_zhou2023cibr(load_params):
    """
    Load eeg data from specified subject.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset: tuple - The input train dataset.
    """
    global params, paths
    # Initialize the path of base dataset & dataset config & pretrain dataset.
    path_dataset = os.path.join(paths.base, "data", "eeg.zhou2023cibr", "dataset.unimodality",
        "054.20230726", params.train.modality.replace("/", ""))
    path_dataset_config = os.path.join(path_dataset, "config.json")
    path_dataset_pretrain = os.path.join(path_dataset, "pretrain")
    # Compose the parse function according to configuration.
    with open(path_dataset_config, "r") as f:
        config = json.load(f)
    _parse_func = partial(_parse_func_eeg_zhou2023cibr, config=config)
    # Initialize the tfrecords path of train-set & validation-set & test-set.
    path_dataset_pretrain_train = os.path.join(path_dataset_pretrain, "train")
    path_dataset_pretrain_train_tfrecords = sorted([os.path.join(path_dataset_pretrain_train, fname_i)\
        for fname_i in os.listdir(path_dataset_pretrain_train) if fname_i.startswith("train")])
    path_dataset_pretrain_validation = os.path.join(path_dataset_pretrain, "validation")
    path_dataset_pretrain_validation_tfrecords = sorted([os.path.join(path_dataset_pretrain_validation, fname_i)\
        for fname_i in os.listdir(path_dataset_pretrain_validation) if fname_i.startswith("validation")])
    path_dataset_pretrain_test = os.path.join(path_dataset_pretrain, "test")
    path_dataset_pretrain_test_tfrecords = sorted([os.path.join(path_dataset_pretrain_test, fname_i)\
        for fname_i in os.listdir(path_dataset_pretrain_test) if fname_i.startswith("test")])
    # Get dataset of train-set & validation-set & test-set.
    dataset_train = _load_data(path_dataset_pretrain_train_tfrecords, parse_func=_parse_func)
    dataset_validation = _load_data(path_dataset_pretrain_validation_tfrecords, parse_func=_parse_func)
    dataset_test = _load_data(path_dataset_pretrain_test_tfrecords, parse_func=_parse_func)
    # Return the final `dataset_*`.
    return dataset_train, dataset_validation, dataset_test

# def _parse_func_eeg_zhou2023cibr func
def _parse_func_eeg_zhou2023cibr(example_proto, config):
    """
    Parse function used to load data from tf-records.
    """
    # Initialize macros used to un-pack data.
    # TODO: Load these macros from json file.
    n_labels = config["n_labels"]; n_subjects = config["n_subjects"]; data_shape = config["data_shape"]
    # Initialize the structure of feature, then get the parsed features.
    feature = {
        "label": tf.io.FixedLenFeature((), tf.string),
        "data": tf.io.FixedLenFeature((), tf.string),
        "subj_id": tf.io.FixedLenFeature((), tf.string),
    }; features_parsed = tf.io.parse_single_example(example_proto, feature)
    # Get `label` & `data` & `subj_id` from `features_parsed`.
    label = features_parsed["label"]; data = features_parsed["data"]; subj_id = features_parsed["subj_id"]
    # Convert them to the specified data type.
    label = tf.io.parse_tensor(label, out_type=tf.float32)
    data = tf.io.parse_tensor(data, out_type=tf.float32)
    subj_id = tf.io.parse_tensor(subj_id, out_type=tf.float32)
    # Reshape them to the specified data shape.
    # label - (n_labels,); data - (seq_len, n_channels); subj_id - (n_subjects,)
    label = tf.reshape(label, shape=(n_labels,))
    data = tf.transpose(tf.reshape(data, shape=data_shape), perm=[1,0])
    subj_id = tf.reshape(subj_id, shape=(n_subjects,))
    # Return the final `item`.
    return data, label, subj_id

"""
train funcs
"""
# def train func
def train(base_, params_):
    """
    Train the model.

    Args:
        base_: str - The base path of current project.
        params_: DotDict - The parameters of current training process.

    Returns:
        None
    """
    global _forward, _train
    global params, paths, model, optimizer
    # Initialize parameters & variables of current training process.
    init(base_, params_)
    # Log the start of current training process.
    paths.run.logger.summaries.info("Training started with dataset {}.".format(params.train.dataset))
    # Initialize load_params. Each load_params_i corresponds to a sub-dataset.
    if params.train.dataset == "eeg_zhou2023cibr":
        load_params = utils.DotDict()
    else:
        raise ValueError("ERROR: Unknown dataset {} in train.conv_net.".format(params.train.dataset))
    # Execute experiments for each dataset run.
    accuracies_validation = []; accuracies_test = []
    for epoch_idx in range(params.train.n_epochs.pretrain):
        # Record the start time of preparing data.
        time_start = time.time()
        # Reset the iteration information of params.
        # TODO: Update the parameters in model.
        params.iteration(iteration=epoch_idx)
        # Load dataset from specified modality.
        dataset_train, dataset_validation, dataset_test = load_data(load_params)
        # Execute train process.
        accuracy_train = []; loss_train = []; batch_size_train = []
        for batch_i in dataset_train:
            # Update `batch_i` to have pseudo-`Y_f`.
            batch_i = [batch_i[0], batch_i[1], tf.random.uniform((batch_i[0].shape[0], params.model.d_contra)), batch_i[2]]
            # Get the number of current batch_i.
            batch_size_i = len(batch_i[0]); batch_size_train.append(batch_size_i)
            # Train model for current batch.
            y_pred_i, loss_i = _train(batch_i); y_pred_i = y_pred_i.numpy(); loss_i = loss_i.numpy()
            # Calculate the corresponding accuracy.
            accuracy_train_i = np.argmax(y_pred_i, axis=-1) == np.argmax(batch_i[1], axis=-1)
            accuracy_train_i = np.sum(accuracy_train_i) / accuracy_train_i.size
            accuracy_train.append(accuracy_train_i); loss_train.append(loss_i)
        accuracy_train = np.sum(np.array(accuracy_train) * np.array(batch_size_train)) / np.sum(batch_size_train)
        loss_train = np.sum(np.array(loss_train) * np.array(batch_size_train)) / np.sum(batch_size_train)
        # Execute validation process.
        accuracy_validation = []; loss_validation = []; batch_size_validation = []
        for batch_i in dataset_validation:
            # Update `batch_i` to have pseudo-`Y_f`.
            batch_i = [batch_i[0], batch_i[1], tf.random.uniform((batch_i[0].shape[0], params.model.d_contra)), batch_i[2]]
            # Get the number of current batch_i.
            batch_size_i = len(batch_i[0]); batch_size_validation.append(batch_size_i)
            # Validate model for current batch.
            y_pred_i, loss_i = _forward(batch_i); y_pred_i = y_pred_i.numpy(); loss_i = loss_i.numpy()
            # Calculate the corresponding accuracy.
            accuracy_validation_i = np.argmax(y_pred_i, axis=-1) == np.argmax(batch_i[1], axis=-1)
            accuracy_validation_i = np.sum(accuracy_validation_i) / accuracy_validation_i.size
            accuracy_validation.append(accuracy_validation_i); loss_validation.append(loss_i)
        accuracy_validation = np.sum(np.array(accuracy_validation) *\
            np.array(batch_size_validation)) / np.sum(batch_size_validation)
        loss_validation = np.sum(np.array(loss_validation) * np.array(batch_size_validation)) / np.sum(batch_size_validation)
        # Execute test process.
        accuracy_test = []; loss_test = []; batch_size_test = []
        for batch_i in dataset_test:
            # Update `batch_i` to have pseudo-`Y_f`.
            batch_i = [batch_i[0], batch_i[1], tf.random.uniform((batch_i[0].shape[0], params.model.d_contra)), batch_i[2]]
            # Get the number of current batch_i.
            batch_size_i = len(batch_i[0]); batch_size_test.append(batch_size_i)
            # Test model for current batch.
            y_pred_i, loss_i = _forward(batch_i); y_pred_i = y_pred_i.numpy(); loss_i = loss_i.numpy()
            # Calculate the corresponding accuracy.
            accuracy_test_i = np.argmax(y_pred_i, axis=-1) == np.argmax(batch_i[1], axis=-1)
            accuracy_test_i = np.sum(accuracy_test_i) / accuracy_test_i.size
            accuracy_test.append(accuracy_test_i); loss_test.append(loss_i)
        accuracy_test = np.sum(np.array(accuracy_test) * np.array(batch_size_test)) / np.sum(batch_size_test)
        loss_test = np.sum(np.array(loss_test) * np.array(batch_size_test)) / np.sum(batch_size_test)
        # Log information related to current training epoch.
        time_stop = time.time(); accuracies_validation.append(accuracy_validation); accuracies_test.append(accuracy_test)
        msg = (
            "Finish train epoch {:d} in {:.2f} seconds, generating {:d} concrete functions."
        ).format(epoch_idx, time_stop-time_start, len(_train.pretty_printed_concrete_signatures().split("\n\n")))
        print(msg); paths.run.logger.summaries.info(msg)
        msg = (
            "Accuracy(train): {:.2f}%. Loss(train): {:.5f}."
        ).format(accuracy_train * 100., loss_train)
        print(msg); paths.run.logger.summaries.info(msg)
        msg = (
            "Accuracy(validation): {:.2f}%. Loss(validation): {:.5f}."
        ).format(accuracy_validation * 100., loss_validation)
        print(msg); paths.run.logger.summaries.info(msg)
        msg = (
            "Accuracy(test): {:.2f}%. Loss(test): {:.5f}."
        ).format(accuracy_test * 100., loss_test)
        print(msg); paths.run.logger.summaries.info(msg)
        # Summarize model information.
        if epoch_idx == 0:
            model.summary(print_fn=print); model.summary(print_fn=paths.run.logger.summaries.info)
        # Save model every `i_model` epochs.
        if (epoch_idx % params.train.i_model == 0) or (epoch_idx + 1 == params.train.n_epochs.pretrain):
            path_checkpoint = os.path.join(paths.run.model, "pretrain-{:d}.ckpt".format(epoch_idx))
            model.save_weights(path_checkpoint)
    # Convert `accuracies_validation` & `accuracies_test` to `np.array`.
    accuracies_validation = np.round(np.array(accuracies_validation, dtype=np.float32), decimals=4)
    accuracies_test = np.round(np.array(accuracies_test, dtype=np.float32), decimals=4)
    epoch_maxacc_idxs = np.where(accuracies_validation == np.max(accuracies_validation))[0]
    epoch_maxacc_idx = epoch_maxacc_idxs[np.argmax(accuracies_test[epoch_maxacc_idxs])]
    # Finish training process of current specified experiment.
    msg = (
        "Finish the training process, with test-accuracy ({:.2f}%)" +\
        " according to max validation-accuracy ({:.2f}%) at epoch {:d}, generating {:d} concrete functions."
    ).format(accuracies_test[epoch_maxacc_idx]*100., accuracies_validation[epoch_maxacc_idx]*100.,
        epoch_maxacc_idx, len(_train.pretty_printed_concrete_signatures().split("\n\n")))
    print(msg); paths.run.logger.summaries.info(msg)

# def _forward func
@tf.function
def _forward(inputs, training=False):
    """
    Forward the model using one-step data. Everything entering this function already be a tensor.
    :param inputs: (X, locations, subject_id, y)
    :param training: Indicate whether enable training process.
    :return outputs_: (n_samples, n_labels) - The predicted labels of inputs.
    :return loss_: float - The corresponding cross-entropy loss.
    """
    global model; return model(inputs, training=training)

# def _train func
@tf.function
def _train(inputs):
    """
    Train the model using one-step data. Everything entering this function already be a tensor.
    :param inputs: (X, locations, subject_id, y)
    :return outputs_: (n_samples, n_labels) - The predicted labels of inputs.
    :return loss_: float - The corresponding cross-entropy loss.
    """
    global model, optimizer
    # Train the model using one-step data.
    with tf.GradientTape() as gt:
        outputs_, loss_ = _forward(inputs, training=True)
    # Modify weights to optimize the model.
    gradients = gt.gradient(loss_, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    # Return the final `outputs_` & `loss_`.
    return outputs_, loss_

if __name__ == "__main__":
    import os
    # local dep
    from params.conv_net_params import conv_net_params

    # macro
    dataset = "eeg_zhou2023cibr"

    # Initialize random seed.
    utils.model.set_seeds(42)

    ## Instantiate conv_net.
    # Initialize base.
    base = os.path.join(os.getcwd(), os.pardir, os.pardir, os.pardir)
    # Instantiate conv_net_params.
    conv_net_params_inst = conv_net_params(dataset=dataset)
    # Train conv_net.
    train(base, conv_net_params_inst)

