#!/usr/bin/env python3
"""
Created on 16:32, Jul. 20th, 2023

@author: Anonymous
"""
import os, json
import copy as cp
import numpy as np
import tensorflow as tf
from sklearn import preprocessing
from collections import Counter
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir))
from utils import DotDict
from utils.data import load_pickle

__all__ = [

]

"""
tool funcs
"""
# def _bytes_feature func
def _bytes_feature(value):
    """
    Return a bytes list from a string / byte.
    """
    # BytesList won't unpack a string from an EagerTensor.
    if isinstance(value, type(tf.constant(0))): value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# def _float_feature func
def _float_feature(value):
    """
    Return a float list from a float / double.
    """
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

# def _int64_feature func
def _int64_feature(value):
    """
    Return an int64 list from a bool / enum / int / uint.
    """
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

"""
preprocess funcs
"""
# def _robust_scale func
def _robust_scale(data, baseline_range=(0,20), crop_range=(20,100), axis=0):
    """
    Robust scale pipeline for data.

    Args:
        data: (n_channels, seq_len) - The loaded data.
        baseline_range: (2[tuple],) - The range used to calculate baseline.
        crop_range: (2[tuple],) - The crop range used in the following calculation.
        axis: int - The axis used to do normalization.

    Returns:
        data: (n_channels, seq_len) - The preprocessed data.
    """
    # Execute baseline correction.
    # baseline - (n_channels, 1)
    baseline = np.mean(data[:,baseline_range[0]:baseline_range[1]], axis=-1, keepdims=True)
    data = data[:,crop_range[0]:crop_range[1]] - baseline
    # Robust scale baseline-corrected data.
    data = preprocessing.robust_scale(data, axis=axis, with_centering=True,
        with_scaling=True, quantile_range=(25.0, 75.0), unit_variance=False)
    # Clamp values that are greater than 20 after normalization to minimize the impact of large outlier samples.
    clamp_range = (-20., 20.)
    data = np.where(data < clamp_range[0], clamp_range[0], data)
    data = np.where(data > clamp_range[1], clamp_range[1], data)
    # Return the final `data`.
    return data

# def _scale func
def _scale(data, crop_range=(20,100), axis=0):
    """
    Scale pipeline for data.

    Args:
        data: (n_channels, seq_len) - The loaded data.
        crop_range: (2[tuple],) - The crop range used in the following calculation.
        axis: int - The axis used to do normalization.

    Returns:
        data: (n_channels, seq_len) - The preprocessed data.
    """
    # Crop the data.
    data = data[:,crop_range[0]:crop_range[1]]
    # Scale the cropped data.
    data = preprocessing.scale(data, axis=axis)
    # Return the final `data`.
    return data

"""
cross-modalitity funcs
"""
# def prepare_crossmodality func
def prepare_crossmodality(base, batch_size):
    """
    Prepare crossmodality data tf-records for TMR dataset.

    Args:
        base: str - The path of project base.
        batch_size: int - The size of batch to split tfrecords.

    Returns:
        None
    """
    ## Prepare for dataset creation.
    # Initialize the path of dataset.
    path_dataset = os.path.join(base, "data", "eeg.anonymous")
    # Initialize available subj-runs.
    available_subj_runs = ["005/20221223", "006/20230103", "007/20230106", "011/20230214", "013/20230308", "018/20230331",\
        "019/20230403", "020/20230405", "021/20230407", "023/20230412", "024/20230414", "025/20230417", "026/20230419",\
        "027/20230421", "028/20230424", "029/20230428", "030/20230504", "031/20230510", "033/20230517", "034/20230519",\
        "036/20230526", "037/20230529", "038/20230531", "039/20230605", "040/20230607", "042/20230614", "043/20230616",\
        "044/20230619", "045/20230626", "046/20230628", "047/20230703", "048/20230705", "049/20230710", "050/20230712",\
        "051/20230717", "052/20230719", "053/20230724", "054/20230726",]
    available_subj_runs.extend(["sz-003/20230524",])
    assert len(set([subj_run_i.split("/")[0] for subj_run_i in available_subj_runs])) == len(available_subj_runs)
    ## Execute dataset creation.
    # Initialize available modalities.
    available_modalities = ["image", "audio", "N2/3", "REM"]
    # Loop over available modalities to create dataset.
    for modality_i in available_modalities:
        _prepare_crossmodality(path_dataset, available_subj_runs, available_modalities, batch_size, modality=modality_i)

# def _prepare_crossmodality func
def _prepare_crossmodality(path_dataset, subj_runs, modalities,
    batch_size, modality="N2/3", train_ratio=0.8, thres_dataset=100):
    """
    Prepare crossmodality data tf-records for dataset.

    Args:
        path_dataset: str - The path to dataset base.
        subj_runs: list - The list of available subj-runs.
        modalities: list - The list of available modalities.
        batch_size: int - The size of batch to split tfrecords.
        modality: str - The specified data modality, supporting [N2/3,REM].

    Returns:
        None
    """
    # Prepare crossmodality data tf-records according to `modality`.
    if modality in ["image", "audio"]:
        _prepare_crossmodality_task(path_dataset, subj_runs, modalities, batch_size,
            modality=modality, train_ratio=train_ratio, thres_dataset=thres_dataset)
    elif modality in ["N2/3", "REM"]:
        _prepare_crossmodality_tmr(path_dataset, subj_runs, modalities, batch_size,
            modality=modality, train_ratio=train_ratio, thres_dataset=thres_dataset)
    else:
        raise ValueError("ERROR: Get unknown modality ({}).".format(modality))

# def _prepare_crossmodality_task func
def _prepare_crossmodality_task(path_dataset, subj_runs, modalities,
    batch_size, modality="N2/3", train_ratio=0.8, thres_dataset=100):
    """
    Prepare crossmodality data tf-records for task dataset.

    Args:
        path_dataset: str - The path to dataset base.
        subj_runs: list - The list of available subj-runs.
        modalities: list - The list of available modalities.
        batch_size: int - The size of batch to split tfrecords.
        modality: str - The specified data modality, supporting [N2/3,REM].

    Returns:
        None
    """
    assert modality in ["image", "audio"]
    ## Prepare for dataset creation.
    # Initialize the path to save tf-records.
    path_dataset_tfrecords = os.path.join(path_dataset, "dataset.crossmodality")
    if not os.path.exists(path_dataset_tfrecords): os.makedirs(path_dataset_tfrecords)
    path_dataset_tfrecords = os.path.join(path_dataset_tfrecords, modality.replace("/", ""))
    if not os.path.exists(path_dataset_tfrecords): os.makedirs(path_dataset_tfrecords)
    # Initialize `labels`, which is the map between category and the corresponding index.
    labels = DotDict({"alarm": 0, "apple": 1, "ball": 2, "book": 3, "box": 4, "chair": 5, "kiwi": 6, "microphone": 7,
        "motorcycle": 8, "pepper": 9, "sheep": 10, "shoes": 11, "strawberry": 12, "tomato": 13, "watch": 14,})
    n_labels = len(labels)
    # Initialize the number of subjects.
    subj_runs = sorted(set(subj_runs)); n_subjects = len(subj_runs)
    # Initialize the number of domains.
    domains = sorted(set(modalities)); n_domains = len(domains); domain_idx = domains.index(modality)
    ## Execute dataset creation.
    # Initialize datasets.
    datasets = DotDict({})
    for subj_run_i in subj_runs:
        # Load the corresponding dataset.
        datasets_i = load_pickle(os.path.join(path_dataset, os.sep.join(subj_run_i.split("/")), "dataset.task"))
        dataset_i = [data_i for dataset_name_i in sorted(datasets_i) for data_i in datasets_i[dataset_name_i]]
        # If the size of dataset is too small, ignore it.
        if len(dataset_i) < thres_dataset: continue
        # If the number of each label in dataset is not enough, ignore it.
        label_counter_i = Counter([data_i[modality]["name"] for data_i in dataset_i])
        if min(label_counter_i.values()) < 3: continue
        # Only keep audio part of `dataset_i`.
        dataset_i = [data_i[modality] for data_i in dataset_i]
        # Update `datasets` with `dataset_i`.
        datasets[subj_run_i] = dataset_i
    print(("INFO: Get {:d} datasets (including {}).").format(len(datasets), sorted(datasets.keys())))
    # Prepare to construct train-set & validation-set & test-set.
    train_ratio = train_ratio; validation_ratio = test_ratio = (1. - train_ratio) / 2.
    # Construct `*sets` according to `datasets`.
    trainsets = DotDict({}); validationsets = DotDict({}); testsets = DotDict({})
    for subj_run_i in sorted(datasets.keys()):
        # Initialize `*sets` as empty list.
        trainsets[subj_run_i] = []; validationsets[subj_run_i] = []; testsets[subj_run_i] = []
        # Get the corresponding dataset.
        dataset_i = datasets[subj_run_i]
        # Count the number of samples corresponding to each label.
        label_counter_i = Counter([data_i["name"] for data_i in dataset_i])
        label_counter_validation_i = DotDict({label_name_i:max(1,
            int(label_counter_i[label_name_i] * validation_ratio)
        ) for label_name_i in sorted(label_counter_i.keys())})
        label_counter_test_i = DotDict({label_name_i:max(1,
            int(label_counter_i[label_name_i] * test_ratio)
        ) for label_name_i in sorted(label_counter_i.keys())})
        label_counter_train_i = DotDict({label_name_i:(label_counter_i[label_name_i] -\
            (label_counter_validation_i[label_name_i] + label_counter_test_i[label_name_i])
        ) for label_name_i in sorted(label_counter_i.keys())})
        # Loop over all labels to get the corresponding indices.
        for label_name_i in sorted(label_counter_i.keys()):
            # Get the indices of label, then split it.
            label_idxs = [data_idx for data_idx in range(len(dataset_i)) if dataset_i[data_idx]["name"] == label_name_i]
            validation_idxs = np.random.choice(label_idxs,
                size=label_counter_validation_i[label_name_i], replace=False).tolist()
            test_idxs = np.random.choice(sorted(set(label_idxs) - set(validation_idxs)),
                size=label_counter_test_i[label_name_i], replace=False).tolist()
            train_idxs = sorted(set(label_idxs) - set(validation_idxs) - set(test_idxs))
            assert len(set(train_idxs) & set(validation_idxs)) == 0
            assert len(set(train_idxs) & set(test_idxs)) == 0
            assert len(set(validation_idxs) & set(test_idxs)) == 0
            # Use these indices to fill up `*sets`.
            trainsets[subj_run_i].extend([dataset_i[data_idx] for data_idx in train_idxs])
            validationsets[subj_run_i].extend([dataset_i[data_idx] for data_idx in validation_idxs])
            testsets[subj_run_i].extend([dataset_i[data_idx] for data_idx in test_idxs])
    # Log information related to train-set & validation-set & test-set of dataset.
    n_samples = DotDict({
        "train-set": np.sum([len(trainset_i) for trainset_i in trainsets.values()]),
        "validation-set": np.sum([len(validationset_i) for validationset_i in validationsets.values()]),
        "test-set": np.sum([len(testset_i) for testset_i in testsets.values()]),
    })
    print((
        "INFO: Complete the segmentation of dataset, with train-set ({:d})" +\
        " & validation-set ({:d}) & test-set ({:d}), including {}."
    ).format(n_samples["train-set"], n_samples["validation-set"], n_samples["test-set"], sorted(datasets.keys())))
    ### Write created dataset to tf-records.
    ## Initialize the path of dataset.
    path_dataset_tfrecords_train = os.path.join(path_dataset_tfrecords, "train")
    if not os.path.exists(path_dataset_tfrecords_train): os.makedirs(path_dataset_tfrecords_train)
    path_dataset_tfrecords_validation = os.path.join(path_dataset_tfrecords, "validation")
    if not os.path.exists(path_dataset_tfrecords_validation): os.makedirs(path_dataset_tfrecords_validation)
    path_dataset_tfrecords_test = os.path.join(path_dataset_tfrecords, "test")
    if not os.path.exists(path_dataset_tfrecords_test): os.makedirs(path_dataset_tfrecords_test)
    ## Save constructed dataset.
    # Construct `dataset_*` according to `*sets`.
    dataset_train = [{
        "label": tf.cast(np.eye(n_labels)[labels[data_i["name"]]], dtype=tf.float32),
        "data": tf.convert_to_tensor(_robust_scale(data_i["data"]), dtype=tf.float32),
        "subj_id": tf.cast(np.eye(n_subjects)[subj_runs.index(subj_run_i)], dtype=tf.float32),
        "domain_id": tf.cast(np.eye(n_domains)[domain_idx], dtype=tf.float32),
    } for subj_run_i in trainsets for data_i in trainsets[subj_run_i]]
    dataset_validation = [{
        "label": tf.cast(np.eye(n_labels)[labels[data_i["name"]]], dtype=tf.float32),
        "data": tf.convert_to_tensor(_robust_scale(data_i["data"]), dtype=tf.float32),
        "subj_id": tf.cast(np.eye(n_subjects)[subj_runs.index(subj_run_i)], dtype=tf.float32),
        "domain_id": tf.cast(np.eye(n_domains)[domain_idx], dtype=tf.float32),
    } for subj_run_i in validationsets for data_i in validationsets[subj_run_i]]
    dataset_test = [{
        "label": tf.cast(np.eye(n_labels)[labels[data_i["name"]]], dtype=tf.float32),
        "data": tf.convert_to_tensor(_robust_scale(data_i["data"]), dtype=tf.float32),
        "subj_id": tf.cast(np.eye(n_subjects)[subj_runs.index(subj_run_i)], dtype=tf.float32),
        "domain_id": tf.cast(np.eye(n_domains)[domain_idx], dtype=tf.float32),
    } for subj_run_i in testsets for data_i in testsets[subj_run_i]]
    # Shuffle `dataset_*` to get shuffled list.
    np.random.shuffle(dataset_train); np.random.shuffle(dataset_validation); np.random.shuffle(dataset_test)
    # Save `dataset_train` to tf-records.
    data_idx = 0; tfrecord_idx = 0
    while data_idx < len(dataset_train):
        # Initialize the path of tf-record to save current batch.
        path_tfrecord_i = os.path.join(path_dataset_tfrecords_train, "train-{:d}.tfrecords".format(tfrecord_idx))
        tfrecord_writer_i = tf.io.TFRecordWriter(path_tfrecord_i)
        # Write current batch to tfrecord.
        for _ in range(batch_size):
            if data_idx < len(dataset_train):
                feature_i = {
                    "label": _bytes_feature(tf.io.serialize_tensor(dataset_train[data_idx]["label"])),
                    "data": _bytes_feature(tf.io.serialize_tensor(dataset_train[data_idx]["data"])),
                    "subj_id": _bytes_feature(tf.io.serialize_tensor(dataset_train[data_idx]["subj_id"])),
                    "domain_id": _bytes_feature(tf.io.serialize_tensor(dataset_train[data_idx]["domain_id"])),
                }; example_i = tf.train.Example(features=tf.train.Features(feature=feature_i))
                tfrecord_writer_i.write(example_i.SerializeToString())
            data_idx += 1
        # Close tfrecord writer, then update `tfrecord_idx`.
        tfrecord_writer_i.close(); tfrecord_idx += 1
    # Save `dataset_validation` to tf-records.
    data_idx = 0; tfrecord_idx = 0
    while data_idx < len(dataset_validation):
        # Initialize the path of tf-record to save current batch.
        path_tfrecord_i = os.path.join(path_dataset_tfrecords_validation, "validation-{:d}.tfrecords".format(tfrecord_idx))
        tfrecord_writer_i = tf.io.TFRecordWriter(path_tfrecord_i)
        # Write current batch to tfrecord.
        for _ in range(batch_size):
            if data_idx < len(dataset_validation):
                feature_i = {
                    "label": _bytes_feature(tf.io.serialize_tensor(dataset_validation[data_idx]["label"])),
                    "data": _bytes_feature(tf.io.serialize_tensor(dataset_validation[data_idx]["data"])),
                    "subj_id": _bytes_feature(tf.io.serialize_tensor(dataset_validation[data_idx]["subj_id"])),
                    "domain_id": _bytes_feature(tf.io.serialize_tensor(dataset_train[data_idx]["domain_id"])),
                }; example_i = tf.train.Example(features=tf.train.Features(feature=feature_i))
                tfrecord_writer_i.write(example_i.SerializeToString())
            data_idx += 1
        # Close tfrecord writer, then update `tfrecord_idx`.
        tfrecord_writer_i.close(); tfrecord_idx += 1
    # Save `dataset_test` to tf-records.
    data_idx = 0; tfrecord_idx = 0
    while data_idx < len(dataset_test):
        # Initialize the path of tf-record to save current batch.
        path_tfrecord_i = os.path.join(path_dataset_tfrecords_test, "test-{:d}.tfrecords".format(tfrecord_idx))
        tfrecord_writer_i = tf.io.TFRecordWriter(path_tfrecord_i)
        # Write current batch to tfrecord.
        for _ in range(batch_size):
            if data_idx < len(dataset_test):
                feature_i = {
                    "label": _bytes_feature(tf.io.serialize_tensor(dataset_test[data_idx]["label"])),
                    "data": _bytes_feature(tf.io.serialize_tensor(dataset_test[data_idx]["data"])),
                    "subj_id": _bytes_feature(tf.io.serialize_tensor(dataset_test[data_idx]["subj_id"])),
                    "domain_id": _bytes_feature(tf.io.serialize_tensor(dataset_train[data_idx]["domain_id"])),
                }; example_i = tf.train.Example(features=tf.train.Features(feature=feature_i))
                tfrecord_writer_i.write(example_i.SerializeToString())
            data_idx += 1
        # Close tfrecord writer, then update `tfrecord_idx`.
        tfrecord_writer_i.close(); tfrecord_idx += 1
    ## Write JSON configuration related to current dataset.
    data_shape = tuple(dataset_train[0]["data"].shape)
    json_config = {"n_labels": n_labels, "n_subjects": n_subjects, "n_domains": n_domains, "data_shape": data_shape,}
    with open(os.path.join(path_dataset_tfrecords, "config.json"), "w") as f:
        json.dump(json_config, f)

# def _prepare_crossmodality_tmr func
def _prepare_crossmodality_tmr(path_dataset, subj_runs, modalities,
    batch_size, modality="N2/3", train_ratio=0.8, thres_dataset=100):
    """
    Prepare crossmodality data tf-records for tmr dataset.

    Args:
        path_dataset: str - The path to dataset base.
        subj_runs: list - The list of available subj-runs.
        modalities: list - The list of available modalities.
        batch_size: int - The size of batch to split tfrecords.
        modality: str - The specified data modality, supporting [N2/3,REM].

    Returns:
        None
    """
    assert modality in ["N2/3", "REM"]
    ## Prepare for dataset creation.
    # Initialize the path to save tf-records.
    path_dataset_tfrecords = os.path.join(path_dataset, "dataset.crossmodality")
    if not os.path.exists(path_dataset_tfrecords): os.makedirs(path_dataset_tfrecords)
    path_dataset_tfrecords = os.path.join(path_dataset_tfrecords, modality.replace("/", ""))
    if not os.path.exists(path_dataset_tfrecords): os.makedirs(path_dataset_tfrecords)
    # Initialize `labels`, which is the map between category and the corresponding index.
    labels = DotDict({"alarm": 0, "apple": 1, "ball": 2, "book": 3, "box": 4, "chair": 5, "kiwi": 6, "microphone": 7,
        "motorcycle": 8, "pepper": 9, "sheep": 10, "shoes": 11, "strawberry": 12, "tomato": 13, "watch": 14,})
    n_labels = len(labels)
    # Initialize the number of subjects.
    subj_runs = sorted(set(subj_runs)); n_subjects = len(subj_runs)
    # Initialize the number of domains.
    domains = sorted(set(modalities)); n_domains = len(domains); domain_idx = domains.index(modality)
    ## Execute dataset creation.
    # Initialize datasets.
    datasets = DotDict({})
    for subj_run_i in subj_runs:
        # Load the corresponding dataset.
        dataset_i = load_pickle(os.path.join(path_dataset, os.sep.join(subj_run_i.split("/")), "dataset.tmr"))[modality]
        # If the size of dataset is too small, ignore it.
        if len(dataset_i) < thres_dataset: continue
        # If the number of each label in dataset is not enough, ignore it.
        label_counter_i = Counter([data_i["audio"]["name"] for data_i in dataset_i])
        if min(label_counter_i.values()) < 3: continue
        # Only keep audio part of `dataset_i`.
        dataset_i = [data_i["audio"] for data_i in dataset_i]
        # Update `datasets` with `dataset_i`.
        datasets[subj_run_i] = dataset_i
    print(("INFO: Get {:d} datasets (including {}).").format(len(datasets), sorted(datasets.keys())))
    # Prepare to construct train-set & validation-set & test-set.
    train_ratio = train_ratio; validation_ratio = test_ratio = (1. - train_ratio) / 2.
    # Construct `*sets` according to `datasets`.
    trainsets = DotDict({}); validationsets = DotDict({}); testsets = DotDict({})
    for subj_run_i in sorted(datasets.keys()):
        # Initialize `*sets` as empty list.
        trainsets[subj_run_i] = []; validationsets[subj_run_i] = []; testsets[subj_run_i] = []
        # Get the corresponding dataset.
        dataset_i = datasets[subj_run_i]
        # Count the number of samples corresponding to each label.
        label_counter_i = Counter([data_i["name"] for data_i in dataset_i])
        label_counter_validation_i = DotDict({label_name_i:max(1,
            int(label_counter_i[label_name_i] * validation_ratio)
        ) for label_name_i in sorted(label_counter_i.keys())})
        label_counter_test_i = DotDict({label_name_i:max(1,
            int(label_counter_i[label_name_i] * test_ratio)
        ) for label_name_i in sorted(label_counter_i.keys())})
        label_counter_train_i = DotDict({label_name_i:(label_counter_i[label_name_i] -\
            (label_counter_validation_i[label_name_i] + label_counter_test_i[label_name_i])
        ) for label_name_i in sorted(label_counter_i.keys())})
        # Loop over all labels to get the corresponding indices.
        for label_name_i in sorted(label_counter_i.keys()):
            # Get the indices of label, then split it.
            label_idxs = [data_idx for data_idx in range(len(dataset_i)) if dataset_i[data_idx]["name"] == label_name_i]
            validation_idxs = np.random.choice(label_idxs,
                size=label_counter_validation_i[label_name_i], replace=False).tolist()
            test_idxs = np.random.choice(sorted(set(label_idxs) - set(validation_idxs)),
                size=label_counter_test_i[label_name_i], replace=False).tolist()
            train_idxs = sorted(set(label_idxs) - set(validation_idxs) - set(test_idxs))
            assert len(set(train_idxs) & set(validation_idxs)) == 0
            assert len(set(train_idxs) & set(test_idxs)) == 0
            assert len(set(validation_idxs) & set(test_idxs)) == 0
            # Use these indices to fill up `*sets`.
            trainsets[subj_run_i].extend([dataset_i[data_idx] for data_idx in train_idxs])
            validationsets[subj_run_i].extend([dataset_i[data_idx] for data_idx in validation_idxs])
            testsets[subj_run_i].extend([dataset_i[data_idx] for data_idx in test_idxs])
    # Log information related to train-set & validation-set & test-set of dataset.
    n_samples = DotDict({
        "train-set": np.sum([len(trainset_i) for trainset_i in trainsets.values()]),
        "validation-set": np.sum([len(validationset_i) for validationset_i in validationsets.values()]),
        "test-set": np.sum([len(testset_i) for testset_i in testsets.values()]),
    })
    print((
        "INFO: Complete the segmentation of dataset, with train-set ({:d})" +\
        " & validation-set ({:d}) & test-set ({:d}), including {}."
    ).format(n_samples["train-set"], n_samples["validation-set"], n_samples["test-set"], sorted(datasets.keys())))
    ### Write created dataset to tf-records.
    ## Initialize the path of dataset.
    path_dataset_tfrecords_train = os.path.join(path_dataset_tfrecords, "train")
    if not os.path.exists(path_dataset_tfrecords_train): os.makedirs(path_dataset_tfrecords_train)
    path_dataset_tfrecords_validation = os.path.join(path_dataset_tfrecords, "validation")
    if not os.path.exists(path_dataset_tfrecords_validation): os.makedirs(path_dataset_tfrecords_validation)
    path_dataset_tfrecords_test = os.path.join(path_dataset_tfrecords, "test")
    if not os.path.exists(path_dataset_tfrecords_test): os.makedirs(path_dataset_tfrecords_test)
    ## Save constructed dataset.
    # Construct `dataset_*` according to `*sets`.
    dataset_train = [{
        "label": tf.cast(np.eye(n_labels)[labels[data_i["name"]]], dtype=tf.float32),
        "data": tf.convert_to_tensor(_robust_scale(data_i["data"]), dtype=tf.float32),
        "subj_id": tf.cast(np.eye(n_subjects)[subj_runs.index(subj_run_i)], dtype=tf.float32),
        "domain_id": tf.cast(np.eye(n_domains)[domain_idx], dtype=tf.float32),
    } for subj_run_i in trainsets for data_i in trainsets[subj_run_i]]
    dataset_validation = [{
        "label": tf.cast(np.eye(n_labels)[labels[data_i["name"]]], dtype=tf.float32),
        "data": tf.convert_to_tensor(_robust_scale(data_i["data"]), dtype=tf.float32),
        "subj_id": tf.cast(np.eye(n_subjects)[subj_runs.index(subj_run_i)], dtype=tf.float32),
        "domain_id": tf.cast(np.eye(n_domains)[domain_idx], dtype=tf.float32),
    } for subj_run_i in validationsets for data_i in validationsets[subj_run_i]]
    dataset_test = [{
        "label": tf.cast(np.eye(n_labels)[labels[data_i["name"]]], dtype=tf.float32),
        "data": tf.convert_to_tensor(_robust_scale(data_i["data"]), dtype=tf.float32),
        "subj_id": tf.cast(np.eye(n_subjects)[subj_runs.index(subj_run_i)], dtype=tf.float32),
        "domain_id": tf.cast(np.eye(n_domains)[domain_idx], dtype=tf.float32),
    } for subj_run_i in testsets for data_i in testsets[subj_run_i]]
    # Shuffle `dataset_*` to get shuffled list.
    np.random.shuffle(dataset_train); np.random.shuffle(dataset_validation); np.random.shuffle(dataset_test)
    # Save `dataset_train` to tf-records.
    data_idx = 0; tfrecord_idx = 0
    while data_idx < len(dataset_train):
        # Initialize the path of tf-record to save current batch.
        path_tfrecord_i = os.path.join(path_dataset_tfrecords_train, "train-{:d}.tfrecords".format(tfrecord_idx))
        tfrecord_writer_i = tf.io.TFRecordWriter(path_tfrecord_i)
        # Write current batch to tfrecord.
        for _ in range(batch_size):
            if data_idx < len(dataset_train):
                feature_i = {
                    "label": _bytes_feature(tf.io.serialize_tensor(dataset_train[data_idx]["label"])),
                    "data": _bytes_feature(tf.io.serialize_tensor(dataset_train[data_idx]["data"])),
                    "subj_id": _bytes_feature(tf.io.serialize_tensor(dataset_train[data_idx]["subj_id"])),
                    "domain_id": _bytes_feature(tf.io.serialize_tensor(dataset_train[data_idx]["domain_id"])),
                }; example_i = tf.train.Example(features=tf.train.Features(feature=feature_i))
                tfrecord_writer_i.write(example_i.SerializeToString())
            data_idx += 1
        # Close tfrecord writer, then update `tfrecord_idx`.
        tfrecord_writer_i.close(); tfrecord_idx += 1
    # Save `dataset_validation` to tf-records.
    data_idx = 0; tfrecord_idx = 0
    while data_idx < len(dataset_validation):
        # Initialize the path of tf-record to save current batch.
        path_tfrecord_i = os.path.join(path_dataset_tfrecords_validation, "validation-{:d}.tfrecords".format(tfrecord_idx))
        tfrecord_writer_i = tf.io.TFRecordWriter(path_tfrecord_i)
        # Write current batch to tfrecord.
        for _ in range(batch_size):
            if data_idx < len(dataset_validation):
                feature_i = {
                    "label": _bytes_feature(tf.io.serialize_tensor(dataset_validation[data_idx]["label"])),
                    "data": _bytes_feature(tf.io.serialize_tensor(dataset_validation[data_idx]["data"])),
                    "subj_id": _bytes_feature(tf.io.serialize_tensor(dataset_validation[data_idx]["subj_id"])),
                    "domain_id": _bytes_feature(tf.io.serialize_tensor(dataset_train[data_idx]["domain_id"])),
                }; example_i = tf.train.Example(features=tf.train.Features(feature=feature_i))
                tfrecord_writer_i.write(example_i.SerializeToString())
            data_idx += 1
        # Close tfrecord writer, then update `tfrecord_idx`.
        tfrecord_writer_i.close(); tfrecord_idx += 1
    # Save `dataset_test` to tf-records.
    data_idx = 0; tfrecord_idx = 0
    while data_idx < len(dataset_test):
        # Initialize the path of tf-record to save current batch.
        path_tfrecord_i = os.path.join(path_dataset_tfrecords_test, "test-{:d}.tfrecords".format(tfrecord_idx))
        tfrecord_writer_i = tf.io.TFRecordWriter(path_tfrecord_i)
        # Write current batch to tfrecord.
        for _ in range(batch_size):
            if data_idx < len(dataset_test):
                feature_i = {
                    "label": _bytes_feature(tf.io.serialize_tensor(dataset_test[data_idx]["label"])),
                    "data": _bytes_feature(tf.io.serialize_tensor(dataset_test[data_idx]["data"])),
                    "subj_id": _bytes_feature(tf.io.serialize_tensor(dataset_test[data_idx]["subj_id"])),
                    "domain_id": _bytes_feature(tf.io.serialize_tensor(dataset_train[data_idx]["domain_id"])),
                }; example_i = tf.train.Example(features=tf.train.Features(feature=feature_i))
                tfrecord_writer_i.write(example_i.SerializeToString())
            data_idx += 1
        # Close tfrecord writer, then update `tfrecord_idx`.
        tfrecord_writer_i.close(); tfrecord_idx += 1
    ## Write JSON configuration related to current dataset.
    data_shape = tuple(dataset_train[0]["data"].shape)
    json_config = {"n_labels": n_labels, "n_subjects": n_subjects, "n_domains": n_domains, "data_shape": data_shape,}
    with open(os.path.join(path_dataset_tfrecords, "config.json"), "w") as f:
        json.dump(json_config, f)

"""
uni-modalitity funcs
"""
# def prepare_unimodality func
def prepare_unimodality(base, batch_size):
    """
    Prepare unimodality data tf-records for dataset.

    Args:
        base: str - The path of project base.
        batch_size: int - The size of batch to split tfrecords.

    Returns:
        None
    """
    ## Prepare for dataset creation.
    # Initialize the path of dataset.
    path_dataset = os.path.join(base, "data", "eeg.anonymous")
    # Initialize available subj-runs.
    available_subj_runs = ["005/20221223", "006/20230103", "007/20230106", "011/20230214", "013/20230308", "018/20230331",\
        "019/20230403", "020/20230405", "021/20230407", "023/20230412", "024/20230414", "025/20230417", "026/20230419",\
        "027/20230421", "028/20230424", "029/20230428", "030/20230504", "031/20230510", "033/20230517", "034/20230519",\
        "036/20230526", "037/20230529", "038/20230531", "039/20230605", "040/20230607", "042/20230614", "043/20230616",\
        "044/20230619", "045/20230626", "046/20230628", "047/20230703", "048/20230705", "049/20230710", "050/20230712",\
        "051/20230717", "052/20230719", "053/20230724", "054/20230726",]
    assert len(set([subj_run_i.split("/")[0] for subj_run_i in available_subj_runs])) == len(available_subj_runs)
    # Initialize the subj-runs which are used in finetune dataset.
    finetune_subj_runs = available_subj_runs[-1:]
    ## Execute dataset creation.
    # Loop over `finetune_subj_runs` to create the corresponding dataset.
    available_modalities = ["image", "audio", "N2/3", "REM"]
    for finetune_subj_run_i in finetune_subj_runs:
        assert finetune_subj_run_i in available_subj_runs
        datasets_config_i = DotDict({
            "pretrain": sorted(set(available_subj_runs) - set([finetune_subj_run_i,])),
            "finetune": [finetune_subj_run_i,],
        })
        # Loop over available modalities to create dataset.
        for modality_i in available_modalities:
            _prepare_unimodality(path_dataset, datasets_config_i, batch_size, modality=modality_i)

# def _prepare_unimodality func
def _prepare_unimodality(path_dataset, datasets_config, batch_size, modality="N2/3", train_ratio=0.8, thres_dataset=100):
    """
    Prepare unimodality data tf-records for dataset.

    Args:
        path_dataset: str - The path to dataset base.
        datasets_config: dict - The dict of configuration corresponding to pretrain dataset & finetune dataset,
            each contains train-set & validation-set & test-set.
        batch_size: int - The size of batch to split tfrecords.
        modality: str - The specified data modality, supporting [N2/3,REM].

    Returns:
        None
    """
    # Prepare unimodality data tf-records according to `modality`.
    if modality in ["image", "audio"]:
        _prepare_unimodality_task(path_dataset, datasets_config, batch_size,
            modality=modality, train_ratio=train_ratio, thres_dataset=thres_dataset)
    elif modality in ["N2/3", "REM"]:
        _prepare_unimodality_tmr(path_dataset, datasets_config, batch_size,
            modality=modality, train_ratio=train_ratio, thres_dataset=thres_dataset)
    else:
        raise ValueError("ERROR: Get unknown modality ({}).".format(modality))

# def _prepare_unimodality_task func
def _prepare_unimodality_task(path_dataset, datasets_config, batch_size, modality="N2/3", train_ratio=0.8, thres_dataset=100):
    """
    Prepare unimodality data tf-records for task dataset.

    Args:
        path_dataset: str - The path to dataset base.
        datasets_config: dict - The dict of configuration corresponding to pretrain dataset & finetune dataset,
            each contains train-set & validation-set & test-set.
        batch_size: int - The size of batch to split tfrecords.
        modality: str - The specified data modality, supporting [N2/3,REM].

    Returns:
        None
    """
    assert modality in ["image", "audio"]

# def _prepare_unimodality_tmr func
def _prepare_unimodality_tmr(path_dataset, datasets_config, batch_size, modality="N2/3", train_ratio=0.8, thres_dataset=100):
    """
    Prepare unimodality data tf-records for TMR dataset.

    Args:
        path_dataset: str - The path to dataset base.
        datasets_config: dict - The dict of configuration corresponding to pretrain dataset & finetune dataset,
            each contains train-set & validation-set & test-set.
        batch_size: int - The size of batch to split tfrecords.
        modality: str - The specified data modality, supporting [N2/3,REM].

    Returns:
        None
    """
    assert modality in ["N2/3", "REM"]
    ## Prepare for dataset creation.
    # Initialize the path to save tf-records.
    path_dataset_tfrecords = os.path.join(path_dataset, "dataset.unimodality")
    if not os.path.exists(path_dataset_tfrecords): os.makedirs(path_dataset_tfrecords)
    # Make sure finetune dataset only contains one subj-run, then create the corresponding dataset.
    assert len(datasets_config["finetune"]) == 1
    path_dataset_tfrecords = os.path.join(path_dataset_tfrecords, ".".join(datasets_config["finetune"][0].split("/")))
    if not os.path.exists(path_dataset_tfrecords): os.makedirs(path_dataset_tfrecords)
    path_dataset_tfrecords = os.path.join(path_dataset_tfrecords, modality.replace("/", ""))
    if not os.path.exists(path_dataset_tfrecords): os.makedirs(path_dataset_tfrecords)
    # Initialize `labels`, which is the map between category and the corresponding index.
    labels = DotDict({"alarm": 0, "apple": 1, "ball": 2, "book": 3, "box": 4, "chair": 5, "kiwi": 6, "microphone": 7,
        "motorcycle": 8, "pepper": 9, "sheep": 10, "shoes": 11, "strawberry": 12, "tomato": 13, "watch": 14,})
    n_labels = len(labels)
    ## Execute dataset creation.
    # Make sure no overlap between pretrain datasets & finetune datasets.
    assert len(set(datasets_config["pretrain"]) & set(datasets_config["finetune"])) == 0
    # Initialize pretrain & finetune datasets.
    datasets_pretrain = DotDict({}); datasets_finetune = DotDict({})
    for subj_run_i in datasets_config["pretrain"]:
        # Load the corresponding dataset.
        dataset_i = load_pickle(os.path.join(path_dataset, os.sep.join(subj_run_i.split("/")), "dataset.tmr"))[modality]
        # If the size of dataset is too small, ignore it.
        if len(dataset_i) < thres_dataset: continue
        # If the number of each label in dataset is not enough, ignore it.
        label_counter_i = Counter([data_i["audio"]["name"] for data_i in dataset_i])
        if min(label_counter_i.values()) < 3: continue
        # Only keep audio part of `dataset_i`.
        dataset_i = [data_i["audio"] for data_i in dataset_i]
        # Update `datasets_pretrain` with `dataset_i`.
        datasets_pretrain[subj_run_i] = dataset_i
    for subj_run_i in datasets_config["finetune"]:
        # Load the corresponding dataset.
        dataset_i = load_pickle(os.path.join(path_dataset, os.sep.join(subj_run_i.split("/")), "dataset.tmr"))[modality]
        # If the size of dataset is too small, ignore it.
        if len(dataset_i) < thres_dataset: continue
        # If the number of each label in dataset is not enough, ignore it.
        label_counter_i = Counter([data_i["audio"]["name"] for data_i in dataset_i])
        if min(label_counter_i.values()) < 3: continue
        # Only keep audio part of `dataset_i`.
        dataset_i = [data_i["audio"] for data_i in dataset_i]
        # Update `datasets_finetune` with `dataset_i`.
        datasets_finetune[subj_run_i] = dataset_i
    print((
        "INFO: Get {:d} pretrain datasets (including {}), and {:d} finetune datasets (including {})."
    ).format(len(datasets_pretrain), sorted(datasets_pretrain.keys()),
        len(datasets_finetune), sorted(datasets_finetune.keys())))
    # Initialize the number of subjects.
    subj_runs = sorted(set(datasets_pretrain.keys()) | set(datasets_finetune.keys())); n_subjects = len(subj_runs)
    # Prepare to construct train-set & validation-set & test-set.
    train_ratio = train_ratio; validation_ratio = test_ratio = (1. - train_ratio) / 2.
    # Construct `*sets_finetune` according to `datasets_finetune`.
    trainsets_finetune = DotDict({}); validationsets_finetune = DotDict({}); testsets_finetune = DotDict({})
    for subj_run_i in sorted(datasets_finetune.keys()):
        # Initialize `*sets_finetune` as empty list.
        trainsets_finetune[subj_run_i] = []; validationsets_finetune[subj_run_i] = []; testsets_finetune[subj_run_i] = []
        # Get the corresponding dataset.
        dataset_i = datasets_finetune[subj_run_i]
        # Count the number of samples corresponding to each label.
        label_counter_i = Counter([data_i["name"] for data_i in dataset_i])
        label_counter_validation_i = DotDict({label_name_i:max(1,
            int(label_counter_i[label_name_i] * validation_ratio)
        ) for label_name_i in sorted(label_counter_i.keys())})
        label_counter_test_i = DotDict({label_name_i:max(1,
            int(label_counter_i[label_name_i] * test_ratio)
        ) for label_name_i in sorted(label_counter_i.keys())})
        label_counter_train_i = DotDict({label_name_i:(label_counter_i[label_name_i] -\
            (label_counter_validation_i[label_name_i] + label_counter_test_i[label_name_i])
        ) for label_name_i in sorted(label_counter_i.keys())})
        # Loop over all labels to get the corresponding indices.
        for label_name_i in sorted(label_counter_i.keys()):
            # Get the indices of label, then split it.
            label_idxs = [data_idx for data_idx in range(len(dataset_i)) if dataset_i[data_idx]["name"] == label_name_i]
            validation_idxs = np.random.choice(label_idxs,
                size=label_counter_validation_i[label_name_i], replace=False).tolist()
            test_idxs = np.random.choice(sorted(set(label_idxs) - set(validation_idxs)),
                size=label_counter_test_i[label_name_i], replace=False).tolist()
            train_idxs = sorted(set(label_idxs) - set(validation_idxs) - set(test_idxs))
            assert len(set(train_idxs) & set(validation_idxs)) == 0
            assert len(set(train_idxs) & set(test_idxs)) == 0
            assert len(set(validation_idxs) & set(test_idxs)) == 0
            # Use these indices to fill up `*sets_finetune`.
            trainsets_finetune[subj_run_i].extend([dataset_i[data_idx] for data_idx in train_idxs])
            validationsets_finetune[subj_run_i].extend([dataset_i[data_idx] for data_idx in validation_idxs])
            testsets_finetune[subj_run_i].extend([dataset_i[data_idx] for data_idx in test_idxs])
    # Log information related to train-set & validation-set & test-set of finetune dataset.
    n_samples_finetune = DotDict({
        "train-set": np.sum([len(trainset_i) for trainset_i in trainsets_finetune.values()]),
        "validation-set": np.sum([len(validationset_i) for validationset_i in validationsets_finetune.values()]),
        "test-set": np.sum([len(testset_i) for testset_i in testsets_finetune.values()]),
    })
    print((
        "INFO: Complete the segmentation of finetune dataset, with train-set ({:d})" +\
        " & validation-set ({:d}) & test-set ({:d}), including {}."
    ).format(n_samples_finetune["train-set"], n_samples_finetune["validation-set"],
        n_samples_finetune["test-set"], sorted(datasets_finetune.keys())))
    # Construct `*sets_pretrain` according to `datasets_pretrain`.
    trainsets_pretrain = DotDict({}); validationsets_pretrain = DotDict({}); testsets_pretrain = DotDict({})
    for subj_run_i in sorted(datasets_pretrain.keys()):
        # Initialize `*sets_pretrain` as empty list.
        trainsets_pretrain[subj_run_i] = []; validationsets_pretrain[subj_run_i] = []; testsets_pretrain[subj_run_i] = []
        # Get the corresponding dataset.
        dataset_i = datasets_pretrain[subj_run_i]
        # Count the number of samples corresponding to each label.
        label_counter_i = Counter([data_i["name"] for data_i in dataset_i])
        label_counter_validation_i = DotDict({label_name_i:max(1,
            int(label_counter_i[label_name_i] * validation_ratio)
        ) for label_name_i in sorted(label_counter_i.keys())})
        label_counter_test_i = DotDict({label_name_i:max(1,
            int(label_counter_i[label_name_i] * test_ratio)
        ) for label_name_i in sorted(label_counter_i.keys())})
        label_counter_train_i = DotDict({label_name_i:(label_counter_i[label_name_i] -\
            (label_counter_validation_i[label_name_i] + label_counter_test_i[label_name_i])
        ) for label_name_i in sorted(label_counter_i.keys())})
        # Loop over all labels to get the corresponding indices.
        for label_name_i in sorted(label_counter_i.keys()):
            # Get the indices of label, then split it.
            label_idxs = [data_idx for data_idx in range(len(dataset_i)) if dataset_i[data_idx]["name"] == label_name_i]
            validation_idxs = np.random.choice(label_idxs,
                size=label_counter_validation_i[label_name_i], replace=False).tolist()
            test_idxs = np.random.choice(sorted(set(label_idxs) - set(validation_idxs)),
                size=label_counter_test_i[label_name_i], replace=False).tolist()
            train_idxs = sorted(set(label_idxs) - set(validation_idxs) - set(test_idxs))
            assert len(set(train_idxs) & set(validation_idxs)) == 0
            assert len(set(train_idxs) & set(test_idxs)) == 0
            assert len(set(validation_idxs) & set(test_idxs)) == 0
            # Use these indices to fill up `*sets_pretrain`.
            trainsets_pretrain[subj_run_i].extend([dataset_i[data_idx] for data_idx in train_idxs])
            validationsets_pretrain[subj_run_i].extend([dataset_i[data_idx] for data_idx in validation_idxs])
            testsets_pretrain[subj_run_i].extend([dataset_i[data_idx] for data_idx in test_idxs])
    # Log information related to train-set & validation-set & test-set of pretrain dataset.
    n_samples_pretrain = DotDict({
        "train-set": np.sum([len(trainset_i) for trainset_i in trainsets_pretrain.values()]),
        "validation-set": np.sum([len(validationset_i) for validationset_i in validationsets_pretrain.values()]),
        "test-set": np.sum([len(testset_i) for testset_i in testsets_pretrain.values()]),
    })
    print((
        "INFO: Complete the segmentation of pretrain dataset, with train-set ({:d})," +\
        " & validation-set ({:d}) & test-set ({:d}), including {}."
    ).format(n_samples_pretrain["train-set"], n_samples_pretrain["validation-set"],
        n_samples_pretrain["test-set"], sorted(datasets_pretrain.keys())))
    ### Write created dataset to tf-records.
    ## Initialize the path of pretrain dataset & finetune dataset.
    path_dataset_tfrecords_pretrain = os.path.join(path_dataset_tfrecords, "pretrain")
    if not os.path.exists(path_dataset_tfrecords_pretrain): os.makedirs(path_dataset_tfrecords_pretrain)
    path_dataset_tfrecords_pretrain_train = os.path.join(path_dataset_tfrecords_pretrain, "train")
    if not os.path.exists(path_dataset_tfrecords_pretrain_train): os.makedirs(path_dataset_tfrecords_pretrain_train)
    path_dataset_tfrecords_pretrain_validation = os.path.join(path_dataset_tfrecords_pretrain, "validation")
    if not os.path.exists(path_dataset_tfrecords_pretrain_validation): os.makedirs(path_dataset_tfrecords_pretrain_validation)
    path_dataset_tfrecords_pretrain_test = os.path.join(path_dataset_tfrecords_pretrain, "test")
    if not os.path.exists(path_dataset_tfrecords_pretrain_test): os.makedirs(path_dataset_tfrecords_pretrain_test)
    path_dataset_tfrecords_finetune = os.path.join(path_dataset_tfrecords, "finetune")
    if not os.path.exists(path_dataset_tfrecords_finetune): os.makedirs(path_dataset_tfrecords_finetune)
    path_dataset_tfrecords_finetune_train = os.path.join(path_dataset_tfrecords_finetune, "train")
    if not os.path.exists(path_dataset_tfrecords_finetune_train): os.makedirs(path_dataset_tfrecords_finetune_train)
    path_dataset_tfrecords_finetune_validation = os.path.join(path_dataset_tfrecords_finetune, "validation")
    if not os.path.exists(path_dataset_tfrecords_finetune_validation): os.makedirs(path_dataset_tfrecords_finetune_validation)
    path_dataset_tfrecords_finetune_test = os.path.join(path_dataset_tfrecords_finetune, "test")
    if not os.path.exists(path_dataset_tfrecords_finetune_test): os.makedirs(path_dataset_tfrecords_finetune_test)
    ## Save constructed finetune dataset.
    # Construct `dataset_finetune_*` according to `*sets_finetune`.
    dataset_finetune_train = [{
        "label": tf.cast(np.eye(n_labels)[labels[data_i["name"]]], dtype=tf.float32),
        "data": tf.convert_to_tensor(_robust_scale(data_i["data"]), dtype=tf.float32),
        "subj_id": tf.cast(np.eye(n_subjects)[subj_runs.index(subj_run_i)], dtype=tf.float32),
    } for subj_run_i in trainsets_finetune for data_i in trainsets_finetune[subj_run_i]]
    dataset_finetune_validation = [{
        "label": tf.cast(np.eye(n_labels)[labels[data_i["name"]]], dtype=tf.float32),
        "data": tf.convert_to_tensor(_robust_scale(data_i["data"]), dtype=tf.float32),
        "subj_id": tf.cast(np.eye(n_subjects)[subj_runs.index(subj_run_i)], dtype=tf.float32),
    } for subj_run_i in validationsets_finetune for data_i in validationsets_finetune[subj_run_i]]
    dataset_finetune_test = [{
        "label": tf.cast(np.eye(n_labels)[labels[data_i["name"]]], dtype=tf.float32),
        "data": tf.convert_to_tensor(_robust_scale(data_i["data"]), dtype=tf.float32),
        "subj_id": tf.cast(np.eye(n_subjects)[subj_runs.index(subj_run_i)], dtype=tf.float32),
    } for subj_run_i in testsets_finetune for data_i in testsets_finetune[subj_run_i]]
    # Shuffle `dataset_finetune_*` to get shuffled list.
    np.random.shuffle(dataset_finetune_train)
    np.random.shuffle(dataset_finetune_validation)
    np.random.shuffle(dataset_finetune_test)
    # Save `dataset_finetune_train` to tf-records.
    data_idx = 0; tfrecord_idx = 0
    while data_idx < len(dataset_finetune_train):
        # Initialize the path of tf-record to save current batch.
        path_tfrecord_i = os.path.join(path_dataset_tfrecords_finetune_train,
            "train-{:d}.tfrecords".format(tfrecord_idx))
        tfrecord_writer_i = tf.io.TFRecordWriter(path_tfrecord_i)
        # Write current batch to tfrecord.
        for _ in range(batch_size):
            if data_idx < len(dataset_finetune_train):
                feature_i = {
                    "label": _bytes_feature(tf.io.serialize_tensor(dataset_finetune_train[data_idx]["label"])),
                    "data": _bytes_feature(tf.io.serialize_tensor(dataset_finetune_train[data_idx]["data"])),
                    "subj_id": _bytes_feature(tf.io.serialize_tensor(dataset_finetune_train[data_idx]["subj_id"])),
                }; example_i = tf.train.Example(features=tf.train.Features(feature=feature_i))
                tfrecord_writer_i.write(example_i.SerializeToString())
            data_idx += 1
        # Close tfrecord writer, then update `tfrecord_idx`.
        tfrecord_writer_i.close(); tfrecord_idx += 1
    # Save `dataset_finetune_validation` to tf-records.
    data_idx = 0; tfrecord_idx = 0
    while data_idx < len(dataset_finetune_validation):
        # Initialize the path of tf-record to save current batch.
        path_tfrecord_i = os.path.join(path_dataset_tfrecords_finetune_validation,
            "validation-{:d}.tfrecords".format(tfrecord_idx))
        tfrecord_writer_i = tf.io.TFRecordWriter(path_tfrecord_i)
        # Write current batch to tfrecord.
        for _ in range(batch_size):
            if data_idx < len(dataset_finetune_validation):
                feature_i = {
                    "label": _bytes_feature(tf.io.serialize_tensor(dataset_finetune_validation[data_idx]["label"])),
                    "data": _bytes_feature(tf.io.serialize_tensor(dataset_finetune_validation[data_idx]["data"])),
                    "subj_id": _bytes_feature(tf.io.serialize_tensor(dataset_finetune_validation[data_idx]["subj_id"])),
                }; example_i = tf.train.Example(features=tf.train.Features(feature=feature_i))
                tfrecord_writer_i.write(example_i.SerializeToString())
            data_idx += 1
        # Close tfrecord writer, then update `tfrecord_idx`.
        tfrecord_writer_i.close(); tfrecord_idx += 1
    # Save `dataset_finetune_test` to tf-records.
    data_idx = 0; tfrecord_idx = 0
    while data_idx < len(dataset_finetune_test):
        # Initialize the path of tf-record to save current batch.
        path_tfrecord_i = os.path.join(path_dataset_tfrecords_finetune_test,
            "test-{:d}.tfrecords".format(tfrecord_idx))
        tfrecord_writer_i = tf.io.TFRecordWriter(path_tfrecord_i)
        # Write current batch to tfrecord.
        for _ in range(batch_size):
            if data_idx < len(dataset_finetune_test):
                feature_i = {
                    "label": _bytes_feature(tf.io.serialize_tensor(dataset_finetune_test[data_idx]["label"])),
                    "data": _bytes_feature(tf.io.serialize_tensor(dataset_finetune_test[data_idx]["data"])),
                    "subj_id": _bytes_feature(tf.io.serialize_tensor(dataset_finetune_test[data_idx]["subj_id"])),
                }; example_i = tf.train.Example(features=tf.train.Features(feature=feature_i))
                tfrecord_writer_i.write(example_i.SerializeToString())
            data_idx += 1
        # Close tfrecord writer, then update `tfrecord_idx`.
        tfrecord_writer_i.close(); tfrecord_idx += 1
    ## Write JSON configuration related to current dataset.
    data_shape = tuple(dataset_finetune_train[0]["data"].shape)
    json_config = {"n_labels": n_labels, "n_subjects": n_subjects, "data_shape": data_shape,}
    with open(os.path.join(path_dataset_tfrecords, "config.json"), "w") as f:
        json.dump(json_config, f)
    ## Save constructed pretrain dataset.
    # Construct `dataset_pretrain_*` according to `*sets_pretrain`.
    dataset_pretrain_train = [{
        "label": tf.cast(np.eye(n_labels)[labels[data_i["name"]]], dtype=tf.float32),
        "data": tf.convert_to_tensor(_robust_scale(data_i["data"]), dtype=tf.float32),
        "subj_id": tf.cast(np.eye(n_subjects)[subj_runs.index(subj_run_i)], dtype=tf.float32),
    } for subj_run_i in trainsets_pretrain for data_i in trainsets_pretrain[subj_run_i]]
    dataset_pretrain_validation = [{
        "label": tf.cast(np.eye(n_labels)[labels[data_i["name"]]], dtype=tf.float32),
        "data": tf.convert_to_tensor(_robust_scale(data_i["data"]), dtype=tf.float32),
        "subj_id": tf.cast(np.eye(n_subjects)[subj_runs.index(subj_run_i)], dtype=tf.float32),
    } for subj_run_i in validationsets_pretrain for data_i in validationsets_pretrain[subj_run_i]]
    dataset_pretrain_test = [{
        "label": tf.cast(np.eye(n_labels)[labels[data_i["name"]]], dtype=tf.float32),
        "data": tf.convert_to_tensor(_robust_scale(data_i["data"]), dtype=tf.float32),
        "subj_id": tf.cast(np.eye(n_subjects)[subj_runs.index(subj_run_i)], dtype=tf.float32),
    } for subj_run_i in testsets_pretrain for data_i in testsets_pretrain[subj_run_i]]
    # Shuffle `dataset_pretrain_*` to get shuffled list.
    np.random.shuffle(dataset_pretrain_train)
    np.random.shuffle(dataset_pretrain_validation)
    np.random.shuffle(dataset_pretrain_test)
    # Save `dataset_pretrain_train` to tf-records.
    data_idx = 0; tfrecord_idx = 0
    while data_idx < len(dataset_pretrain_train):
        # Initialize the path of tf-record to save current batch.
        path_tfrecord_i = os.path.join(path_dataset_tfrecords_pretrain_train,
            "train-{:d}.tfrecords".format(tfrecord_idx))
        tfrecord_writer_i = tf.io.TFRecordWriter(path_tfrecord_i)
        # Write current batch to tfrecord.
        for _ in range(batch_size):
            if data_idx < len(dataset_pretrain_train):
                feature_i = {
                    "label": _bytes_feature(tf.io.serialize_tensor(dataset_pretrain_train[data_idx]["label"])),
                    "data": _bytes_feature(tf.io.serialize_tensor(dataset_pretrain_train[data_idx]["data"])),
                    "subj_id": _bytes_feature(tf.io.serialize_tensor(dataset_pretrain_train[data_idx]["subj_id"])),
                }; example_i = tf.train.Example(features=tf.train.Features(feature=feature_i))
                tfrecord_writer_i.write(example_i.SerializeToString())
            data_idx += 1
        # Close tfrecord writer, then update `tfrecord_idx`.
        tfrecord_writer_i.close(); tfrecord_idx += 1
    # Save `dataset_pretrain_validation` to tf-records.
    data_idx = 0; tfrecord_idx = 0
    while data_idx < len(dataset_pretrain_validation):
        # Initialize the path of tf-record to save current batch.
        path_tfrecord_i = os.path.join(path_dataset_tfrecords_pretrain_validation,
            "validation-{:d}.tfrecords".format(tfrecord_idx))
        tfrecord_writer_i = tf.io.TFRecordWriter(path_tfrecord_i)
        # Write current batch to tfrecord.
        for _ in range(batch_size):
            if data_idx < len(dataset_pretrain_validation):
                feature_i = {
                    "label": _bytes_feature(tf.io.serialize_tensor(dataset_pretrain_validation[data_idx]["label"])),
                    "data": _bytes_feature(tf.io.serialize_tensor(dataset_pretrain_validation[data_idx]["data"])),
                    "subj_id": _bytes_feature(tf.io.serialize_tensor(dataset_pretrain_validation[data_idx]["subj_id"])),
                }; example_i = tf.train.Example(features=tf.train.Features(feature=feature_i))
                tfrecord_writer_i.write(example_i.SerializeToString())
            data_idx += 1
        # Close tfrecord writer, then update `tfrecord_idx`.
        tfrecord_writer_i.close(); tfrecord_idx += 1
    # Save `dataset_pretrain_test` to tf-records.
    data_idx = 0; tfrecord_idx = 0
    while data_idx < len(dataset_pretrain_test):
        # Initialize the path of tf-record to save current batch.
        path_tfrecord_i = os.path.join(path_dataset_tfrecords_pretrain_test,
            "test-{:d}.tfrecords".format(tfrecord_idx))
        tfrecord_writer_i = tf.io.TFRecordWriter(path_tfrecord_i)
        # Write current batch to tfrecord.
        for _ in range(batch_size):
            if data_idx < len(dataset_pretrain_test):
                feature_i = {
                    "label": _bytes_feature(tf.io.serialize_tensor(dataset_pretrain_test[data_idx]["label"])),
                    "data": _bytes_feature(tf.io.serialize_tensor(dataset_pretrain_test[data_idx]["data"])),
                    "subj_id": _bytes_feature(tf.io.serialize_tensor(dataset_pretrain_test[data_idx]["subj_id"])),
                }; example_i = tf.train.Example(features=tf.train.Features(feature=feature_i))
                tfrecord_writer_i.write(example_i.SerializeToString())
            data_idx += 1
        # Close tfrecord writer, then update `tfrecord_idx`.
        tfrecord_writer_i.close(); tfrecord_idx += 1

if __name__ == "__main__":
    # Initialize macros.
    batch_size = 256
    # Initialize the path of project base.
    base = os.path.join(os.getcwd(), os.pardir, os.pardir, os.pardir)

    # Prepare crossmodality datasets.
    prepare_crossmodality(base, batch_size)
    # Prepare unimodality datasets.
    #prepare_unimodality(base, batch_size)

