#!/usr/bin/env python3
"""
Created on 21:09, Dec. 10th, 2022

@author: Anonymous
"""
import os, mne
import re, logging
import contextlib
import numpy as np
import seaborn as sns
import scipy.ndimage as nd
import mne_icalabel as mneicl
import matplotlib.pyplot as plt
from bidict import bidict
from scipy.stats import pearsonr
from scipy.stats import kurtosis
from mne.preprocessing.bads import _find_outliers
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir))
from utils import DotDict

__all__ = [
    # macros
    "default_preprocess_params",
    # data initialization functions
    "load_data",
    "get_events",
    "set_events",
    "init_channel_montage",
    # data preprocess functions
    "init_channel_types",
    "crop",
    "filter_spectrum",
    "resample",
    "fix_bad_channels",
    "set_ref_channels",
    "detect_bad_segments",
    "run_ssp",
    "run_ica",
    # pipeline functions
    "preprocess",
]

# Initialize environment configuration.
mne.set_log_level("WARNING")
# Define tool functions.
@contextlib.contextmanager
def _suppress_logger():
    """Temporarily suppress the logger output."""
    logger = logging.getLogger("mne")
    original_level = logger.level
    logger.setLevel(logging.ERROR)
    try:
        yield
    finally:
        logger.setLevel(original_level)

## Macros
# def default_preprocess_params macro
default_preprocess_params = DotDict({
    # The type of each channel, `eeg` by default.
    "channel_types": None,
    # The list of allowed channel types.
    "allowed_channel_types": ["ecg", "emg", "eog", "eeg", "misc"],
    # The ranges (unit: second) to be cropped.
    "crop_ranges": [],
    # The bad channels detected by the experimenter when doing experiments.
    # We can use algorithms to identify more bad channels.
    "bad_channels": [],
    # Reference channels, set it to `[]` to enable global average.
    "ref_channels": [],
    # The parameters of filter process.
    "filter": {
        # The range of band-pass freqs.
        "band_frange": [0.1, 40],
        # The list of power line noise freqs.
        "notch_freqs": [50,],
    },
    # The frequency of resampled data.
    "resample_freq": 100,
    # The parameters of bad segments detection.
    "bad_segments_detection": {
        # The threshold to initialize bad points, used in `_detect_bad_segments`.
        "scale_threshold": 20.,
        # The threshold to ignore too short bad segments.
        "duration_threshold": 0.2,
        # The reciprocal of the duration of convolution window (unit: second).
        # Note: `window_freq` should be `resample_freq / 10`.
        "window_freq": 10,
    },
    # The parameters of SSP process.
    "ssp": {
        # The threshold of explained variance related to ecg artifacts.
        "explained_var_ecg": 0.8,
    },
    # The parameters of ICA process.
    "ica": {
        # The number of PCA components.
        "n_components": 50,
        # The random seed for `numpy` random number generator to ensure the reproducibility.
        "seed": 42,
        # The threshold of correlation.
        "corr_thres": 0.2,
        # Whether re-run ICA, e.g. after labeling components manually.
        "rerun": False,
        # The markers of events that we care about.
        "markers": None,
    },
})

## Data initialization.
# def load_data func
def load_data(eeg_fname):
    """
    Load data from specified eeg file.

    Args:
        eeg_fname: str - The path of specified eeg data.

    Returns:
        data: The loaded `mne.io.RawArray` object.
    """
    # Load raw data from specified eeg file.
    data = mne.io.read_raw_brainvision(eeg_fname, preload=True)
    # Shift all annotations back to 0s, then reconstruct data to ensure `first_samp` is 0.
    # TODO: `mne.io.RawArray` will erase all annotations!
    annotations = mne.Annotations(
        onset=[(annotation_i["onset"] - data.first_time) for annotation_i in data.annotations],
        duration=[annotation_i["duration"] for annotation_i in data.annotations],
        description=[annotation_i["description"] for annotation_i in data.annotations],
        orig_time=data.annotations.orig_time if len(data.annotations) > 0 else None,
    ); data = mne.io.RawArray(data._data, data.info, first_samp=0)
    assert len(data.annotations) == 0; data.set_annotations(annotations)
    # Return the final `data`.
    return data

# def get_events func
def get_events(data, marker_map=mne.io.brainvision.brainvision._BVEventParser()):
    """
    Get events from annotations. Due to that `data` use annotations to store events (which are annotations
    that have `duration` equal to `1 / data.info["sfreq"]`, we have to use `events_from_annotations` to get
    events from annotations to align with the original data structure.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        marker_map: object - The map from descriptions to markers.

    Returns:
        events: list - The list of event, each `event` DotDict contains [onset,marker,description].
    """
    # Initialize `duration` as `1 / data.info["sfreq"]`.
    duration = 1. / data.info["sfreq"]
    # Get the events from annotations. The original `description` field of `data` doesn't contain the map
    # between description and index, it will construct index from description automatically. And it seems that
    # `Stimulus/S{:3d}` always have the lowest index, no matter how many instantiants are in other kinds of events.
    # Note: To let `mne.events_from_annotations` work with `mne.io.RawArray`, we force the `event_id` argument of
    # `mne.events_from_annotations` to be brainvision exclusive `mne.io.brainvision.brainvision._BVEventParser()`.
    events, markers = mne.events_from_annotations(data, event_id=marker_map); markers = bidict(markers)
    # Construct `events` from np-version `events`.
    # Note: We always use detailed description to describe each event, instead of the general event id.
    # We can extract event id from detailed description, thus avoiding re-index of stimulus event.
    events = [DotDict({
        "onset":float(event_i[0]) * duration,
        "marker":event_i[2], "description":event_i[2],
    }) for event_i in events]
    for event_idx in range(len(events)): events[event_idx].description = markers.inverse[events[event_idx].description]
    # Return the final `events`.
    return events

# def set_events func
def set_events(data, events):
    """
    Set events through annotations. Due to that `data` use annotations to store events (which are annotations
    that have `duration` equal to `1 / data.info["sfreq"]`), we have to use `set_annotations` to set events-like
    annotations to align with the original data structure. We should note that the `onset` field of `events` is
    the relative time (second) of the sample, instead of the index of the sample.
    Note: `set_annotations` will overwritte the original annotations. We should include the original events in `events`.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        events: list - The list of event, each `event` DotDict contains [onset,description].

    Returns:
        data: object - The modified `mne.io.RawArray` object.
    """
    # Initialize `duration` as `1 / data.info["sfreq"]`.
    duration = 1. / data.info["sfreq"]
    # Initialize `annotations` using `events and `duration`.
    # Note: We use `time_as_index` and `duration` to round the original onset. For example,
    # `0.9995` may be round to `1.0`, `1.1175` may be round to `1.118`.
    onset = data.time_as_index([event_i.onset for event_i in events], use_rounding=True) * duration
    annotations = mne.Annotations(
        onset=[event_i.onset for event_i in events],
        duration=[duration for _ in events],
        description=[event_i.description for event_i in events],
        orig_time=data.annotations.orig_time
    )
    # Set events through annotations.
    data.set_annotations(annotations)
    # Return the final `data`.
    return data

## Data preprocess.
# def init_channel_types func
def init_channel_types(data, channel_types=None, allowed_channel_types=["ecg","emg","eog","eeg","misc"]):
    """
    Initialize the type of each channel. We should note that all channels are set to `EEG` by default.
    Note: All these operations are operated on the original object to reduce memory consumption.

    Args:
        data: object - The loaded `mne.io.RawArray` object.

    Returns:
        data: object - The modified `mne.io.RawArray` object.
    """
    # Check whether `channel_types` is None.
    if channel_types is None: return data
    # Set the corresponding channel types.
    for key_i, val_i in channel_types.items(): data.set_channel_types(dict(zip(val_i, [key_i,] * len(val_i))))
    # Then we only keep [ecg,emg,eog] for [heartbeats,?,eye-movements]-detection,
    # [eeg] to data processing, and [misc] for channel reference.
    # Note: As we only allow meaningful channel types, we should do `init_channel_types`
    # before calling `crop`, i.e. convert all meaningful [bio,misc] signals into [ecg,eog].
    data.pick(allowed_channel_types)
    # Return the final `data`.
    return data

# def init_channel_montage func
def init_channel_montage(data, path_montage):
    """
    Initialize the montage of each channel. We should note that no channel has montage setting by default.
    E.g. there is no `dig` field in `data.info`. The original `data.info` is just like a blank sheet.
    >>> <Info | 7 non-empty values
    >>>  bads: []
    >>>  ch_names: Fp1, Fp2, F3, F4, C3, C4, P3, P4, O1, O2, F7, F8, T7, T8, P7, ...
    >>>  chs: 64 EEG
    >>>  custom_ref_applied: False
    >>>  highpass: 0.0 Hz
    >>>  lowpass: 1000.0 Hz
    >>>  meas_date: 2022-11-28 00:35:20 UTC
    >>>  nchan: 64
    >>>  projs: []
    >>>  sfreq: 500.0 Hz
    >>> >
    As we can see, `projs` field is also blank, this field contains some projectors after calling PCA.
    After calling `init_channel_montage`, `data.info` will have another field `dig`.
    >>> <Info | 8 non-empty values
    >>>  bads: []
    >>>  ch_names: Fp1, Fp2, F3, F4, C3, C4, P3, P4, O1, O2, F7, F8, T7, T8, P7, ...
    >>>  chs: 55 EEG, 9 EOG
    >>>  custom_ref_applied: False
    >>>  dig: 58 items (3 Cardinal, 55 EEG)
    >>>  highpass: 0.0 Hz
    >>>  lowpass: 1000.0 Hz
    >>>  meas_date: 2022-11-28 00:35:20 UTC
    >>>  nchan: 64
    >>>  projs: []
    >>>  sfreq: 500.0 Hz
    >>> >
    Note: All these operations are operated on the original object to reduce memory consumption.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        path_montage: path - The path of montage configuration file.

    Returns:
        data: object - The modified `mne.io.RawArray` object.
    """
    # Check whether `path_montage` exists.
    if not os.path.exists(path_montage):
        print("WARNING: The path of montage ({}) doesn't exist.".format(path_montage)); return data
    # Get the corresponding montage.
    montage = mne.channels.read_custom_montage(path_montage)
    # Initialize the montage field of data.
    data.set_montage(montage)
    # Return the final `data`.
    return data

# def crop func
def crop(data, crop_ranges=[]):
    """
    Crop data to remove unused data segments & channels.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        allowed_channel_types: list - The list of allowed channel types.

    Returns:
        data: object - The modified `mne.io.RawArray` object.
    """
    # Check whether `crop_ranges` is empty.
    if len(crop_ranges) == 0: return data
    # Although each session only has a small part of unused head & tail, they have
    # great effect over the process of executing ICA (by Yunzhe), we cannot ignore them.
    # TODO: The crop ranges is relative to `data.first_time`, instead of 0s.
    assert data.first_time == 0.; segments = []; data_range = (data.first_time + data.tmin, data.first_time + data.tmax)
    for crop_range_i in crop_ranges:
        tmin_i = crop_range_i[0] - data.first_time; tmax_i = crop_range_i[1] - data.first_time
        if tmin_i < data.tmin: print("INFO: The provided tmin {} is less than data tmin {}.".format(tmin_i, data.tmin))
        if tmax_i > data.tmax: print("INFO: The provided tmax {} is greater than data tmax {}.".format(tmax_i, data.tmax))
        segments.append(data.copy().crop(tmin=tmin_i if tmin_i >= data.tmin else data.tmin,
            tmax=tmax_i if tmax_i <= data.tmax else data.tmax, include_tmax=True))
    data = mne.concatenate_raws(segments)
    # TODO: Remove the last annotation which have 0 duration.
    annotations = [annotation_i for annotation_i in data.annotations if annotation_i["duration"] > 0.]
    annotations = mne.Annotations(
        onset=[(annotation_i["onset"] - data.first_time) for annotation_i in annotations],
        duration=[annotation_i["duration"] for annotation_i in annotations],
        description=[annotation_i["description"] for annotation_i in annotations],
        orig_time=annotations[0]["orig_time"] if len(annotations) > 0 else None,
    ); data = mne.io.RawArray(data._data, data.info, first_samp=0)
    assert len(data.annotations) == 0; data.set_annotations(annotations)
    # Return the final `data`.
    return data

# def filter_spectrum func
def filter_spectrum(data, filter_params):
    """
    Filter the parts of a signal's spectrum, using `filter` and `notch_filter`.
    Note: All these operations are operated on the original object to reduce memory consumption.
    Some operations (such as filtering, ICA, etc.) require that the data be copied into RAM; to do that
    we must set `preload` parameter to `True` when using `mne.io.read_raw_brainvision` to load the data.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        filter_params: DotDict - The parameters of filter, containing [band_frange,notch_freqs].
            band_frange: list - The range of band-pass freqs.
            notch_freqs: list - The list of power line noise freqs.

    Returns:
        data: object - The modified `mne.io.RawArray` object.
    """
    # Band-pass filter the original data. If `l_freq` is None, we will get a low-pass filter.
    # If `h_freq` is None, we will get a high-pass filter. Both None leads to no filter.
    data = data.filter(l_freq=filter_params.band_frange[0], h_freq=filter_params.band_frange[1])
    # Notch filter the band-pass filtered data. Use `data.info["sfreq"]` to get the max freq. We should
    # note that `notch_freqs` is caused by the frequency of AC, we should remove all its multiples.
    # TODO: After `notch_filter`, there are still some peaks at high freqency region.
    notch_freqs_ = []; freq_max = data.info["sfreq"] / 2.
    for notch_freq_i in filter_params.notch_freqs:
        factor = 1
        while notch_freq_i * factor < freq_max:
            notch_freqs_.append(notch_freq_i * factor)
            factor += 1
    notch_freqs_.sort()
    # Note: The power line noise indeed affacts the EEG channels, therefore, we use notch filter to EEG signals.
    if len(notch_freqs_) > 0: data = data.notch_filter(freqs=notch_freqs_, picks=mne.pick_types(data.info, eeg=True))
    # Return the final `data`.
    return data

# def resample func
def resample(data, resample_freq):
    """
    Resample the signal to specified resample frequency.
    Note: In cases where high-frequency components of the signal are not of interest and precise timing
    is not needed (e.g. computing EOG or ECG projectors on a long recording), downsampling the signal
    can be a useful time-saver (i.e. saving memory and computational resources when processing the data).

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        resample_freq: float - The frequency of resampled data.

    Returns:
        data: object - The modified `mne.io.RawArray` object.
    """
    # Check whether resample frequency is consistent with current sample frequency.
    if data.info["sfreq"] == resample_freq: return data
    # The resampling methods apply a low-pass filter to the signal to avoid aliasing, so you don't need to
    # explicitly filter it yourself first. The built-in filtering is a brick-wall filter applied in the
    # frequency domain at the Nyquist frequency of the desired new sampling rate (e.g. 100Hz for 200Hz).
    # Because resampling involves filtering, there are some pitfalls to resampling at different points
    # in the analysis stream:
    # 1) Performing resampling on Raw data (before epoching) will negatively affect the temporal precision
    #    of Event arrays, by causing jitter in the event timing. This reduced temporal precision will
    #    propagate to subsequent epoching operations.
    # 2) Performing resampling after epoching can introduce edge artifacts on every epoch, whereas filtering
    #    the Raw object will only introduce artifacts at the start and end of the recording (which is often
    #    far enough from the first and last epochs to have no affect on the analysis).
    # To avoid the reduction in temporal precision of events that comes with resampling a Raw object, and also
    # avoid the edge artifacts that come with filtering an Epochs or Evoked object, the best practice is to:
    # 1) low-pass filter the Raw data at or below 1/3 of the desired sample rate.
    # 2) decimate the data after epoching, by either passing the decim parameter to the Epochs constructor,
    #    or using the decimate() method after the Epochs have been created.
    # Note that this method of manually filtering and decimating is exact only when the original sampling
    # frequency is an integer multiple of the desired new sampling frequency.
    # TODO: Consider the problem introduced by performing resampling on Raw data.
    # >>> RuntimeWarning: Resampling of the stim channels caused event information to become unreliable.
    # >>> Consider finding events on the original data and passing the event matrix as a parameter.
    data = data.filter(l_freq=None, h_freq=resample_freq / 2.).resample(sfreq=resample_freq)
    # Check whether some annotations are out of boundary, update annotations.
    assert data.first_time == 0.; data_range = (data.first_time + data.tmin, data.first_time + data.tmax)
    annotations = [annotation_i for annotation_i in data.annotations if ((annotation_i["duration"] > 0.) and\
        (annotation_i["onset"] >= data_range[0]) and (annotation_i["onset"] + annotation_i["duration"] <= data_range[1]))]
    annotations = mne.Annotations(
        onset=[annotation_i["onset"] for annotation_i in annotations],
        duration=[annotation_i["duration"] for annotation_i in annotations],
        description=[annotation_i["description"] for annotation_i in annotations],
        orig_time=annotations[0]["orig_time"] if len(annotations) > 0 else None,
    ); data.set_annotations(annotations)
    # Return the final `data`.
    return data

# def fix_bad_channels func
def fix_bad_channels(data, bad_channels=[]):
    """
    Fix bad channels from data. The specification of `bad_channels` can be loaded from json file.
    Note: All these operations are operated on the original object to reduce memory consumption.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        bad_channels: list - The list of bad channels to be masked.

    Returns:
        data: object - The modified `mne.io.RawArray` object.
    """
    # Check whether all specified `bad_channels` are in `data.info["ch_names"]`.
    if len(bad_channels) > 0:
        assert np.array([bad_channel_i in data.info["ch_names"] for bad_channel_i in bad_channels], dtype=np.bool_).all(), (
            "ERROR: bad_channels ({}) are not all in data.info[\"ch_names\"] ({})."
        ).format(bad_channels, data.info["ch_names"])
    # If `bad_channels` is `[]`, there is no bad channels.
    # Remove bad channels from the original data.
    data.info["bads"] = bad_channels

    # Get the channel names according to each channel type.
    eeg_channels = np.array(data.ch_names)[mne.pick_types(data.info, eeg=True)].tolist()
    other_channels = list(set(data.ch_names) - set(eeg_channels))
    # Create a separate `Raw` data object for each channel type.
    data_eeg = data.copy().pick_channels(eeg_channels)
    data_other = data.copy().pick_channels(other_channels)
    # Log information of detected bad channels.
    print((
        "INFO: Bad channels (including marked and detected) contain eeg ({})."
    ).format(data_eeg.info["bads"]))
    # Interpolate bad channels using nearby channels around that channel.
    # In some cases simply excluding bad channels is sufficient (for example, if you plan
    # only to analyze a specific sensor ROI, and the bad channel is outside that ROI).
    # However, in cross-subject analyses it is often helpful to maintain the same data
    # dimensionality for all subjects, and there is no guarantee that the same channels
    # will be bad for all subjects. It is possible in such cases to remove each channel
    # that is bad for even a single subject, but that can lead to a dramatic drop in data
    # rank (and ends up discarding a fair amount of clean data in the process). In such
    # cases it is desirable to reconstruct bad channels by interpolating its signal based
    # on the signals of the good sensors around them.
    # Interpolate bad channels separately on `Raw` data object of each channel type.
    if len(data_eeg.info["bads"]) > 0: data_eeg.interpolate_bads(reset_bads=True)
    # Concatenate interpolated `Raw` data objects together using `add_channels`.
    data = data_other.add_channels([data_eeg,])
    # Return the final `data`.
    return data

# def _find_bad_channels func
def _find_bad_channels(data, picks=None, max_iter=1, thres=3):
    """
    Automatically find and mark bad channels. Implements the first step of FASTER algorithm.
    This function attempts to automatically mark bad EEG channels by performing outlier detection.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        picks: list - Channels to operate on. Defaults to EEG channels.
        max_iter: int - The maximum number of iterations performed during outlier detection
            (defaults to 1, as in the original FASTER paper).
        thres: float - The threshold value, in standard deviations, to apply. A channel
            crossing this threshold value is marked as bad. Defaults to 3.

    Returns:
        bad_channels: list - The names of the bad EEG channels.
    """
    # Initialize metrics used to evaluate scores.
    metrics = {
        "variance": lambda x: np.var(x, axis=-1),
        "correlation": lambda x: np.nanmean(
            np.ma.masked_array(
                np.corrcoef(x),
                np.identity(len(x), dtype=bool)
            ), axis=0),
        "kurtosis": lambda x: kurtosis(x, axis=1),
    }
    # Initialize picks corresponding to eeg channels.
    picks = mne.pick_types(data.info, eeg=True, exclude=[]) if picks is None else picks
    # Get the signal according to picks, should only have one channel types.
    data_raw = data.get_data()[picks,:]
    # Calculate scores over picked data, then get the list of bad channels.
    scores = metrics["variance"](data_raw)
    bad_channels = [data.ch_names[picks[i]] for i in _find_outliers(scores, thres, max_iter)]
    # Return the final `bad_channels`.
    return bad_channels

# def set_ref_channels func
def set_ref_channels(data, ref_channels=[]):
    """
    Set reference channels of data. The specification of `ref_channels` can be loaded from json file.
    Note: All these operations are operated on the original object to reduce memory consumption.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        ref_channels: list - The list of reference channels, if empty, use "average" reference.

    Returns:
        data: object - The modified `mne.io.RawArray` object.
    """
    # Check whether all specified `ref_channels` are in `data.info["ch_names"]`.
    if len(ref_channels) > 0:
        assert np.array([ref_channel_i in data.info["ch_names"] for ref_channel_i in ref_channels], dtype=np.bool_).all(), (
            "ERROR: ref_channels ({}) are not all in data.info[\"ch_names\"] ({})."
        ).format(ref_channels, data.info["ch_names"])
    # If `ref_channels` is `[]`, we use `average` as default, which use the mean of all channels as reference.
    ref_channels = ref_channels if len(ref_channels) > 0 else "average"
    # Note: We should remove bad channels first, then all bad channels will not affect the calculation of average.
    # And we should note that the value of bad channels will not be modified during the reference process.
    data.set_eeg_reference(ref_channels=ref_channels)
    # Return the final `data`.
    return data

# def detect_bad_segments func
def detect_bad_segments(data, detect_params):
    """
    Detect and mark bad segments by annotating "bad_segment", which will be rejected when calling some
    functions that have `reject_by_annotations` field. In these functions, the data processing
    overlaping with segments whose description begins with "bad" are rejected.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        detect_params: DotDict - The parameters of bad segments detection, containing [threshold,window_freq].
            threshold: float - The threshold to initialize bad points, used in `_detect_bad_segments`.
            window_freq: float - The reciprocal of the duration of convolution window (unit: second).

    Returns:
        data: object - The modified `mne.io.RawArray` object.
    """
    assert data.first_time == 0.
    # Initialzie the length of convolution window.
    window_len = int(np.ceil(data.info["sfreq"] / detect_params.window_freq))
    # Detect bad segments according to eeg channels.
    # mask_eeg - (seq_len,)
    mask_eeg = _detect_bad_segments(data.get_data(picks=["eeg",]),
        threshold=detect_params.scale_threshold, window_len=window_len)
    # Aggregate bad segments detected from each channel type, using or logic.
    # mask - (seq_len,)
    mask = mask_eeg
    # Split mask into segments for further annotations.
    # segments_start - (n_segments,); segments_duration - (n_segments,)
    segments_start = np.insert(np.where(np.diff(mask.astype(np.int64)))[0] + 1, 0, 0)
    segments_duration = np.insert(np.diff(segments_start), len(segments_start) - 1, len(mask) - segments_start[-1])
    # Only keep bad segments for further annotations.
    # bad_segments_start - (n_bad_segments,); bad_segments_duration - (n_bad_segments,)
    bad_segments_start = segments_start[mask[segments_start]]
    bad_segments_duration = segments_duration[mask[segments_start]]
    # Check whether bad segments are too short.
    duration_scale = 1. / data.info["sfreq"]
    too_short = (bad_segments_duration * duration_scale) < detect_params.duration_threshold
    bad_segments_start = bad_segments_start[~too_short]
    bad_segments_duration = bad_segments_duration[~too_short]
    # Construct annotations from bad_segments_[start,duration]. We use `data.info["sfreq"]`
    # to get `duration_scale` to change the magnitude of `bad_segments_duration`.
    # TODO: The onset of each bad segment must be shifted by `data.first_time`!
    annotations = mne.Annotations(
        onset=bad_segments_start * duration_scale + data.first_time,
        duration=bad_segments_duration * duration_scale,
        description=["bad-segment" for _ in bad_segments_start],
        orig_time=data.annotations.orig_time
    )
    # Set annotations related to bad segments while keeping old annotations.
    # Note: The following warning could happen, but that doesn't affect the effect of annnotationns.
    # >>> RuntimeWarning: Omitted 1 annotation(s) that were outside data range.
    data.set_annotations(data.annotations + annotations)
    # Log information related to the duration of each bad segment.
    print((
        "INFO: {:d} bad segments are detected, with [start, duration]-pairs as {}."
    ).format(len(bad_segments_start), list(zip(
        np.round(annotations.onset, decimals=4),
        np.round(annotations.duration, decimals=4)))))
    # Return the final `data`.
    return data

def _detect_bad_segments(X, threshold=10., window_len=10):
    """
    Detect bad segments according to the magnitude of signal.

    Args:
        X: (n_channels, seq_len) - The recorded meg data (e.g. eeg), which is getted through `data.get_data()`.
            But only one type of meg data is picked, e.g. `X` cannot include different channel types.
        threshold: float - The threshold to initialize bad points (above is rejected) according to `np.abs(X) / X_mad`.
        window_len: int - The length of convolution window, should be relevant to `data.info["sfreq"]`.

    Return:
        mask: (seq_len,) - The mask of rejected bad segments.
    """
    # Calculate the median absolute deviation `X_mad` of `X`.
    # X_mad - (1, seq_len)
    X_mad = np.median(np.abs(X - np.median(X, axis=0, keepdims=True)), axis=0, keepdims=True)
    # Initialize `mask` according to `threshold`.
    # mask - (n_channels, seq_len)
    mask = np.ones_like(X); mask[(np.abs(X) / X_mad) > threshold] = np.nan
    # Execute convolution to create bad segments (each bad point creates `2 * window_len - 1`-wise bad segment).
    # Note: We set the convolution kernel as a simple (1, window_len) ones vector, e.g. the
    # convolution will only be executed along the `seq_len` axix for each channel separately.
    mask = nd.convolve(mask, np.ones((1, window_len)), mode="constant", cval=0.) / window_len
    # Aggregate bad segments detected from each channel, using or logic.
    # mask - (seq_len,)
    mask = np.sum(np.isnan(mask), axis=0) > 0
    # Return the final `mask`.
    return mask

# def run_ssp func
def run_ssp(data, ssp_params, path_output=None):
    """
    Repair ECG artifacts using SSP.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        ssp_params: DotDict - The parameters of signal-space projection (ssp), containing [explained_var_ecg,].
            explained_var_ecg: float - The threshold of explained variance related to ecg artifacts.
        path_output: path - The path of output directory.

    Returns:
        data: object - The modified `mne.io.RawArray` object.
    """
    # Initialize the number of eeg channels.
    n_eeg = len(data.get_channel_types("eeg"))
    # Remove heartbears from EEG data. We firstly visualize ECG artifact before trying to repair it.
    # We have to ensure that there are large deflections at the onset of ECG events.
    # TODO: We get more than one ECG channels:
    # >>> RuntimeWarning: More than one ECG channel found. Using only LECG.
    ecg_evoked = mne.preprocessing.create_ecg_epochs(data).average(picks="all"); ecg_evoked.apply_baseline((None, None))
    with _suppress_logger(): ecg_erp_plot = ecg_evoked.plot_joint(show=False)
    if path_output is None:
        if isinstance(ecg_erp_plot, list):
            for ecg_erp_plot_i in ecg_erp_plot: ecg_erp_plot_i.show()
        else:
            ecg_erp_plot.show()
    else:
        if isinstance(ecg_erp_plot, list):
            for plot_idx, ecg_erp_plot_i in enumerate(ecg_erp_plot):
                ecg_erp_plot_i.savefig(os.path.join(path_output, "ecg.erp.{:d}.pdf".format(plot_idx)))
        else:
            ecg_erp_plot.savefig(os.path.join(path_output, "ecg.erp.pdf"))
    # Compute SSP projectors for heartbeat artifact, and return the requested number of projectors
    # for magnetometers, gradiometers, and EEG channels (default is two projectors for each channel type).
    # `mne.preprocessing.compute_proj_ecg` function will:
    # 1) Filter the ECG data channel.
    # 2) Find ECG R wave peaks using `mne.preprocessing.find_ecg_events`.
    # 3) Filter the raw data.
    # 4) Create `Epochs` around the R wave peaks, capturing the heartbeats.
    # 5) Optionally average the `Epochs` to produce an `Evoked` if `average=True` was passed (default).
    # 6) Calculate SSP projection vectors on that data to capture the artifacts.
    # There are still some things we have to notice:
    # 1) `tmin` (-0.2) & `tmax` (0.4) are set to default values. These are used when creating `Epochs`.
    # 2) `l_freq` & `h_freq` should be set to `None`, as we have already done `filter_spectrum`.
    # 3) `average` is set to True by default, i.e. computing SSP after averaging `Epochs` to get `Evoked`.
    # 4) `reject` is set to `None` to avoid any rejection.
    # 5) `no_proj` is set to True to only return projectors that removes ECG artifacts.
    # 6) `ecg_l_freq` & `ecg_h_freq` are set to default values, as we are not sure whether
    #    changing their values will effect the detection of ECG events.
    # TODO: We get more than one ECG channels:
    # >>> RuntimeWarning: More than one ECG channel found. Using only LECG.
    ecg_projs_eeg, _ = mne.preprocessing.compute_proj_ecg(data, n_eeg=n_eeg,
        tmin=-0.2, tmax=0.4, l_freq=None, h_freq=None, reject=None, no_proj=True)
    # Note: In normal cases, we only have to keep 2~3 SSP projectors (~80% explained variance).
    explained_var_sum_eeg = 0.; ecg_proj_eeg_idx = 0
    while explained_var_sum_eeg < ssp_params.explained_var_ecg:
        explained_var_sum_eeg += ecg_projs_eeg[ecg_proj_eeg_idx]["explained_var"]; ecg_proj_eeg_idx += 1
    ecg_projs_eeg = ecg_projs_eeg[:ecg_proj_eeg_idx]
    assert explained_var_sum_eeg > ssp_params.explained_var_ecg and\
        explained_var_sum_eeg - ecg_projs_eeg[-1]["explained_var"] <= ssp_params.explained_var_ecg
    ecg_projs = []; ecg_projs.extend(ecg_projs_eeg)
    # Visual the sclap distribution.
    with _suppress_logger(): ecg_projs_topomap_plot = mne.viz.plot_projs_topomap(ecg_projs, info=data.info, show=False)
    if path_output is None:
        if isinstance(ecg_projs_topomap_plot, list):
            for ecg_projs_topomap_plot_i in ecg_projs_topomap_plot: ecg_projs_topomap_plot_i.show()
        else:
            ecg_projs_topomap_plot.show()
    else:
        if isinstance(ecg_projs_topomap_plot, list):
            for plot_idx, ecg_projs_topomap_plot_i in enumerate(ecg_projs_topomap_plot):
                ecg_projs_topomap_plot_i.savefig(os.path.join(path_output, "ecg.projs.topomap.{:d}.pdf".format(plot_idx)))
        else:
            ecg_projs_topomap_plot.savefig(os.path.join(path_output, "ecg.projs.topomap.pdf"))
    # Do a joint plot of the projectors and their effect on the time-averaged epochs.
    # 1) The left shows the data traces before (black) and after (green) projection. We can see
    #    that the ECG artifact is well suppressed by one projector per channel type.
    # 2) The center shows the topomaps associated with the projectors, in this case just a single
    #    topography for our one projector per channel type.
    # 3) The right again shows the data traces (black), but this time with those traces also projected onto the first
    #    projector for each channel type (red) plus one surrogate ground truth for an ECG channel (MEG 0111).
    with _suppress_logger():
        ecg_projs_joint_plot = mne.viz.plot_projs_joint(ecg_projs, ecg_evoked, picks_trace="ecg", show=False)
    ecg_projs_joint_plot.suptitle("ECG projectors with explained vars {}.".format(
        ["{:.3f}".format(ecg_proj_i["explained_var"]) for ecg_proj_i in ecg_projs]
    ))
    if path_output is None:
        if isinstance(ecg_projs_joint_plot, list):
            for ecg_projs_joint_plot_i in ecg_projs_joint_plot: ecg_projs_joint_plot_i.show()
        else:
            ecg_projs_joint_plot.show()
    else:
        if isinstance(ecg_projs_joint_plot, list):
            for plot_idx, ecg_projs_joint_plot_i in enumerate(ecg_projs_joint_plot):
                ecg_projs_joint_plot_i.savefig(os.path.join(path_output, "ecg.projs.joint.{:d}.pdf".format(plot_idx)))
        else:
            ecg_projs_joint_plot.savefig(os.path.join(path_output, "ecg.projs.joint.pdf"))
    # As SSP removes one component strictly, EOG is not a beautiful pulse as ECG, so we do not use SSP to remove
    # EOG component to avoid affecting other components. Instead, we use ICA to smoothly remove EOG component.
    # Remove all previous projs, and then add ecg & eog projs.
    data.del_proj().add_proj(ecg_projs, remove_existing=False).apply_proj()
    # Return the final `data`.
    return data

# def run_ica func
def run_ica(data, ica_params, path_output=None):
    """
    Repair signals using ICA.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        ica_params: DotDict - The parameters of independent components analysis (ica), containing [explained_var,seed].
            explained_var: float - The threshold of explained variance related to ecg artifacts.
            seed: int - The random seed for `numpy` random number generator to ensure the reproducibility.
            rerun: bool - Indicate whether re-run ICA.
        path_output: path - The path of output directory.

    Returns:
        data: object - The modified `mne.io.RawArray` object.
    """
    # First time to run ICA.
    if not ica_params.rerun:
        data = _run_ica(data, ica_params, path_output=path_output)
    # After labeling components manually.
    else:
        data = _rerun_ica(data, path_output)
    # Return the final `data`.
    return data

# def _run_ica func
def _run_ica(data, ica_params, path_output=None):
    """
    Automatically repair signals using ICA with the ICLabel model. ICLabel is designed to classify ICs
    fitted with an extended `infomax` ICA decomposition algorithm on MEG datasets referenced to a common
    average and filtered between [1., 100.] Hz. It is possible to run ICLabel on datasets that do not
    meet those specification, but the classification performance might be negatively impacted.
    Note: We use `picard` ICA decomposition algorithm, which may not have such limitation. Therefore,
    we just ignore the requirement of [1., 100.]-Hz filtering before `run_ica`.
    Note: All these operations are operated on the original object to reduce memory consumption.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        ica_params: DotDict - The parameters of independent components analysis (ica), containing [explained_var,seed].
            explained_var: float - The threshold of explained variance related to ecg artifacts.
            seed: int - The random seed for `numpy` random number generator to ensure the reproducibility.
        path_output: path - The path of output directory.

    Returns:
        data: object - The modified `mne.io.RawArray` object.
    """
    ## Future discussion.
    # 1) If we use `infomax` ICA decomposition algorithm, we should filter the original data first.
    #    Because filtering is a linear operation, the ICA solution found from the filtered signal
    #    can be applied to the unfiltered signal, so we'll keep a copy of the unfiltered `data`.
    #    But we use `picard` ICA decomposition algorithm, we directly operate on the unfiltered `data`.
    #    Before fitting ICA, we will apply a common average referencing, to comply with the ICLabel requirements.
    # 2) The ICA algorithms implemented in MNE-python find patterns across channels, but ignore the time domain.
    #    This means you can compute ICA on discontinuous `Epochs` or `Evoked` objects (not just continuous `Raw`
    #    objects), or only use every N-th sample by passing the `decim` parameter to `ICA.fit()`. But the length
    #    of samples indeed affects the time of calculation.
    ## Prepare ICA.
    # Note: `set_eeg_reference` is different from the implementation in matlab. Every time we use it, the value
    # of each channel will be re-calculated. We donot have to worry about adding more parameter to specify this case.
    data = data.set_eeg_reference("average")
    # Get the channel names according to each channel type.
    ecg_channels = np.array(data.ch_names)[mne.pick_types(data.info, ecg=True)].tolist()\
        if "ecg" in data.get_channel_types() else None
    eog_channels = np.array(data.ch_names)[mne.pick_types(data.info, eog=True)].tolist()\
        if "eog" in data.get_channel_types() else None
    emg_channels = np.array(data.ch_names)[mne.pick_types(data.info, emg=True)].tolist()\
        if "emg" in data.get_channel_types() else None
    eeg_channels = np.array(data.ch_names)[mne.pick_types(data.info, eeg=True)].tolist()
    other_channels = list(set(data.ch_names) - set(eeg_channels))
    # Create a separate `Raw` data object for each channel type.
    data_ecg = data.copy().pick_channels(ecg_channels) if ecg_channels is not None else None
    data_eog = data.copy().pick_channels(eog_channels) if eog_channels is not None else None
    data_emg = data.copy().pick_channels(emg_channels) if emg_channels is not None else None
    data_eeg = data.copy().pick_channels(eeg_channels); data_other = data.copy().pick_channels(other_channels)
    # Initialize `corr_dict` according to `data_ecg` & `data_eog` & `data_emg`.
    corr_dict = DotDict()
    if data_ecg is not None: corr_dict["ecg"] = data_ecg.get_data()
    if data_eog is not None: corr_dict["eog"] = data_eog.get_data()
    if data_emg is not None: corr_dict["emg"] = data_emg.get_data()
    ## Run ICA.
    # 1) ICA fitting is not deterministic (e.g., the components may get a sign flip on different runs, or may not always
    #    be returned in the same order), so we'll also specify a random seed so that we get identical results each time.
    # 2) If the EEG artifacts are fairly strong, we would expect those artifacts to be captured in the first few
    #    dimensions of the PCA decomposition that happens before the ICA. However, it is usually preferable to
    #    include more components for a more accurate solution. Here, we almost keep all dimensions of the PCA
    #    decomposition, e.g. `n_channels - 1`. The reason why we substract 1 from `n_channels` is that we have
    #    just used `set_eeg_reference` to remove one degree of freedom, and remove item should be the least important.
    # 3) We use `picard` ICA decomposition algorithm, which doesn't comply with the ICLabal requirements. `infomax`
    #    ICA decomposition algorithm is recommended. ICLabel was not tested with other ICA decomposition algorithm,
    #    but its performance and accuracy should not be impacted by the algorithm. But we change it to `fastica`,
    #    as `fastica` is much faster than `picard`.
    # 4) As the implementation in matlab doesn't use extended `picard` ICA decomposition algorithm, we donot provide
    #    additional `fit_params` to specify extended `picard`. In the tutorial, extended `infomax` is specified.
    # 5) `max_iter` controls the maximum number of iterations during fit. If `auto`, it will set maximum iterations
    #    to 1000 for `fastica` and to 500 for `infomax` or `picard`.
    # 6) ICA only considers [eeg, mag, ...] channels, not consider [eog, stim, ...] channels. As we already initialize
    #    the type of each channel, we can direcly run ICA over `data`. Therefore, not sure how to calculate the total
    #    number of valid components before running ICA. We donot set `n_components` parameter. And we will get:
    #    >>> Fitting ICA to data using 55 channels (please be patient, this may take a while)
    #    >>> Selecting by non-zero PCA components: 53 components
    # 7) In ICA, `decim` parameter is used to select only each N-th sampling point, thus yielding a considerable
    #    speed-up, `reject` parameter provides a rejection dictionary for maximum acceptable peak-to-peak amplitudes
    #    for each channel type. Signal periods exceeding the thresholds in reject or less than the thresholds in
    #    flat will be removed before fitting the ICA.
    # Note: We only apply ICA on `data_eeg`.
    ica_eeg = mne.preprocessing.ICA(n_components=ica_params.n_components,
        max_iter="auto", method="fastica", random_state=ica_params.seed); ica_eeg.fit(data_eeg)
    # Save ICA objects into specified `path_output`.
    if path_output is not None:
        ica_eeg.save(os.path.join(path_output, "data-eeg-ica.fif"), overwrite=True)
    # Select ICA components automatically. We apply the automatic ICA component labeling
    # algorithm, which will assign a probability value for each component being one of
    # [brain, muscle artifact, eye blink, heart beat, line noise, channel noise, ...].
    # Note: `icl_eeg` only contains [y_pred_proba, labels]. `y_pred_proba` only contains
    # the corresponding probability of `labels`, e.g. the max-likely labels. Therefore, we
    # cannot set threshold range for each type of label to execute individual-level selection.
    # But Jinbo says that it can save more signals related to brain, it's better. Especially, in
    # the case where `brain` label is the max-likely label, we donot care the value of other labels.
    icl_eeg = mneicl.label_components(data_eeg, ica_eeg, method="iclabel")
    remove_eeg_icl = [ica_idx for ica_idx, label in enumerate(icl_eeg["labels"]) if label not in ["brain", "other"]]
    # Automatically remove eeg ICA components according to correlation.
    remove_eeg_corr = []; ica_eeg_sources = ica_eeg._transform_raw(data_eeg, start=None, stop=None)
    for ica_idx in range(ica_eeg.n_components_):
        # Initialize `remove_eeg_corr_i` as empty list.
        remove_eeg_corr_i = []
        for key_i, vals_i in corr_dict.items():
            # Calculate pearson correlation according to `corr_dict`.
            corr_pairs_i = [tuple(pearsonr(val_i, ica_eeg_sources[ica_idx,:])) for val_i in vals_i]
            # Check whether correlation value is greater than `corr_thres`.
            remove_eeg_corr_i.extend([(corr_pair_i[0] > ica_params.corr_thres) for corr_pair_i in corr_pairs_i])
        # If any correlation value exceeds `corr_thres`, remove this ica component.
        if np.any(remove_eeg_corr_i): remove_eeg_corr.append(ica_idx)
    # Extract labels and reconstruct raw data.
    remove_eeg = sorted(set(remove_eeg_icl) | set(remove_eeg_corr))
    keep_eeg = sorted(set(range(ica_eeg.n_components_)) - set(remove_eeg))
    print((
        "INFO: For eeg data, keep {} ICA components, remove {} ICA components."
    ).format(keep_eeg, remove_eeg))
    # Create path to store keep & remove ica components.
    path_output_keep = os.path.join(path_output, "keep") if path_output is not None else None
    if path_output_keep is not None and not os.path.exists(path_output_keep): os.makedirs(path_output_keep)
    path_output_remove = os.path.join(path_output, "remove") if path_output is not None else None
    if path_output_remove is not None and not os.path.exists(path_output_remove): os.makedirs(path_output_remove)
    # Plot ica properties & runs jointly.
    events = get_events(data); events = np.stack([
        data.time_as_index(np.array([event_i.onset for event_i in events], dtype=np.float64), use_rounding=True),
        np.array([0 for event_i in events], dtype=np.int64),
        np.array([event_i.marker for event_i in events], dtype=np.int64),
    ], axis=0).T
    events = np.array([event_i for event_i in events if event_i[-1] in ica_params.markers.values()], dtype=np.int64)\
        if ica_params.markers is not None else np.array([event_i for event_i in events], dtype=np.int64)
    try:
        with _suppress_logger():
            ica_properties_eeg_plots = plot_ica_joint(ica_eeg, data_eeg, events,
                icl=icl_eeg, corr_dict=corr_dict, picks=None, show=False)
        for ica_idx, ica_plot_i in enumerate(ica_properties_eeg_plots):
            if path_output_keep is None:
                ica_plot_i.show()
            else:
                if ica_idx in keep_eeg:
                    ica_plot_i.savefig(os.path.join(path_output_keep, "eeg.{:d}.png".format(ica_idx)))
                else:
                    ica_plot_i.savefig(os.path.join(path_output_remove, "eeg.{:d}.png".format(ica_idx)))
    except Exception as e:
        print("ERROR: Cannot plot ica joint figure, due to {}.".format(e))
    # Use `keep_*` & `remove_*` to remove unwanted ICA components.
    # Apply `ica_eeg` & `ica_meg` on `data` separately, which maps different channels.
    # The `exclude` attribute can be provided before executing `ica.apply()`. `exclude` attribute is list or np.array
    # of sources indices to exclude when re-mixing the data in the `ica.apply()` method, i.e. artifactual ICA components.
    # The components identified manually and by the various automatic artifact detection methods should be (manually)
    # appended (e.g. ica.excluded.extend(eog_inds)). (There is also an `exclude` parameter in the `ica.apply()` method.)
    # To scrap all marked components, set this attribute to an empty list.
    # Now we focus on the arguments of `ica.apply()` method:
    # 1) include: list of int - The indices referring to columns in the un-mixing matrix. The components to be kept. If
    #     `None` (default), all components will be included (minus those defined in `ica.exclude`, and the `exclude` parameter.
    # 2) exclude: list of int - The indices referring to columns in the un-mixing matrix. The components to be zeroed out.
    #      If`None` (default) or an empty list, only components from `ica.exclude` will be excluded. Else, the union of
    #     `exclude` and `ica.exclude` will be excluded.
    # 3) n_pca_components: int or float - The number of PCA components to be kept, either absolute (int) or fraction of
    #     the explained variance (float). If `None` (default), the `ica.n_pca_components` from initialization will be used
    #     in MNE (version 0.22); in MNE (version 0.23) all components will be used.
    # 4) start: int or float - First sample to include. If float, data will be interpreted as time in seconds.
    #     If `None`, data will be used from the first sample.
    # 5) end: int or float - Last sample to not include. If float, data will be interpreted as time in seconds.
    #     If `None`, data will be used to the last sample.
    # 6) on_baseline: str - How to handle baseline-corrected epochs or evoked data. Can be "raise" to raise an error,
    #     "warn" (default) to emit a warning, "ignore" to ignore, or "reapply" to reapply the baseline after applying ICA.
    data = ica_eeg.apply(data, include=keep_eeg, exclude=remove_eeg)
    # Return the final `data`.
    return data

# def plot_ica_joint func
def plot_ica_joint(ica, data, events, icl=None, corr_dict=None, picks=None, show=True, show_bad_segments=False):
    """
    Display component properties & runs jointly. Properties include the topography,
    epochs image, ERP/ERF, power spectrum, and epoch variance.

    Args:
        ica : object - The fitted `mne.preprocessing.ICA` solution.
        data : object - The loaded `mne.io.RawArray` object.
        events: (n_events, 3) - The events used to plot ERP.
        icl: dict - The fitted ica labels.
        corr_dict: dict - The data items we want to correlates.
        picks: list - The list of picked components.
        show: bool - Whether show plotted joint figure.
        show_bad_segments: bool - Whether plot bad segments in runs.

    Returns:
        plots: list - The list of matplotlib figures.
    """
    # Initialize `plots` & `picks`.
    plots = []; picks = np.arange(ica.n_components_) if picks is None else picks
    # Initialize `ica_sources` from `ica` & `data`, `epochs` from `data` & `events`.
    # Note: We cannot calculate `ica_sources` by matrix multiplying `ica_components` and `data`!
    # TODO: It seems `ica.get_sources` has some bugs in itself, sometimes we cannot get the full ica data!
    # TODO: `epochs` maybe empty, due to `reject_by_annotation`, may cause errors.
    ica_sources = ica._transform_raw(data, start=None, stop=None)
    epochs = mne.Epochs(data, events, tmin=-0.2, tmax=0.8, preload=True)
    # Initialize blocks from annotations.
    block_ranges = [(annotation_i["onset"], annotation_i["onset"] + annotation_i["duration"])\
        for annotation_i in data.annotations if annotation_i["description"].startswith("block-")]
    n_runs = len(block_ranges) if len(block_ranges) > 0 else 6; run_tmax = 50.
    blocks = [crop(data.copy(), crop_ranges=[block_range_i,]) for block_range_i in block_ranges]\
        if len(block_ranges) > 0 else [crop(data.copy(), crop_ranges=[[
            data.tmin + (run_idx / n_runs) * (data.tmax - data.tmin) + data.first_time,
            data.tmin + ((run_idx + 1) / n_runs) * (data.tmax - data.tmin) + data.first_time,
        ],]) for run_idx in range(n_runs)]
    # TODO: It seems `ica.get_sources` has some bugs in itself, sometimes we cannot get the full ica data!
    ica_sources_blocks = [ica._transform_raw(block_i, start=None, stop=None) for block_i in blocks]
    # Initialize bad segments of each block.
    bad_segments = [[(
        annotation_i["onset"] - block_i.first_time,
        annotation_i["onset"] - block_i.first_time + annotation_i["duration"]
    ) for annotation_i in block_i.annotations if annotation_i["description"] == "bad-segment"] for block_i in blocks]
    bad_segments = [[] for _ in blocks] if not show_bad_segments else bad_segments; other_segments = [[] for _ in blocks]
    for block_idx, (block_i, bad_segments_i) in enumerate(zip(blocks, bad_segments)):
        if len(bad_segments_i) == 0:
            other_segments[block_idx].append([block_i.tmin, block_i.tmax])
        else:
            if bad_segments_i[0][0] > block_i.tmin: other_segments[block_idx].append([block_i.tmin, bad_segments_i[0][0]])
            for bad_segment_idx in range(len(bad_segments_i) - 1):
                other_segments[block_idx].append([bad_segments_i[bad_segment_idx][1], bad_segments_i[bad_segment_idx+1][0]])
            if bad_segments_i[-1][1] < block_i.tmax: other_segments[block_idx].append([bad_segments_i[-1][1], block_i.tmax])
    # Draw joint plots, including properties & runs.
    for pick_i in picks:
        # Initialize `plot_i` & `axes_i` of the joint plot.
        plot_i, axes_i = _create_properties_layout(n_runs=len(blocks))
        # Plot ica properties of specified `pick_i`.
        ica.plot_properties(epochs, picks=[pick_i,], show=show,
            axes=[ax_i for ax_i in axes_i if ax_i.get_label().startswith("ica-properties")])
        # Plot ica runs of specified `pick_i`.
        axes_runs = [ax_i for ax_i in axes_i if ax_i.get_label().startswith("ica-runs")]
        assert len(axes_runs) == len(blocks) == len(ica_sources_blocks) == len(bad_segments) == len(other_segments)
        for ax_i, block_i, ica_sources_block_i, bad_segments_i, other_segments_i in\
            zip(axes_runs, blocks, ica_sources_blocks, bad_segments, other_segments):
            # Initialize ica component of each run & time.
            x_i = block_i.times; y_i = ica_sources_block_i[pick_i,:]
            # Truncate `x_i` & `y_i`, i.e. we only plot the first `run_tmax` seconds.
            y_i = y_i[x_i < run_tmax]; x_i = x_i[x_i < run_tmax]
            # Plot bad & other segments separately.
            for other_segment_i in other_segments_i:
                mask_i = np.logical_and(x_i >= other_segment_i[0], x_i <= other_segment_i[1])
                sns.lineplot(x=x_i[mask_i], y=y_i[mask_i], color="green", ax=ax_i, linewidth=.2)
            for bad_segment_i in bad_segments_i:
                mask_i = np.logical_and(x_i >= bad_segment_i[0], x_i <= bad_segment_i[1])
                sns.lineplot(x=x_i[mask_i], y=y_i[mask_i], color="red", ax=ax_i, linewidth=.2)
            # Set the title of specified run.
            ax_i.set_title((
                "Explained Variance Ratio: {:.2f}%"
            ).format(ica.get_explained_variance_ratio(block_i, components=[pick_i,])["eeg"] * 100.))
        # Plot ica correlation of specified `pick_i`.
        if corr_dict is not None:
            axes_text_corr = [ax_i for ax_i in axes_i if ax_i.get_label() == "ica-properties-topomap"]
            assert len(axes_text_corr) == 1; ax_text_corr = axes_text_corr[0]
            text_i = "" if icl is None else "{}: {:.3f}\n".format(icl["labels"][pick_i], icl["y_pred_proba"][pick_i])
            for key_i, vals_i in corr_dict.items():
                corr_pairs_i = [tuple(pearsonr(val_i, ica_sources[pick_i,:])) for val_i in vals_i]
                text_i += "{}: ".format(key_i.upper())
                for corr_pair_i in corr_pairs_i:
                    text_i += "({:.3f},{:.2f}),".format(corr_pair_i[0], corr_pair_i[1])
                text_i += "\n"
            ax_text_corr.text(-0.15, -0.2, text_i)
        plots.append(plot_i)
    return plots

def _create_properties_layout(figsize=None, n_runs=6):
    """
    Create main figure and axes layout used by plot_ica_properties.

    Args:
        figsize: tuple - The size of figure.
        n_runs: int - The number of block runs during one session, each block run contains multiple trials.

    Returns:
        fig: object - The initialized matplotlib figure.
        axes: list - The list of initialized axes.
    """
    # Initialize layout configuration.
    # figsize - (width, height)
    figsize = [20, 12] if figsize is None else figsize; fig = plt.figure(figsize=figsize, facecolor=[0.95] * 3)
    # size_range - ((width_min, width_max), (height_min, height_max)), i.e. ((left, right), (down, up))
    size_range = DotDict({"ica-properties": ((0., 0.4), (0., 1.)), "ica-runs": ((0.4, 1.), (0., 1.)),})
    for key_i, val_i in size_range.items(): size_range[key_i] = np.array(val_i, dtype=np.float32)
    for key_i, val_i in size_range.items(): size_range[key_i] = np.concatenate([val_i, np.diff(val_i, axis=-1)], axis=-1)
    # The parameters of axes, each item has value in form, (left, down, width, height).
    # The parameters of axes related to ica properties.
    axes_params = DotDict({
        # The axes parameters related to ica properties.
        "ica-properties-topomap": [
            size_range["ica-properties"][0,0] + 0.08 * size_range["ica-properties"][0][2],
            size_range["ica-properties"][1,0] + 0.5 * size_range["ica-properties"][1][2],
            0.3 * size_range["ica-properties"][0][2], 0.45 * size_range["ica-properties"][1][2]
        ],
        "ica-properties-image": [
            size_range["ica-properties"][0,0] + 0.5 * size_range["ica-properties"][0][2],
            size_range["ica-properties"][1,0] + 0.6 * size_range["ica-properties"][1][2],
            0.45 * size_range["ica-properties"][0][2], 0.35 * size_range["ica-properties"][1][2]
        ],
        "ica-properties-erp": [
            size_range["ica-properties"][0,0] + 0.5 * size_range["ica-properties"][0][2],
            size_range["ica-properties"][1,0] + 0.5 * size_range["ica-properties"][1][2],
            0.45 * size_range["ica-properties"][0][2], 0.1 * size_range["ica-properties"][1][2]
        ],
        "ica-properties-spectrum": [
            size_range["ica-properties"][0,0] + 0.08 * size_range["ica-properties"][0][2],
            size_range["ica-properties"][1,0] + 0.1 * size_range["ica-properties"][1][2],
            0.32 * size_range["ica-properties"][0][2], 0.3 * size_range["ica-properties"][1][2]
        ],
        "ica-properties-variance": [
            size_range["ica-properties"][0,0] + 0.5 * size_range["ica-properties"][0][2],
            size_range["ica-properties"][1,0] + 0.1 * size_range["ica-properties"][1][2],
            0.45 * size_range["ica-properties"][0][2], 0.25 * size_range["ica-properties"][1][2]
        ],
    })
    # The parameters of axes related to ica runs.
    y_offset = 1 / n_runs * size_range["ica-runs"][1,2]
    for run_idx in range(n_runs):
        axes_params["ica-runs-{:d}".format(run_idx)] = [
            size_range["ica-runs"][0,0] + 0.05 * size_range["ica-runs"][0,2],
            size_range["ica-runs"][1,0] + (n_runs - 1 - run_idx) * y_offset + 0.2 * y_offset,
            0.9 * size_range["ica-runs"][0,2], 0.6 * y_offset
        ]
    # Add axes to initialized fig object.
    axes = [fig.add_axes(val_i, label=key_i) for key_i, val_i in axes_params.items()]
    # Return the final `fig` & `axes`.
    return fig, axes

# def _rerun_ica func
def _rerun_ica(data, path_output):
    """
    Rerun ICA after labeling components manually, i.e. move png from keep to remove.
    """
    ## Prepare ICA.
    # Initialize [keep,remove] directories in `path_output`.
    path_output_keep = os.path.join(path_output, "keep")
    path_output_remove = os.path.join(path_output, "remove")
    # Read ICA objects from specified `path_output`.
    ica_eeg = mne.preprocessing.read_ica(os.path.join(path_output, "data-eeg-ica.fif"))
    # Read the indices of keep & remove components.
    pattern_eeg = re.compile("eeg.(\d+)")
    keep_eeg = [int(pattern_eeg.findall(fname_i)[0])\
        for fname_i in os.listdir(path_output_keep) if pattern_eeg.match(fname_i) is not None]
    remove_eeg = [int(pattern_eeg.findall(fname_i)[0])\
        for fname_i in os.listdir(path_output_remove) if pattern_eeg.match(fname_i) is not None]
    assert len(keep_eeg) + len(remove_eeg) == ica_eeg.n_components_
    # Get the channel names according to each channel type.
    eeg_channels = np.array(data.ch_names)[mne.pick_types(data.info, eeg=True)].tolist()
    other_channels = list(set(data.ch_names) - set(eeg_channels))
    # Create a separate `Raw` data object for each channel type.
    data_eeg = data.copy().pick_channels(eeg_channels); data_other = data.copy().pick_channels(other_channels)
    ## Re-run ICA.
    # Use `keep_*` & `remove_*` to remove unwanted ICA components.
    # Apply `ica_eeg` & `ica_meg` on `data` separately, which maps different channels.
    # The `exclude` attribute can be provided before executing `ica.apply()`. `exclude` attribute is list or np.array
    # of sources indices to exclude when re-mixing the data in the `ica.apply()` method, i.e. artifactual ICA components.
    # The components identified manually and by the various automatic artifact detection methods should be (manually)
    # appended (e.g. ica.excluded.extend(eog_inds)). (There is also an `exclude` parameter in the `ica.apply()` method.)
    # To scrap all marked components, set this attribute to an empty list.
    # Now we focus on the arguments of `ica.apply()` method:
    # 1) include: list of int - The indices referring to columns in the un-mixing matrix. The components to be kept. If
    #     `None` (default), all components will be included (minus those defined in `ica.exclude`, and the `exclude` parameter.
    # 2) exclude: list of int - The indices referring to columns in the un-mixing matrix. The components to be zeroed out.
    #      If`None` (default) or an empty list, only components from `ica.exclude` will be excluded. Else, the union of
    #     `exclude` and `ica.exclude` will be excluded.
    # 3) n_pca_components: int or float - The number of PCA components to be kept, either absolute (int) or fraction of
    #     the explained variance (float). If `None` (default), the `ica.n_pca_components` from initialization will be used
    #     in MNE (version 0.22); in MNE (version 0.23) all components will be used.
    # 4) start: int or float - First sample to include. If float, data will be interpreted as time in seconds.
    #     If `None`, data will be used from the first sample.
    # 5) end: int or float - Last sample to not include. If float, data will be interpreted as time in seconds.
    #     If `None`, data will be used to the last sample.
    # 6) on_baseline: str - How to handle baseline-corrected epochs or evoked data. Can be "raise" to raise an error,
    #     "warn" (default) to emit a warning, "ignore" to ignore, or "reapply" to reapply the baseline after applying ICA.
    data = ica_eeg.apply(data, include=keep_eeg, exclude=remove_eeg)
    print((
        "INFO: For eeg data, keep {} ICA components, remove {} ICA components."
    ).format(keep_eeg, remove_eeg))
    # Return the final `data`.
    return data

## Data preprocess package.
# def preprocess func
def preprocess(data, preprocess_params, path_output=None):
    """
    The whole pipleline to preprocess eeg data of specified session.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        preprocess_params: DotDict - The parameters of data preprocessing.
        path_output: path - The path of output directory.

    Returns:
        data: object - The modified `mne.io.RawArray` object.
    """
    ## Data initialization.
    # Initialize the type of each channel. `preprocess_params.channel_types` is `None` by default,
    # and should be modified when loading from json config files of one meg run.
    data = init_channel_types(data, channel_types=preprocess_params.channel_types,
        allowed_channel_types=preprocess_params.allowed_channel_types)
    # Initialize the locations of channels.
    data = init_channel_montage(data, preprocess_params.path_montage)
    ## Data preprocess.
    # Crop data to remove unused data segments & channels.
    data = crop(data, crop_ranges=preprocess_params.crop_ranges)
    # Filter parts of signals' spectrum.
    data = filter_spectrum(data, preprocess_params.filter)
    # Downsample data to specified sample frequency.
    data = resample(data, resample_freq=preprocess_params.resample_freq)
    # Fix [pre-defined, detected] bad channels of data.
    data = fix_bad_channels(data, bad_channels=preprocess_params.bad_channels)
    # Set reference channels.
    data = set_ref_channels(data, ref_channels=preprocess_params.ref_channels)
    # Mark bad segments of data using annotations.
    data = detect_bad_segments(data, preprocess_params.bad_segments_detection)
    # Run SSP over data.
    if preprocess_params.ssp is not None:
        path_output_ssp = os.path.join(path_output, "ssp") if path_output is not None else None
        if path_output_ssp is not None and not os.path.exists(path_output_ssp): os.makedirs(path_output_ssp)
        data = run_ssp(data, preprocess_params.ssp, path_output=path_output_ssp)
    # Run ICA over data.
    path_output_ica = os.path.join(path_output, "ica") if path_output is not None else None
    if path_output_ica is not None and not os.path.exists(path_output_ica): os.makedirs(path_output_ica)
    data = run_ica(data, preprocess_params.ica, path_output=path_output_ica)
    ## Data visualization.
    # Show examples of preprocessed data.
    show_examples(data, path_output=path_output)
    # Return the final `data`.
    return data

## Data Visualization.
# def show_examples func
def show_examples(data, path_output=None):
    """
    Show example plots of data.

    Args:
        data: object - The loaded `mne.io.RawArray` object.
        path_output: path - The path of output directory.

    Returns:
        None
    """
    ## Plot psd figure.
    # Plot non-averaged psd figure.
    try:
        with _suppress_logger(): plot = data.plot_psd(average=False, show=False)
        if path_output is None:
            plot.show()
        else:
            plot.savefig(fname=os.path.join(path_output, "psd.nonaverage.pdf"))
    except Exception as e:
        print("ERROR: Cannot plot non-average psd figure, due to {}.".format(e))
    # Plot averaged psd figure.
    try:
        with _suppress_logger(): plot = data.plot_psd(average=True, show=False)
        if path_output is None:
            plot.show()
        else:
            plot.savefig(fname=os.path.join(path_output, "psd.average.pdf"))
    except Exception as e:
        print("ERROR: Cannot plot average psd figure, due to {}.".format(e))
    # Close all plotted figures.
    plt.close("all")

if __name__ == "__main__":
    import time, json
    import copy as cp

    # Initialize base path & output path.
    base = os.path.join(os.getcwd(), os.pardir, os.pardir, os.pardir)
    path_output = os.path.join(os.getcwd(), "output.session")
    if not os.path.exists(path_output): os.makedirs(path_output)
    # Initialize eeg data path & the corresponding preprocess parameters.
    path_run = os.path.join(base, "data", "eeg.anonymous", "005", "20221223")
    path_run_eeg = os.path.join(path_run, "eeg", "image-audio-pre.vhdr")
    path_run_montage = os.path.join(path_run, "standard-1020-cap64.locs")
    path_run_params = os.path.join(path_run, "preprocess_params.json")
    # Initialize the parameters of preprocess.
    with open(path_run_params, "r") as f:
        preprocess_params = DotDict(json.load(f))
    for key_i in default_preprocess_params.keys():
        if key_i not in preprocess_params.keys():
            setattr(preprocess_params, key_i, getattr(default_preprocess_params, key_i))
    preprocess_params.path_montage = path_run_montage; preprocess_params.ica.rerun = False
    # The markers of ica are image markers.
    preprocess_params.markers = DotDict({
        # The image/audio cue markers.
        "image": {"alarm": 4, "apple": 5, "ball": 16, "book": 17, "box": 20, "chair": 21, "kiwi": 32, "microphone": 33,
            "motorcycle": 36, "pepper": 37, "sheep": 48, "shoes": 49, "strawberry": 52, "tomato": 53, "watch": 64,},
        "audio": {"alarm": 65, "apple": 68, "ball": 69, "book": 80, "box": 81, "chair": 84, "kiwi": 85, "microphone": 96,
            "motorcycle": 97, "pepper": 100, "sheep": 101, "shoes": 112, "strawberry": 113, "tomato": 116, "watch": 117,},
        # The response markers.
        "response": {"resp_corr": 144, "resp_wrong": 145, "resp_none": 148,},
        # The block markers.
        "block": {"block_start": 208, "block_end": 209,},
        # The trial markers.
        "trial": {"trial_cross": 1,},
        # The unlabeled markers.
        "unlabeled": [128, 129, 133, 149],
    })
    preprocess_params.ica.markers = preprocess_params.markers.image
    print("INFO: The parameters of preprocess is initialized as {}.".format(preprocess_params))

    # Record the start time.
    time_start = time.time()
    # Read the raw data from `vhdr` file.
    tmin = 0.; duration = None; data = load_data(path_run_eeg); events = get_events(data)
    # At the very first, we use block markers to identify block runs.
    # TODO: The following operation is special for task data.
    events = get_events(data)
    block_start_events = [event_i for event_i in events if event_i.marker == preprocess_params.markers.block.block_start]
    block_end_events = [event_i for event_i in events if event_i.marker == preprocess_params.markers.block.block_end]
    assert len(block_start_events) == len(block_end_events)
    block_events = [(block_start_event_i.onset, block_end_event_i.onset - block_start_event_i.onset)\
        for block_start_event_i, block_end_event_i in zip(block_start_events, block_end_events)]
    block_annotations = mne.Annotations(
        onset=[block_event_i[0] for block_event_i in block_events],
        duration=[block_event_i[1] for block_event_i in block_events],
        description=["block-{:d}".format(block_idx) for block_idx in range(len(block_events))],
        orig_time=data.annotations.orig_time
    ); data.set_annotations(data.annotations + block_annotations)
    # Crop unused data segments, but actually we should use concatenate blocks.
    data = crop(data, crop_ranges=[[events[0].onset, events[-1].onset],])
    # Crop specified data segments from remained data.
    data = crop(data, crop_ranges=[[tmin+data.first_time, tmin+duration+data.first_time],]) if duration is not None else data
    # Use events of data to get `crop_ranges`.
    events = get_events(data); preprocess_params.crop_ranges = [[events[0].onset, events[-1].onset],]
    # Execute preprocess pipeline.
    data = preprocess(data, preprocess_params, path_output=path_output)
    # Record the whole time of preprocess.
    print("INFO: The total time of preprocessing data is {:.2f}s.".format(time.time()-time_start))

