#!/usr/bin/env python3
"""
Created on 14:24, Jul. 21st, 2023

@author: Anonymous
"""
import os, sys, time
import copy as cp
import numpy as np
import tensorflow as tf
from sklearn import datasets
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.pardir)
import utils; import utils.model
import utils.data.eeg; import utils.data.meg
from models.cnn_ensemble import cnn_ensemble as cnn_ensemble_model

__all__ = [
    "train",
]

# Global variables.
params = None; paths = None
model = None; optimizer = None

"""
init funcs
"""
# def init func
def init(base_, params_):
    """
    Initialize `cnn_ensemble` training variables.

    Args:
        base_: str - The base path of current project.
        params_: DotDict - The parameters of current training process.

    Returns:
        None
    """
    global params, paths
    # Initialize params.
    params = cp.deepcopy(params_)
    paths = utils.Paths(base=base_, params=params)
    # Initialize model.
    _init_model()
    # Initialize training process.
    _init_train()

# def _init_model func
def _init_model():
    """
    Initialize model used in the training process.
    """
    global params
    ## Initialize tf configuration.
    # Not set random seed, should be done before initializing `model`.
    tf.keras.backend.set_floatx(params._precision)
    # Check whether run in graph mode or eager mode.
    tf.config.run_functions_eagerly(not params.train.use_graph_mode)

# def _init_train func
def _init_train():
    """
    Initialize the training process.
    """
    pass

"""
data funcs
"""
# def load_data func
def load_data(load_params):
    """
    Load data from specified dataset.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tf.data.Dataset - The input train dataset.
        dataset_validation_: tuple - The input validation dataset.
        dataset_test_: tf.data.Dataset - The input test dataset.
    """
    global params
    # Load data from specified dataset.
    try:
        func = getattr(sys.modules[__name__], "_".join(["_load_data", params.train.dataset]))
        dataset_train_, dataset_validation_, dataset_test_ = func(load_params)
    except Exception:
        raise ValueError("ERROR: Unknown dataset type {} in train.cnn_ensemble.".format(params.train.dataset))
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return dataset_train_, dataset_validation_, dataset_test_

# def _load_data_eeg_anonymous func
def _load_data_eeg_anonymous(load_params):
    """
    Load eeg data from specified subject.

    Args:
        load_params: DotDict - The load parameters of specified dataset.

    Returns:
        dataset_train_: tuple - The input train dataset.
        dataset_validation_: tuple - The input validation dataset.
        dataset_test_: tuple - The input test dataset.
    """
    global params, paths
    # Initialize path_run & session_type.
    path_run = os.path.join(paths.base, "data", "eeg.anonymous", "020", "20230405")\
        if not hasattr(load_params, "path_run") else load_params.path_run
    # Load data from specified subject run.
    datasets = utils.DotDict(); dataset_names = list(set(load_params.trainset) | set(load_params.testset))
    for dataset_name_i in dataset_names:
        # Load data from specified dataset name.
        session_type_i = "-".join(dataset_name_i.split("-")[:-1]); data_type_i = dataset_name_i.split("-")[-1]
        func_i = getattr(utils.data.eeg.anonymous, "_".join(["load_run", session_type_i.split("-")[0]]))
        X_i, y_i = func_i(path_run, session_type="-".join(session_type_i.split("-")[1:]))
        # Check whether current dataset has data items.
        if X_i[data_type_i] is None:
            msg = "WARNING: Skip dataset {} with no data item.".format(dataset_name_i)
            print(msg); paths.run.logger.summaries.info(msg); continue
        X_i = X_i[data_type_i].astype(np.float32); y_i = y_i[data_type_i].astype(np.int64)
        # Truncate `X_i` to let them have the same length.
        # TODO: Here, we only keep the [0.0~0.8]s-part of [audio,image] that after onset. And we should
        # note that the [0.0~0.8]s-part of image is the whole onset time of image, the [0.0~0.8]s-part
        # of audio is the sum of the whole onset time of audio and the following 0.3s padding.
        # X_i - (n_samples, seq_len, n_channels)
        # If the type of dataset is `default`.
        if load_params.type == "default":
            X_i = X_i[:,20:100,:]
        # If the type of dataset is `lvbj`.
        elif load_params.type == "lvbj":
            X_i = X_i[:,20:100,:]
        # Get unknown type of dataset.
        else:
            raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
        # Do cross-trial normalization.
        X_i = (X_i - np.mean(X_i)) / np.std(X_i)
        # Set the corresponding item of dataset.
        datasets[dataset_name_i] = utils.DotDict({"X":X_i,"y":y_i,})
    # Initialize trainset & testset.
    X_train = []; y_train = []; X_test = []; y_test = []
    for dataset_name_i, dataset_i in datasets.items():
        # If trainset and testset are the same dataset, split X into trainset & testset.
        if dataset_name_i in load_params.trainset and dataset_name_i in load_params.testset:
            # Construct dataset from data items.
            train_ratio = params.train.train_ratio
            X_train.append(dataset_i.X[:int(train_ratio * dataset_i.X.shape[0]),:,:])
            y_train.append(dataset_i.y[:int(train_ratio * dataset_i.y.shape[0])])
            X_test.append(dataset_i.X[int(train_ratio * dataset_i.X.shape[0]):,:,:])
            y_test.append(dataset_i.y[int(train_ratio * dataset_i.y.shape[0]):])
        # If trainset and testset are not the same, construct trainset & testset separately.
        elif dataset_name_i in load_params.trainset and dataset_name_i not in load_params.testset:
            X_train.append(dataset_i.X); y_train.append(dataset_i.y)
        # If trainset and testset are not the same, construct trainset & testset separately.
        elif dataset_name_i not in load_params.trainset and dataset_name_i in load_params.testset:
            X_test.append(dataset_i.X); y_test.append(dataset_i.y)
        # Wrong cases.
        else:
            raise ValueError("ERROR: Unknown dataset name {}.".format(dataset_name_i))
    # Check whether trainset & testset both have data items.
    if len(X_train) == 0 or len(X_test) == 0: return (), ()
    # X - (n_samples, seq_len, n_channels); y - (n_samples,)
    X_train = np.concatenate(X_train, axis=0); y_train = np.concatenate(y_train, axis=0)
    X_test = np.concatenate(X_test, axis=0); y_test = np.concatenate(y_test, axis=0)
    # Make sure there is no overlap between X_train & X_test.
    samples_same = None; n_samples = 10; assert X_train.shape[1] == X_test.shape[1]
    for _ in range(n_samples):
        sample_idx = np.random.randint(X_train.shape[1])
        sample_same_i = np.intersect1d(X_train[:,sample_idx,0], X_test[:,sample_idx,0], return_indices=True)[-1].tolist()
        samples_same = set(sample_same_i) if samples_same is None else set(sample_same_i) & samples_same
    assert len(samples_same) == 0
    # Check whether labels are enough, then transform y to one-hot encoding.
    try:
        assert len(set(y_train)) == len(set(y_test)) == 15; labels = sorted(set(y_train))
    except AssertionError as e:
        msg = (
            "WARNING: Skip experiment (train:{};test:{}) due to that the classes of test cases are not enough."
        ).format(set(load_params.trainset), set(load_params.testset))
        print(msg); paths.run.logger.summaries.info(msg); return (), ()
    # y - (n_samples, n_labels)
    y_train = np.array([labels.index(y_i) for y_i in y_train], dtype=np.int64); y_train = np.eye(len(labels))[y_train]
    y_test = np.array([labels.index(y_i) for y_i in y_test], dtype=np.int64); y_test = np.eye(len(labels))[y_test]
    # Downsample trainset & testset to test the time redundancy.
    if load_params.type == "default":
        sample_rate = 100
    elif load_params.type == "lvbj":
        sample_rate = 100
    else:
        raise ValueError("ERROR: Unknown type {} of dataset".format(load_params.type))
    assert sample_rate / load_params.resample_rate == sample_rate // load_params.resample_rate
    downsample_rate = sample_rate // load_params.resample_rate
    if downsample_rate != 1:
        assert X_train.shape[1] / downsample_rate == X_train.shape[1] // downsample_rate
        X_train = np.concatenate([X_train[:,np.arange(start_i, X_train.shape[1], downsample_rate),:]\
            for start_i in range(downsample_rate)], axis=0)
        y_train = np.concatenate([y_train for _ in range(downsample_rate)], axis=0)
        assert X_test.shape[1] / downsample_rate == X_test.shape[1] // downsample_rate
        X_test = np.concatenate([X_test[:,np.arange(start_i, X_test.shape[1], downsample_rate),:]\
            for start_i in range(downsample_rate)], axis=0)
        y_test = np.concatenate([y_test for _ in range(downsample_rate)], axis=0)
    # Execute sample permutation. We only shuffle along the axis.
    if load_params.permutation: np.random.shuffle(y_train)
    # Further split test-set into validation-set & test-set.
    validation_idxs = np.random.choice(np.arange(X_test.shape[0]), size=int(X_test.shape[0]/2), replace=False)
    validation_mask = np.zeros((X_test.shape[0],), dtype=np.bool_); validation_mask[validation_idxs] = True
    X_validation = X_test[validation_mask,:,:]; y_validation = y_test[validation_mask,:]
    X_test = X_test[~validation_mask,:,:]; y_test = y_test[~validation_mask,:]
    # Log information of data loading.
    msg = (
        "INFO: Data preparation complete, with train-set ({}) & validation-set ({}) & test-set ({})."
    ).format(X_train.shape, X_validation.shape, X_test.shape)
    print(msg); paths.run.logger.summaries.info(msg)
    # Construct dataset from data items.
    dataset_train_ = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    dataset_validation_ = tf.data.Dataset.from_tensor_slices((X_validation, y_validation))
    dataset_test_ = tf.data.Dataset.from_tensor_slices((X_test, y_test))
    # Shuffle and then batch the dataset.
    dataset_train_ = dataset_train_.shuffle(params.train.buffer_size).batch(params.train.batch_size, drop_remainder=False)
    dataset_validation_ = dataset_validation_.shuffle(
        params.train.buffer_size).batch(params.train.batch_size, drop_remainder=False)
    dataset_test_ = dataset_test_.shuffle(params.train.buffer_size).batch(params.train.batch_size, drop_remainder=False)
    # Return the final `dataset_train_` & `dataset_validation_` & `dataset_test_`.
    return dataset_train_, dataset_validation_, dataset_test_

"""
train funcs
"""
# def train func
def train(base_, params_):
    """
    Train the model.

    Args:
        base_: str - The base path of current project.
        params_: DotDict - The parameters of current training process.

    Returns:
        None
    """
    global _forward, _train
    global params, paths, model, optimizer
    # Initialize parameters & variables of current training process.
    init(base_, params_)
    # Log the start of current training process.
    paths.run.logger.summaries.info("Training started with dataset {}.".format(params.train.dataset))
    # Initialize load_params. Each load_params_i corresponds to a sub-dataset.
    if params.train.dataset == "eeg_anonymous":
        # Initialize the paths of runs that we want to execute experiments.
        path_runs = [
            #os.path.join(paths.base, "data", "eeg.anonymous", "005", "20221223"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "006", "20230103"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "007", "20230106"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "011", "20230214"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "013", "20230308"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "018", "20230331"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "019", "20230403"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "020", "20230405"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "021", "20230407"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "023", "20230412"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "024", "20230414"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "025", "20230417"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "026", "20230419"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "027", "20230421"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "028", "20230424"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "029", "20230428"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "030", "20230504"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "031", "20230510"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-002", "20230509"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "033", "20230517"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "034", "20230519"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-003", "20230524"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-004", "20230528"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-005", "20230601"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "036", "20230526"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "037", "20230529"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "038", "20230531"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "039", "20230605"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "040", "20230607"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-006", "20230603"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-007", "20230608"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-008", "20230610"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "042", "20230614"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "043", "20230616"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "044", "20230619"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "045", "20230626"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "046", "20230628"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-009", "20230613"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-010", "20230615"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-012", "20230623"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-013", "20230627"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-014", "20230629"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-015", "20230701"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "047", "20230703"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "048", "20230705"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-016", "20230703"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-017", "20230706"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "049", "20230710"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "050", "20230712"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "051", "20230717"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "052", "20230719"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-018", "20230710"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-019", "20230712"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-020", "20230714"),
            #os.path.join(paths.base, "data", "eeg.anonymous", "sz-021", "20230718"),
            os.path.join(paths.base, "data", "eeg.anonymous", "053", "20230724"),
            os.path.join(paths.base, "data", "eeg.anonymous", "054", "20230726"),
            os.path.join(paths.base, "data", "eeg.anonymous", "sz-022", "20230724"),
            os.path.join(paths.base, "data", "eeg.anonymous", "sz-023", "20230726"),
            os.path.join(paths.base, "data", "eeg.anonymous", "sz-024", "20230728"),
            os.path.join(paths.base, "data", "eeg.anonymous", "sz-025", "20230803"),
            os.path.join(paths.base, "data", "eeg.anonymous", "sz-026", "20230805"),
            os.path.join(paths.base, "data", "eeg.anonymous", "sz-027", "20230807"),
        ]; load_type = "default"
        # `load_params` contains all the experiments that we want to execute for every run.
        load_params = [
            # train-task-all-all-test-task-all-all
            utils.DotDict({
                "name": "train-task-all-all-test-task-all-all",
                "trainset": [
                    "task-image-audio-pre-audio", "task-image-audio-pre-image",
                    "task-audio-image-pre-audio", "task-audio-image-pre-image",
                    "task-image-audio-post-audio", "task-image-audio-post-image",
                    "task-audio-image-post-audio", "task-audio-image-post-image",
                ],
                "testset": [
                    "task-image-audio-pre-audio", "task-image-audio-pre-image",
                    "task-audio-image-pre-audio", "task-audio-image-pre-image",
                    "task-image-audio-post-audio", "task-image-audio-post-image",
                    "task-audio-image-post-audio", "task-audio-image-post-image",
                ],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-task-all-image-test-task-all-image
            utils.DotDict({
                "name": "train-task-all-image-test-task-all-image",
                "trainset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "testset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-task-all-audio-test-task-all-audio
            utils.DotDict({
                "name": "train-task-all-audio-test-task-all-audio",
                "trainset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "testset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-tmr-n23-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-tmr-n23-audio",
                "trainset": ["tmr-N2/3-audio",],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-tmr-n23-audio-test-tmr-rem-audio
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-tmr-rem-audio",
                "trainset": ["tmr-N2/3-audio",],
                "testset": ["tmr-REM-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-tmr-n23-audio-test-task-all-image
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-task-all-image",
                "trainset": ["tmr-N2/3-audio",],
                "testset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-tmr-n23-audio-test-task-all-audio
            utils.DotDict({
                "name": "train-tmr-n23-audio-test-task-all-audio",
                "trainset": ["tmr-N2/3-audio",],
                "testset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-task-all-image-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-image-test-tmr-n23-audio",
                "trainset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-task-all-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-audio-test-tmr-n23-audio",
                "trainset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-task-all-image-test-tmr-rem-audio
            utils.DotDict({
                "name": "train-task-all-image-test-tmr-rem-audio",
                "trainset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                ],
                "testset": ["tmr-REM-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-task-all-audio-test-tmr-rem-audio
            utils.DotDict({
                "name": "train-task-all-audio-test-tmr-rem-audio",
                "trainset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                ],
                "testset": ["tmr-REM-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-task-all-image-tmr-n23-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-image-tmr-n23-audio-test-tmr-n23-audio",
                "trainset": [
                    "task-image-audio-pre-image", "task-audio-image-pre-image",
                    "task-image-audio-post-image", "task-audio-image-post-image",
                    "tmr-N2/3-audio",
                ],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
            # train-task-all-audio-tmr-n23-audio-test-tmr-n23-audio
            utils.DotDict({
                "name": "train-task-all-audio-tmr-n23-audio-test-tmr-n23-audio",
                "trainset": [
                    "task-image-audio-pre-audio", "task-audio-image-pre-audio",
                    "task-image-audio-post-audio", "task-audio-image-post-audio",
                    "tmr-N2/3-audio",
                ],
                "testset": ["tmr-N2/3-audio",],
                "type": load_type, "permutation": False, "resample_rate": 100,
            }),
        ]
    else:
        raise ValueError("ERROR: Unknown dataset {} in train.cnn_ensemble.".format(params.train.dataset))
    # Execute experiments for each dataset run.
    for path_run_i in path_runs:
        # Log the start of current training iteration.
        msg = "Training started with sub-dataset {}.".format(path_run_i)
        print(msg); paths.run.logger.summaries.info(msg)
        for load_params_idx in range(len(load_params)):
            # Add `path_run` to `load_params_i`.
            load_params_i = cp.deepcopy(load_params[load_params_idx]); load_params_i.path_run = path_run_i
            # Load data from specified experiment.
            dataset_train, dataset_validation, dataset_test = load_data(load_params_i)
            # Check whether trainset & testset exists.
            if len(dataset_train) == 0 or len(dataset_validation) == 0 or len(dataset_test) == 0:
                msg = (
                    "INFO: Skip experiment {} with train-set ({:d} items)" +\
                    " & validation-set ({:d} items) & test-set ({:d} items)."
                ).format(load_params_i.name, len(dataset_train[0]), len(dataset_validation[0]), len(dataset_test[0]))
                print(msg); paths.run.logger.summaries.info(msg); continue
            # Start training process of current specified experiment.
            msg = "Start the training process of experiment {}.".format(load_params_i.name)
            print(msg); paths.run.logger.summaries.info(msg)

            # Train the model for each time segment.
            accuracies_validation = []; accuracies_test = []

            # Reset the iteration information of params.
            params.iteration(iteration=0)
            # Initialize model of current time segment.
            model = cnn_ensemble_model(params.model)
            # Make an ADAM optimizer for model.
            optimizer = tf.keras.optimizers.Adam(learning_rate=params.train.lr_i)
            # Re-initialize the tf.function.
            _forward = tf.function(_forward.python_function); _train = tf.function(_train.python_function)
            for epoch_idx in range(params.train.n_epochs):
                # Record the start time of preparing data.
                time_start = time.time()
                # Start training epoch.
                paths.run.logger.summaries.info("Start training epoch {:d}.".format(epoch_idx))
                # Execute train process.
                accuracy_train = []; loss_train = []; batch_size_train = []
                for batch_i in dataset_train:
                    # Get the number of current batch.
                    batch_size_i = len(batch_i[0]); batch_size_train.append(batch_size_i)
                    # Train model for current batch.
                    outputs_i, loss_i = _train(batch_i); outputs_i, loss_i = outputs_i.numpy(), loss_i.numpy()
                    accuracy_train_i = np.argmax(outputs_i, axis=-1) == np.argmax(batch_i[1], axis=-1)
                    accuracy_train_i = np.sum(accuracy_train_i) / accuracy_train_i.size
                    accuracy_train.append(accuracy_train_i); loss_train.append(loss_i)
                accuracy_train = np.sum(np.array(accuracy_train) * np.array(batch_size_train)) / np.sum(batch_size_train)
                loss_train = np.sum(np.array(loss_train) * np.array(batch_size_train)) / np.sum(batch_size_train)
                # Execute validation process.
                accuracy_validation = []; loss_validation = []; batch_size_validation = []
                for batch_i in dataset_validation:
                    # Get the number of current batch.
                    batch_size_i = len(batch_i[0]); batch_size_validation.append(batch_size_i)
                    # Validate model for current batch.
                    outputs_i, loss_i = _forward(batch_i); outputs_i, loss_i = outputs_i.numpy(), loss_i.numpy()
                    accuracy_validation_i = np.argmax(outputs_i, axis=-1) == np.argmax(batch_i[1], axis=-1)
                    accuracy_validation_i = np.sum(accuracy_validation_i) / accuracy_validation_i.size
                    accuracy_validation.append(accuracy_validation_i); loss_validation.append(loss_i)
                accuracy_validation = np.sum(np.array(accuracy_validation) *\
                    np.array(batch_size_validation)) / np.sum(batch_size_validation)
                loss_validation = np.sum(np.array(loss_validation) *\
                    np.array(batch_size_validation)) / np.sum(batch_size_validation)
                # Execute test process.
                accuracy_test = []; loss_test = []; batch_size_test = []
                for batch_i in dataset_test:
                    # Get the number of current batch.
                    batch_size_i = len(batch_i[0]); batch_size_test.append(batch_size_i)
                    # Test model for current batch.
                    outputs_i, loss_i = _forward(batch_i); outputs_i, loss_i = outputs_i.numpy(), loss_i.numpy()
                    accuracy_test_i = np.argmax(outputs_i, axis=-1) == np.argmax(batch_i[1], axis=-1)
                    accuracy_test_i = np.sum(accuracy_test_i) / accuracy_test_i.size
                    accuracy_test.append(accuracy_test_i); loss_test.append(loss_i)
                accuracy_test = np.sum(np.array(accuracy_test) * np.array(batch_size_test)) / np.sum(batch_size_test)
                loss_test = np.sum(np.array(loss_test) * np.array(batch_size_test)) / np.sum(batch_size_test)
                # Log information related to current training epoch.
                time_stop = time.time()
                accuracies_validation.append(accuracy_validation); accuracies_test.append(accuracy_test)
                msg = (
                    "Finish train epoch {:d} in {:.2f} seconds, generating {:d} concrete functions."
                ).format(epoch_idx, time_stop-time_start, len(_train.pretty_printed_concrete_signatures().split("\n\n")))
                print(msg); paths.run.logger.summaries.info(msg)
                msg = (
                    "Accuracy(train): {:.2f}%. Loss(train): {:.5f}."
                ).format(accuracy_train * 100., loss_train)
                print(msg); paths.run.logger.summaries.info(msg)
                msg = (
                    "Accuracy(validation): {:.2f}%. Loss(validation): {:.5f}."
                ).format(accuracy_validation * 100., loss_validation)
                print(msg); paths.run.logger.summaries.info(msg)
                msg = (
                    "Accuracy(test): {:.2f}%. Loss(test): {:.5f}."
                ).format(accuracy_test * 100., loss_test)
                print(msg); paths.run.logger.summaries.info(msg)
                # Summarize model information.
                if epoch_idx == 0:
                    model.summary(print_fn=print); model.summary(print_fn=paths.run.logger.summaries.info)
            # Convert `accuracies_validation` & `accuracies_test` to `np.array`.
            accuracies_validation = np.round(np.array(accuracies_validation, dtype=np.float32), decimals=4)
            accuracies_test = np.round(np.array(accuracies_test, dtype=np.float32), decimals=4)
            epoch_maxacc_idxs = np.where(accuracies_validation == np.max(accuracies_validation))[0]
            epoch_maxacc_idx = epoch_maxacc_idxs[np.argmax(accuracies_test[epoch_maxacc_idxs])]
            # Finish training process of current specified experiment.
            msg = (
                "Finish the training process of experiment {},, with test-accuracy ({:.2f}%)" +\
                " according to max validation-accuracy ({:.2f}%) at epoch {:d}, generating {:d} concrete functions."
            ).format(load_params_i.name, accuracies_test[epoch_maxacc_idx]*100.,
                accuracies_validation[epoch_maxacc_idx]*100., epoch_maxacc_idx,
                len(_train.pretty_printed_concrete_signatures().split("\n\n")))
            print(msg); paths.run.logger.summaries.info(msg)
        # Log the end of current training iteration.
        msg = "Training finished with sub-dataset {}.".format(path_run_i)
        print(msg); paths.run.logger.summaries.info(msg)
    # Log the end of current training process.
    msg = "Training finished with dataset {}.".format(params.train.dataset)
    print(msg); paths.run.logger.summaries.info(msg)

# def _forward func
@tf.function
def _forward(inputs, training=False):
    """
    Forward the model using one-step data. Everything entering this function already be a tensor.
    :param inputs: (X, locations, subject_id, y)
    :param training: Indicate whether enable training process.
    :return outputs_: (n_samples, n_labels) - The predicted labels of inputs.
    :return loss_: float - The corresponding cross-entropy loss.
    """
    global model; return model(inputs, training=training)

# def _train func
@tf.function
def _train(inputs):
    """
    Train the model using one-step data. Everything entering this function already be a tensor.
    :param inputs: (X, locations, subject_id, y)
    :return outputs_: (n_samples, n_labels) - The predicted labels of inputs.
    :return loss_: float - The corresponding cross-entropy loss.
    """
    global model, optimizer
    # Train the model using one-step data.
    with tf.GradientTape() as gt:
        outputs_, loss_ = _forward(inputs, training=True)
    # Modify weights to optimize the model.
    gradients = gt.gradient(loss_, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    # Return the final `outputs_` & `loss_`.
    return outputs_, loss_

if __name__ == "__main__":
    import os
    # local dep
    from params.cnn_ensemble_params import cnn_ensemble_params

    # macro
    dataset = "eeg_anonymous"

    # Initialize random seed.
    utils.model.set_seeds(42)

    ## Instantiate cnn_ensemble.
    # Initialize base.
    base = os.path.join(os.getcwd(), os.pardir)
    # Instantiate cnn_ensemble_params.
    cnn_ensemble_params_inst = cnn_ensemble_params(dataset=dataset)
    # Train cnn_ensemble.
    train(base, cnn_ensemble_params_inst)

