#!/usr/bin/env python3
"""
Created on 15:47, Jul. 28th, 2023

@author: Anonymous
"""
import os, sys
import time, json
import copy as cp
import numpy as np
import tensorflow as tf
from functools import partial
from collections import Counter
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir, os.pardir))
import utils; import utils.model
import utils.data.eeg
from models.domain_adaptation import DomainAdversarialTransformer as datn_model

__all__ = [
    "train",
]

# Global variables.
params = None; paths = None; model = None

"""
init funcs
"""
# def init func
def init(base_, params_):
    """
    Initialize `DomainAdversarialTransformer` training variables.

    Args:
        base_: str - The base path of current project.
        params_: DotDict - The parameters of current training process.

    Returns:
        None
    """
    global params, paths
    # Initialize parameters.
    params = cp.deepcopy(params_)
    paths = utils.Paths(base=base_, params=params)
    # Initialize model.
    _init_model()
    # Initialize training process.
    _init_train()

# def _init_model func
def _init_model():
    """
    Initialize model used in the training process.

    Args:
        None

    Returns:
        None
    """
    global params, model
    ## Initialize tf configuration.
    # Not set random seed, should be done before instantiating `model`.
    # Set default precision.
    tf.keras.backend.set_floatx(params._precision)
    # Check whether run in graph mode or eager mode.
    tf.config.run_functions_eagerly(not params.train.use_graph_mode)
    ## Modify model parameters.
    # The number of domains.
    params.model.n_domains = params.model.cls_domain.d_output = 4
    # The dimensions of the embedding.
    params.model.encoder.d_model = 256
    # The maximum length of element sequence.
    params.model.encoder.max_len = 80
    # The depth of encoder.
    params.model.encoder.n_blocks = 8
    # The number of attention heads.
    params.model.encoder.n_heads = 8
    # The dimensions of attention head.
    params.model.encoder.d_head = 128
    # The dropout probability of attention weights.
    params.model.encoder.mha_dropout_prob = 0.
    # The dimensions of the hidden layer in ffn.
    params.model.encoder.d_ff = 256
    # The dropout probability of the hidden layer in ffn.
    params.model.encoder.ff_dropout_prob = [0., 0.3]
    # The weight factor used to scale domain loss.
    params.model.w_loss_domain = .0
    # Number of batch size used in training process.
    params.train.batch_size = 256
    # Number of epochs used in training process.
    params.train.n_epochs = 500
    ## Initialize the model.
    model = datn_model(params.model)

# def _init_train func
def _init_train():
    """
    Initialize the training process.

    Args:
        None

    Returns:
        None
    """
    pass

"""
data funcs
"""
# def load_data func
def load_data(load_params):
    """
    Load data from specified dataset.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_source: tf.data.Dataset - The input source dataset.
        dataset_target: tf.data.Dataset - The input target dataset.
    """
    global params
    # Load data from specified dataset.
    try:
        func = getattr(sys.modules[__name__], "_".join(["_load_data", params.train.dataset]))
        dataset_source, dataset_target = func(load_params)
    except Exception:
        raise ValueError("ERROR: Unknown dataset type {} in train.domain_adaptation.".format(params.train.dataset))
    # Return the final `dataset_*`.
    return dataset_source, dataset_target

# def _load_data func
def _load_data(path_tfrecords, parse_func):
    """
    Load data from specified tfrecord files.

    Args:
        path_tfrecords: list - The list of specified tfrecord files.
        parse_func: func - The parse function.

    Returns:
        dataset: tf.data.Dataset - The instantiated dataset.
    """
    # Get the corresponding tfrecord files.
    tfrecord_files = tf.data.Dataset.list_files(path_tfrecords, shuffle=True)
    # Load dataset from tfrecord files, then transform dataset.
    dataset = tfrecord_files.interleave(map_func=tf.data.TFRecordDataset, cycle_length=1)
    dataset = dataset.map(lambda x: parse_func(x))
    # Shuffle dataset, then crop it into batches.
    dataset = dataset.shuffle(buffer_size=params.train.buffer_size)
    dataset = dataset.batch(batch_size=params.train.batch_size,
        drop_remainder=True).prefetch(buffer_size=params.train.buffer_size)
    # Return the final `dataset`.
    return dataset

# def _load_data_eeg_zhou2023cibr func
def _load_data_eeg_zhou2023cibr(load_params):
    """
    Load eeg data from specified subject.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_source: tf.data.Dataset - The input source dataset.
        dataset_target: tf.data.Dataset - The input target dataset.
    """
    assert len(load_params.sourceset) == len(load_params.targetset) == 1
    # Get the corresponding source & target dataset.
    dataset_source = _load_data_eeg_zhou2023cibr_helper(load_params.sourceset[0])
    dataset_target = _load_data_eeg_zhou2023cibr_helper(load_params.targetset[0])
    # Return the final `dataset_*`.
    return dataset_source, dataset_target

# def _load_data_eeg_zhou2023cibr_helper func
def _load_data_eeg_zhou2023cibr_helper(modality):
    """
    Load eeg data from specified subject.

    Args:
        modality: DotDict - The specified modality.

    Returns:
        dataset: tuple - The input train dataset.
    """
    global params, paths
    # Initialize the path of base dataset & dataset config dataset.
    path_dataset = os.path.join(paths.base, "data", "eeg.zhou2023cibr", "dataset.crossmodality", modality.replace("/", ""))
    path_dataset_config = os.path.join(path_dataset, "config.json")
    # Compose the parse function according to configuration.
    with open(path_dataset_config, "r") as f:
        config = json.load(f)
    _parse_func = partial(_parse_func_eeg_zhou2023cibr, config=config)
    # Initialize the tfrecords path of train-set & validation-set & test-set.
    path_dataset_train = os.path.join(path_dataset, "train")
    path_dataset_train_tfrecords = sorted([os.path.join(path_dataset_train, fname_i)\
        for fname_i in os.listdir(path_dataset_train) if fname_i.startswith("train")])
    path_dataset_validation = os.path.join(path_dataset, "validation")
    path_dataset_validation_tfrecords = sorted([os.path.join(path_dataset_validation, fname_i)\
        for fname_i in os.listdir(path_dataset_validation) if fname_i.startswith("validation")])
    path_dataset_test = os.path.join(path_dataset, "test")
    path_dataset_test_tfrecords = sorted([os.path.join(path_dataset_test, fname_i)\
        for fname_i in os.listdir(path_dataset_test) if fname_i.startswith("test")])
    # Get dataset of train-set & validation-set & test-set.
    dataset_train = _load_data(path_dataset_train_tfrecords, parse_func=_parse_func)
    dataset_validation = _load_data(path_dataset_validation_tfrecords, parse_func=_parse_func)
    dataset_test = _load_data(path_dataset_test_tfrecords, parse_func=_parse_func)
    # Return the final `dataset_*`.
    return dataset_train, dataset_validation, dataset_test

# def _parse_func_eeg_zhou2023cibr func
def _parse_func_eeg_zhou2023cibr(example_proto, config):
    """
    Parse function used to load data from tf-records.
    """
    # Initialize macros used to un-pack data.
    # TODO: Load these macros from json file.
    n_labels = config["n_labels"]; n_subjects = config["n_subjects"]
    n_domains = config["n_domains"]; data_shape = config["data_shape"]
    # Initialize the structure of feature, then get the parsed features.
    feature = {
        "label": tf.io.FixedLenFeature((), tf.string),
        "data": tf.io.FixedLenFeature((), tf.string),
        "subj_id": tf.io.FixedLenFeature((), tf.string),
        "domain_id": tf.io.FixedLenFeature((), tf.string),
    }; features_parsed = tf.io.parse_single_example(example_proto, feature)
    # Get `label` & `data` & `subj_id` & `domain_id` from `features_parsed`.
    label = features_parsed["label"]; data = features_parsed["data"]
    subj_id = features_parsed["subj_id"]; domain_id = features_parsed["domain_id"]
    # Convert them to the specified data type.
    label = tf.io.parse_tensor(label, out_type=tf.float32)
    data = tf.io.parse_tensor(data, out_type=tf.float32)
    subj_id = tf.io.parse_tensor(subj_id, out_type=tf.float32)
    domain_id = tf.io.parse_tensor(domain_id, out_type=tf.float32)
    # Reshape them to the specified data shape.
    # label - (n_labels,); data - (seq_len, n_channels); subj_id - (n_subjects,); domain_id - (n_domains,)
    label = tf.reshape(label, shape=(n_labels,))
    data = tf.transpose(tf.reshape(data, shape=data_shape), perm=[1,0])
    subj_id = tf.reshape(subj_id, shape=(n_subjects,))
    domain_id = tf.reshape(domain_id, shape=(n_domains,))
    # Return the final `item`.
    return data, label, subj_id, domain_id

"""
train funcs
"""
# def train func
def train(base_, params_):
    """
    Train the model.

    Args:
        base_: str - The base path of current project.
        params_: str - The parameters of current training process.

    Returns:
        None
    """
    global params, paths, model
    # Initialize parameters & variables of current training process.
    init(base_, params_)
    # Log the start of current training process.
    paths.run.logger.summaries.info("Training started with dataset {}.".format(params.train.dataset))
    # Initialize load_params. Each load_params_i corresponds to a sub-dataset.
    if params.train.dataset == "eeg_zhou2023cibr":
        # `load_params` contains all the experiments that we want to execute for every run.
        load_params = utils.DotDict({
            "name": "source-task-all-image-target-tmr-n23-audio",
            "sourceset": ["image",],
            "targetset": ["N2/3",],
        })
    else:
        raise ValueError("ERROR: Unknown dataset {} in train.naive_rnn.".format(params.train.dataset))
    # Load data from specified experiment.
    dataset_source, dataset_target = load_data(load_params)
    dataset_source_train, dataset_source_validation, dataset_source_test = dataset_source
    dataset_target_train, dataset_target_validation, dataset_target_test = dataset_target
    n_batches_source_train = len(list(dataset_source_train))
    n_batches_source_validation = len(list(dataset_source_validation))
    n_batches_source_test = len(list(dataset_source_test))
    n_batches_target_train = len(list(dataset_target_train))
    n_batches_target_validation = len(list(dataset_target_validation))
    n_batches_target_test = len(list(dataset_target_test))
    # Log information related to data loading.
    batch_size = params.train.batch_size
    msg = (
        "INFO: The source dataset includes train-set (~{:d}) & validation-set (~{:d}) & test-set (~{:d})," +\
        " the target dataset includes train-set (~{:d}) & validation-set (~{:d}) & test-set (~{:d})."
    ).format(n_batches_source_train * batch_size, n_batches_source_validation * batch_size, n_batches_source_test * batch_size,
        n_batches_target_train * batch_size, n_batches_target_validation * batch_size, n_batches_target_test * batch_size)
    print(msg); paths.run.logger.summaries.info(msg)
    # Execute experiments for each dataset run.
    msg = "Start the training process of experiment {}.".format(load_params.name)
    print(msg); paths.run.logger.summaries.info(msg)
    accuracies_source = []; accuracies_target_train = []; accuracies_target_validation = []; accuracies_target_test = []
    for epoch_idx in range(params.train.n_epochs):
        # Record the start time of preparing data.
        time_start = time.time()
        # Reset the iteration information of params.
        params.iteration(iteration=epoch_idx)

        # Prepare for model train process.
        accuracy_source = []; batch_size_source = []; loss_source = utils.DotDict({"class": [], "domain": [],})
        accuracy_target_train = []; batch_size_target_train = []
        loss_target_train = utils.DotDict({"class": [], "domain": [],})
        # Train model for current epoch.
        iter_source = iter(dataset_source_train); n_batches_source = n_batches_source_train
        iter_target = iter(dataset_target_train); n_batches_target = n_batches_target_train
        assert n_batches_source >= n_batches_target; n_batches = n_batches_source
        for batch_idx in range(n_batches):
            # Initialize `inputs_*_i` from `iter_*`.
            inputs_source_i = iter_source.next(); inputs_target_i = iter_target.next()
            # Re-formulate `inputs_*_i`.
            inputs_source_i = [inputs_source_i[0], inputs_source_i[1], inputs_source_i[3]]
            inputs_target_i = [inputs_target_i[0], inputs_target_i[1], inputs_target_i[3]]
            inputs_i = utils.DotDict({"source": inputs_source_i, "target": inputs_target_i,})
            # Update `batch_size_source`.
            batch_size_source_i = len(inputs_source_i[0]); batch_size_source.append(batch_size_source_i)
            # Update `batch_size_target_train`.
            batch_size_target_train_i = len(inputs_target_i[0]); batch_size_target_train.append(batch_size_target_train_i)
            # If `iter_target` is used up, re-instantiate `iter_target`.
            if (batch_idx + 1) % n_batches_target == 0: iter_target = iter(dataset_target_train)
            # Train model for current batch.
            y_pred_i, loss_i = model.train(inputs_i, params=params.model)
            # Calculate `accuracy_source_i`.
            y_pred_class_source_i = y_pred_i["source"]["class"].numpy()
            y_true_class_source_i = inputs_source_i[1].numpy()
            accuracy_source_i = np.argmax(y_pred_class_source_i, axis=-1) == np.argmax(y_true_class_source_i, axis=-1)
            accuracy_source_i = np.sum(accuracy_source_i) / accuracy_source_i.size
            accuracy_source.append(accuracy_source_i)
            # Calculate `accuracy_target_train_i`.
            y_pred_class_target_i = y_pred_i["target"]["class"].numpy()
            y_true_class_target_i = inputs_target_i[1].numpy()
            accuracy_target_train_i = np.argmax(y_pred_class_target_i, axis=-1) == np.argmax(y_true_class_target_i, axis=-1)
            accuracy_target_train_i = np.sum(accuracy_target_train_i) / accuracy_target_train_i.size
            accuracy_target_train.append(accuracy_target_train_i)
            # Update `loss_source`.
            loss_source["class"].append(loss_i["source"]["class"].numpy().mean())
            loss_source["domain"].append(loss_i["source"]["domain"].numpy().mean())
            # Update `loss_target_train`.
            loss_target_train["class"].append(loss_i["target"]["class"].numpy().mean())
            loss_target_train["domain"].append(loss_i["target"]["domain"].numpy().mean())
        # Calculate the averaged `accuracy` & `loss` corresponding to source domain.
        accuracy_source = np.sum(np.array(accuracy_source) * np.array(batch_size_source)) / np.sum(batch_size_source)
        loss_source["class"] = np.sum(np.array(loss_source["class"]) *\
            np.array(batch_size_source)) / np.sum(batch_size_source)
        loss_source["domain"] = np.sum(np.array(loss_source["domain"]) *\
            np.array(batch_size_source)) / np.sum(batch_size_source)
        # Calculate the averaged `accuracy` & `loss` corresponding to target domain.
        accuracy_target_train = np.sum(np.array(accuracy_target_train) *\
            np.array(batch_size_target_train)) / np.sum(batch_size_target_train)
        loss_target_train["class"] = np.sum(np.array(loss_target_train["class"]) *\
            np.array(batch_size_target_train)) / np.sum(batch_size_target_train)
        loss_target_train["domain"] = np.sum(np.array(loss_target_train["domain"]) *\
            np.array(batch_size_target_train)) / np.sum(batch_size_target_train)
        # Prepare for model validation process.
        accuracy_target_validation = []; batch_size_target_validation = []
        loss_target_validation = utils.DotDict({"class": [], "domain": [],})
        # Test model for current epoch.
        iter_target = iter(dataset_target_validation); n_batches = n_batches_target_validation
        for batch_idx in range(n_batches):
            # Initialize `inputs_target_i` from `iter_target`.
            inputs_target_i = iter_target.next()
            # Re-formulate `inputs_*_i`.
            inputs_target_i = [inputs_target_i[0], inputs_target_i[1], inputs_target_i[3]]
            # Update `batch_size_target`.
            batch_size_target_i = len(inputs_target_i[0]); batch_size_target_validation.append(batch_size_target_i)
            # Test model for current batch.
            y_pred_i, loss_i = model(inputs_target_i, params=params.model)
            # Calculate `accuracy_target_i`.
            y_pred_class_target_i = y_pred_i["class"].numpy()
            y_true_class_target_i = inputs_target_i[1].numpy()
            accuracy_target_i = np.argmax(y_pred_class_target_i, axis=-1) == np.argmax(y_true_class_target_i, axis=-1)
            accuracy_target_i = np.sum(accuracy_target_i) / accuracy_target_i.size
            accuracy_target_validation.append(accuracy_target_i)
            # Update `loss_target`.
            loss_target_validation["class"].append(loss_i["class"].numpy().mean())
            loss_target_validation["domain"].append(loss_i["domain"].numpy().mean())
        # Calculate the averaged `accuracy` & `loss` corresponding to target domain.
        accuracy_target_validation = np.sum(np.array(accuracy_target_validation) *\
            np.array(batch_size_target_validation)) / np.sum(batch_size_target_validation)
        loss_target_validation["class"] = np.sum(np.array(loss_target_validation["class"]) *\
            np.array(batch_size_target_validation)) / np.sum(batch_size_target_validation)
        loss_target_validation["domain"] = np.sum(np.array(loss_target_validation["domain"]) *\
            np.array(batch_size_target_validation)) / np.sum(batch_size_target_validation)
        # Prepare for model test process.
        accuracy_target_test = []; batch_size_target_test = []
        loss_target_test = utils.DotDict({"class": [], "domain": [],})
        # Test model for current epoch.
        iter_target = iter(dataset_target_test); n_batches = n_batches_target_test
        for batch_idx in range(n_batches):
            # Initialize `inputs_target_i` from `iter_target`.
            inputs_target_i = iter_target.next()
            # Re-formulate `inputs_*_i`.
            inputs_target_i = [inputs_target_i[0], inputs_target_i[1], inputs_target_i[3]]
            # Update `batch_size_target`.
            batch_size_target_i = len(inputs_target_i[0]); batch_size_target_test.append(batch_size_target_i)
            # Test model for current batch.
            y_pred_i, loss_i = model(inputs_target_i, params=params.model)
            # Calculate `accuracy_target_i`.
            y_pred_class_target_i = y_pred_i["class"].numpy()
            y_true_class_target_i = inputs_target_i[1].numpy()
            accuracy_target_i = np.argmax(y_pred_class_target_i, axis=-1) == np.argmax(y_true_class_target_i, axis=-1)
            accuracy_target_i = np.sum(accuracy_target_i) / accuracy_target_i.size
            accuracy_target_test.append(accuracy_target_i)
            # Update `loss_target`.
            loss_target_test["class"].append(loss_i["class"].numpy().mean())
            loss_target_test["domain"].append(loss_i["domain"].numpy().mean())
        # Calculate the averaged `accuracy` & `loss` corresponding to target domain.
        accuracy_target_test = np.sum(np.array(accuracy_target_test) *\
            np.array(batch_size_target_test)) / np.sum(batch_size_target_test)
        loss_target_test["class"] = np.sum(np.array(loss_target_test["class"]) *\
            np.array(batch_size_target_test)) / np.sum(batch_size_target_test)
        loss_target_test["domain"] = np.sum(np.array(loss_target_test["domain"]) *\
            np.array(batch_size_target_test)) / np.sum(batch_size_target_test)
        # Log information related to current training epoch.
        time_stop = time.time(); accuracies_source.append(accuracy_source)
        accuracies_target_train.append(accuracy_target_train)
        accuracies_target_validation.append(accuracy_target_validation)
        accuracies_target_test.append(accuracy_target_test)
        msg = (
            "Finish train epoch {:d} in {:.2f} seconds, generating {:d} concrete functions."
        ).format(epoch_idx, time_stop-time_start, len(model.train.pretty_printed_concrete_signatures().split("\n\n")))
        print(msg); paths.run.logger.summaries.info(msg)
        msg = (
            "Accuracy(source): {:.2f}%. Loss(source): {:.5f} (class) {:.5f} (domain)."
        ).format(accuracy_source * 100., loss_source["class"], loss_source["domain"])
        print(msg); paths.run.logger.summaries.info(msg)
        msg = (
            "Accuracy(target-train): {:.2f}%. Loss(target): {:.5f} (class) {:.5f} (domain)."
        ).format(accuracy_target_train * 100., loss_target_train["class"], loss_target_train["domain"])
        print(msg); paths.run.logger.summaries.info(msg)
        msg = (
            "Accuracy(target-validation): {:.2f}%. Loss(target): {:.5f} (class) {:.5f} (domain)."
        ).format(accuracy_target_validation * 100., loss_target_validation["class"], loss_target_validation["domain"])
        print(msg); paths.run.logger.summaries.info(msg)
        msg = (
            "Accuracy(target-test): {:.2f}%. Loss(target): {:.5f} (class) {:.5f} (domain)."
        ).format(accuracy_target_test * 100., loss_target_test["class"], loss_target_test["domain"])
        print(msg); paths.run.logger.summaries.info(msg)
        # Summarize model information.
        if epoch_idx == 0:
            model.summary(print_fn=print); model.summary(print_fn=paths.run.logger.summaries.info)
    # Convert `accuracies_source` & `accuracies_target` to `np.array`.
    accuracies_target_validation = np.round(np.array(accuracies_target_validation, dtype=np.float32), decimals=4)
    accuracies_target_test = np.round(np.array(accuracies_target_test, dtype=np.float32), decimals=4)
    epoch_maxacc_idxs = np.where(accuracies_target_validation == np.max(accuracies_target_validation))[0]
    epoch_maxacc_idx = epoch_maxacc_idxs[np.argmax(accuracies_target_test[epoch_maxacc_idxs])]
    # Finish training process of current specified experiment.
    msg = (
        "Finish the training process of experiment {}, with test-accuracy ({:.2f}%)" +\
        " according to max validation-accuracy ({:.2f}%) at epoch {:d}, generating {:d} concrete functions."
    ).format(load_params.name, accuracies_target_test[epoch_maxacc_idx]*100.,
        accuracies_target_validation[epoch_maxacc_idx]*100., epoch_maxacc_idx,
        len(model.train.pretty_printed_concrete_signatures().split("\n\n")))
    print(msg); paths.run.logger.summaries.info(msg)

if __name__ == "__main__":
    import os
    # local dep
    from params.domain_adaptation_params import domain_adversarial_transformer_params

    # macro
    dataset = "eeg_zhou2023cibr"

    # Initialize random seed.
    utils.model.set_seeds(42)

    ## Instantiate naive_rnn.
    # Initialize base.
    base = os.path.join(os.getcwd(), os.pardir, os.pardir, os.pardir, os.pardir)
    # Instantiate domain_adversarial_transformer_params.
    domain_adversarial_transformer_params_inst = domain_adversarial_transformer_params(dataset=dataset)
    # Train `DomainAdversarialTransformer`.
    train(base, domain_adversarial_transformer_params_inst)

