#!/usr/bin/env python3
"""
Created on 23:59, Jun. 3rd, 2023

@author: Anonymous
"""
import re, mne
import json, yasa
import logging
import contextlib
import copy as cp
import numpy as np
from bidict import bidict
import mne_icalabel as mneicl
from scipy.stats import pearsonr
from sortedcontainers import SortedSet
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir, os.pardir))
    from preprocess.eeg.anonymous import session
else:
    from preprocess.eeg.anonymous import session
from utils import DotDict
from utils.data import save_pickle

__all__ = [
    "preprocess_task",
    "preprocess_tmr",
]

# Initialize environment configuration.
mne.set_log_level("WARNING")
# Define tool functions.
@contextlib.contextmanager
def _suppress_logger():
    """Temporarily suppress the logger output."""
    logger = logging.getLogger("mne")
    original_level = logger.level
    logger.setLevel(logging.ERROR)
    try:
        yield
    finally:
        logger.setLevel(original_level)

## Macros
# def default_task_preprocess_params macro
default_task_preprocess_params = cp.deepcopy(session.default_preprocess_params)
default_task_preprocess_params.markers = DotDict({
    # The image/audio cue markers.
    "image": {"alarm": 1, "apple": 2, "ball": 3, "book": 4, "box": 5, "chair": 6, "kiwi": 7, "microphone": 8,
        "motorcycle": 9, "pepper": 10, "sheep": 11, "shoes": 12, "strawberry": 13, "tomato": 14, "watch": 15,},
    "audio": {"alarm": 101, "apple": 102, "ball": 103, "book": 104, "box": 105, "chair": 106, "kiwi": 107, "microphone": 108,
        "motorcycle": 109, "pepper": 110, "sheep": 111, "shoes": 112, "strawberry": 113, "tomato": 114, "watch": 115,},
    # The response markers.
    "response": {
        "resp_corr_consistant": 16, "resp_corr_inconsistant": 17,
        "resp_wrong_consistant": 18, "resp_wrong_inconsistant": 19,
        "resp_none_consistant": 20, "resp_none_inconsistant": 21,
    },
    # The block markers.
    "block": {"block_start": 22, "block_end": 23,},
    # The trial markers.
    "trial": {"trial_cross": 25,},
    # The unlabeled markers.
    "unlabeled": [26, 151, 152, 153, 154, 201, 202, 203, 204, 205],
})
default_task_preprocess_params.epochs = DotDict({
    "image-audio": [(-0.2, 0.8+0.3), (-0.2, 0.8)],
    "audio-image": [(-0.2, 0.5+0.3), (-0.2, 0.8)],
})
# TODO: The following parameters defined in session.py is changed!
default_task_preprocess_params.bad_segments_detection = DotDict({
    # The threshold to initialize bad points, used in `_detect_bad_segments`.
    "scale_threshold": 50.,
    # The threshold to ignore too short bad segments.
    "duration_threshold": 0.2,
    # The reciprocal of the duration of convolution window (unit: second).
    # Note: `window_freq` should be `resample_freq / 10`.
    "window_freq": 10,
})
# def default_tmr_preprocess_params macro
default_tmr_preprocess_params = cp.deepcopy(session.default_preprocess_params)
default_tmr_preprocess_params.markers = DotDict({
    # The audio cue markers.
    "audio": {"alarm": 101, "apple": 102, "ball": 103, "book": 104, "box": 105, "chair": 106, "kiwi": 107, "microphone": 108,
        "motorcycle": 109, "pepper": 110, "sheep": 111, "shoes": 112, "strawberry": 113, "tomato": 114, "watch": 115,},
})
default_tmr_preprocess_params.epochs = [-0.2, 4.]
default_tmr_preprocess_params.sls = DotDict({
    # The reference eeg channel & the corresponding reference channel for sleep stage.
    "eeg_channel": "C3", "eeg_ref_channel": "M2",
    # The reference eog channel & the corresponding reference channel for sleep stage.
    "eog_channel": "EOG", "eog_ref_channel": "M2",
    # The reference emg channel & the corresponding reference channel for sleep stage.
    "emg_channel": None, "emg_ref_channel": None,
})
default_tmr_preprocess_params.ica.markers = default_tmr_preprocess_params.markers.audio
default_tmr_preprocess_params.crop_padding = (5., 2.)
default_tmr_preprocess_params.interval_threshold = 5.
# TODO: The following parameters defined in session.py is changed!
default_tmr_preprocess_params.bad_segments_detection = DotDict({
    # The threshold to initialize bad points, used in `_detect_bad_segments`.
    "scale_threshold": 50.,
    # The threshold to ignore too short bad segments.
    "duration_threshold": 0.2,
    # The reciprocal of the duration of convolution window (unit: second).
    # Note: `window_freq` should be `resample_freq / 10`.
    "window_freq": 10,
})

## Data Preprocess.
# def preprocess_task func
def preprocess_task(path_run, rerun=False, seed=42):
    """
    The whole pipeline to preprocess task eeg data of specified run, e.g. [ant-001,...].

    Args:
        path_run: str - The path of specified eeg run.
        rerun: bool - The flag indicates whether rerun ICA.
        seed: int - The random seed used to initialize ICA.

    Returns:
        None
    """
    # Initialize dataset & path_run_dataset.
    dataset = DotDict(); path_run_dataset = os.path.join(path_run, "dataset.task")
    # Initialize the list of allowed session types.
    allowed_session_types = ["image-audio-pre", "image-audio-post", "audio-image-pre", "audio-image-post"]
    for session_type_i in allowed_session_types:
        dataset[session_type_i] = _preprocess_task(path_run,
            session_type=session_type_i, rerun=rerun, seed=seed)
    # Save data to dataset.
    save_pickle(path_run_dataset, dataset)

# def _preprocess_task func
def _preprocess_task(path_run, session_type="image-audio-pre", rerun=False, seed=42):
    """
    The whole pipeline to preprocess task eeg data of specified run/session_type, e.g. [ant-001,...].

    Args:
        path_run: str - The path of specified eeg run.
        session_type: str - The type of specified session.
        rerun: bool - The flag indicates whether rerun ICA.
        seed: int - The random seed used to initialize ICA.

    Returns:
        dataset: (n_samples[list],) - The list of data pairs.
    """
    # Initialize path of params & behavior & eeg & montage & output.
    path_run_params = os.path.join(path_run, "preprocess_params.json")
    path_run_montage = os.path.join(path_run, "standard-1020-cap64.locs")
    path_run_behavior = os.path.join(path_run, "behavior"); path_run_eeg = os.path.join(path_run, "eeg")
    path_run_output = os.path.join(path_run, "output.run_va1.task")
    if not os.path.exists(path_run_output): os.makedirs(path_run_output)
    path_run_output_session = os.path.join(path_run_output, session_type)
    if not os.path.exists(path_run_output_session): os.makedirs(path_run_output_session)
    # Load default preprocess_params.
    default_preprocess_params = cp.deepcopy(default_task_preprocess_params)
    with open(path_run_params, "r") as f:
        preprocess_params = DotDict(json.load(f))
    for key_i in default_preprocess_params.keys():
        if key_i not in preprocess_params.keys():
            setattr(preprocess_params, key_i, getattr(default_preprocess_params, key_i))
    preprocess_params.path_montage = path_run_montage; preprocess_params.ica.rerun = rerun; preprocess_params.ica.seed = seed
    preprocess_params.ica.markers = preprocess_params.markers[session_type.split("-")[0]]
    print("INFO: The parameters of preprocess is initialized as {}.".format(preprocess_params))
    # Save the initialized preprocess_params for further future analysis.
    save_pickle(os.path.join(path_run_output_session, "preprocess_params"), preprocess_params)
    # Load data from specified eeg data path.
    eeg_fnames = sorted([fname_i for fname_i in os.listdir(path_run_eeg)\
        if fname_i.startswith(session_type) and fname_i.endswith(".vhdr")])
    print("INFO: The detected eeg file names labelled by task are {}.".format(eeg_fnames))
    # Check whether file exists, if not, directly return [].
    if len(eeg_fnames) == 0: return []
    # Initialize data as a `Raw` data object.
    # TODO: We directly concatenate all eeg data segments, ignoring the splits!
    data = mne.concatenate_raws([session.load_data(os.path.join(path_run_eeg, eeg_fname_i)) for eeg_fname_i in eeg_fnames])\
        if len(eeg_fnames) > 1 else session.load_data(os.path.join(path_run_eeg, eeg_fnames[0]))
    # At the very first, we use block markers to identify block runs.
    events = session.get_events(data)
    block_start_events = [event_i for event_i in events if event_i.marker == preprocess_params.markers.block.block_start]
    block_end_events = [event_i for event_i in events if event_i.marker == preprocess_params.markers.block.block_end]
    if len(block_start_events) != len(block_end_events):
        print((
            "WARNING: The number of block start events ({:d}) is different from the number of block end events ({:d})."+
            "\nThe block start events are\n\t{}.\nThe block end events are\n\t{}."
        ).format(len(block_start_events), len(block_end_events), block_start_events, block_end_events))
        assert abs(len(block_start_events) - len(block_end_events)) == 1
        # If missing one end event, append one to the end.
        if len(block_start_events) > len(block_end_events):
            block_end_events.append(events[-1])
        # If missing one start event, insert one at the start.
        else:
            block_start_events.insert(0, events[0])
    block_events = [(block_start_event_i.onset, block_end_event_i.onset - block_start_event_i.onset)\
        for block_start_event_i, block_end_event_i in zip(block_start_events, block_end_events)]
    block_annotations = mne.Annotations(
        onset=[block_event_i[0] for block_event_i in block_events],
        duration=[block_event_i[1] for block_event_i in block_events],
        description=["block-{:d}".format(block_idx) for block_idx in range(len(block_events))],
        orig_time=data.annotations.orig_time
    )
    data.set_annotations(mne.Annotations(
        onset=[annotation_i["onset"] - data.first_time for annotation_i in data.annotations],
        duration=[annotation_i["duration"] for annotation_i in data.annotations],
        description=[annotation_i["description"] for annotation_i in data.annotations],
        orig_time=data.annotations.orig_time
    ) + block_annotations)
    # Then crop data according to blocks with no crop padding.
    crop_ranges = [(block_start_event_i.onset, block_end_event_i.onset)\
        for block_start_event_i, block_end_event_i in zip(block_start_events, block_end_events)]
    data = session.crop(data, crop_ranges=crop_ranges)
    # Find events from un-preprocessed data, then use trials to re-organize them.
    events = session.get_events(data); trials = []
    for event_i in events:
        if event_i.marker == preprocess_params.markers.trial.trial_cross:
            trials.append([event_i,])
        elif len(trials) == 0:
            continue
        else:
            trials[-1].append(event_i)
    # Remove trials that contains buffer overflow event.
    trials_ = []; dropped_patterns = ["New Segment/", "Comment/Buffer Overflow"]
    for trial_i in trials:
        if np.any([(event_i.description in dropped_patterns) for event_i in trial_i]): continue
        trials_.append(trial_i)
    print((
        "INFO: Drop {:d} trials according to unknown markers, the number of remaining trials is {:d}."
    ).format(len(trials)-len(trials_), len(trials_)))
    trials = trials_
    # Make sure each trial in `trials` are complete!
    trials_ = []
    for trial_i in trials:
        # Check whether there exists 1 image markers.
        image_markers_i = [event_i.marker for event_i in trial_i\
            if event_i.marker in preprocess_params.markers.image.values()]
        if len(image_markers_i) != 1: continue
        # Check whether there exists 1 audio markers.
        audio_markers_i = [event_i.marker for event_i in trial_i\
            if event_i.marker in preprocess_params.markers.audio.values()]
        if len(audio_markers_i) != 1: continue
        # Check whether there exists 1 trial markers.
        trial_markers_i = [event_i.marker for event_i in trial_i\
            if event_i.marker in preprocess_params.markers.trial.values()]
        if len(trial_markers_i) != 1: continue
        # Pass all checks, then append it to `trials_`.
        trials_.append(trial_i)
    print((
        "INFO: Drop {:d} trials according to trial checks, the number of remaining trials is {:d}."
    ).format(len(trials)-len(trials_), len(trials_)))
    trials = trials_
    # Execute preprocess over the whole data.
    data = session.preprocess(data, preprocess_params, path_output=path_run_output_session)
    # Save the preprocessed data for further future analysis. There are some argument settings we have to note:
    # 1) fmt: "auto" (default). Format to export. Could be one of ["auto", "brainvision", "edf", "eeglab"].
    # 2) physical_range: "auto" (default). Only used for exporting EDF files.
    # 3) add_ch_type: False (default). Only used for exporting EDF files.
    # 4) overwrite: True. If True (default False), overwrite the destination file if it exists.
    mne.export.export_raw(os.path.join(path_run_output_session, "data.vhdr"), data, overwrite=True)
    # Get bad segments from preprocessed data.
    bad_segments = [(annotation_i["onset"], annotation_i["duration"])\
        for annotation_i in data.annotations if annotation_i["description"] == "bad-segment"]
    # Use `bad_segments` to drop trials that overlap with bad segments.
    trials_ = []; trial_idx = 0; segment_idx = 0
    if len(bad_segments) > 0:
        while trial_idx < len(trials):
            # Ensure that the end time of `bad_segment_i` is greater than the start time of current trial.
            while (bad_segments[segment_idx][0] + bad_segments[segment_idx][1] < trials[trial_idx][0].onset) and\
                  (segment_idx + 1 < len(bad_segments)): segment_idx += 1
            # The current trial has no overlap with `bad_segment_i`.
            if (bad_segments[segment_idx][0] > trials[trial_idx][-1].onset) or\
               (bad_segments[segment_idx][0] + bad_segments[segment_idx][1] < trials[trial_idx][0].onset):
                trials_.append(trials[trial_idx]); trial_idx += 1
            # Check whether the start time of `bad_segment_i` is smaller than the end time of current trial.
            else:
                trial_idx += 1
        # Calculate the differences between the endpoints of the ranges, then
        # check if any of the differences have opposite signs (i.e., overlap).
        if len(trials_) > 0:
            trials_ranges = np.array([(trial_i[0].onset, trial_i[-1].onset) for trial_i in trials_], dtype=np.float32)
            bad_segments_ranges = np.array(bad_segments, dtype=np.float32); bad_segments_ranges[:,1] += bad_segments_ranges[:,0]
            differences = np.subtract.outer(trials_ranges, bad_segments_ranges)
            assert differences.shape[-1] == 2 and not np.any(np.sign(differences[:,:,:,0]) != np.sign(differences[:,:,:,1]))
    else:
        trials_ = trials
    print((
        "INFO: Drop {:d} trials according to bad segments, the number of remaining trials is {:d}."
    ).format(len(trials)-len(trials_), len(trials_)))
    trials = trials_
    # If the number of remaining trials is 0, directly return [].
    if len(trials) == 0: return []
    # Before epoching data, use `events` to get pure cue events `cue_events`.
    # Then set events of `data_lvbj`, and finally save that MNE object.
    image_markers = bidict(preprocess_params.markers.image).inverse
    audio_markers = bidict(preprocess_params.markers.audio).inverse
    cue_events = [event_i for trial_i in trials for event_i in trial_i\
        if (event_i.marker in image_markers.keys()) or (event_i.marker in audio_markers.keys())]
    data_lvbj = session.set_events(data.copy(), cue_events)
    mne.export.export_raw(os.path.join(path_run_output_session, "data-lvbj.vhdr"), data_lvbj, overwrite=True)
    # Epoch data from `trials`, only using cue markers.
    dataset = []; epoch_range = preprocess_params.epochs["-".join(session_type.split("-")[:2])]
    image_events = np.squeeze(np.array([[[
        data.time_as_index(event_i.onset, use_rounding=True)[0], 0, event_i.marker
    ] for event_i in trial_i if event_i.marker in image_markers.keys()] for trial_i in trials], dtype=np.int64))
    audio_events = np.squeeze(np.array([[[
        data.time_as_index(event_i.onset, use_rounding=True)[0], 0, event_i.marker
    ] for event_i in trial_i if event_i.marker in audio_markers.keys()] for trial_i in trials], dtype=np.int64))
    # Note: We set `baseline=(0,0)`, i.e. we do not use baseline correction when creating Epochs.
    cue_types = session_type.split("-")[:2]
    image_epochs = mne.Epochs(data, image_events, tmin=epoch_range[cue_types.index("image")][0],
        tmax=epoch_range[cue_types.index("image")][1]-(1./data.info["sfreq"]), baseline=(0, 0), preload=True)
    image_epochs = image_epochs.pick(picks=mne.pick_types(image_epochs.info, eeg=True))
    image_epochs = [DotDict({"name":image_markers[image_event_i],"data":image_epoch_i,})\
        for image_event_i, image_epoch_i in zip(image_events[:,-1], image_epochs.get_data())]
    audio_epochs = mne.Epochs(data, audio_events, tmin=epoch_range[cue_types.index("audio")][0],
        tmax=epoch_range[cue_types.index("audio")][1]-(1./data.info["sfreq"]), baseline=(0, 0), preload=True)
    audio_epochs = audio_epochs.pick(picks=mne.pick_types(audio_epochs.info, eeg=True))
    audio_epochs = [DotDict({"name":audio_markers[audio_event_i],"data":audio_epoch_i,})\
        for audio_event_i, audio_epoch_i in zip(audio_events[:,-1], audio_epochs.get_data())]
    # Construct data from [image,audio]_epochs.
    for image_epoch_i, audio_epoch_i in zip(image_epochs, audio_epochs):
        dataset.append(DotDict({"image":image_epoch_i,"audio":audio_epoch_i,}))
    # Return the final `dataset`.
    return dataset

# def preprocess_tmr func
def preprocess_tmr(path_run, rerun=False, seed=42):
    """
    The whole pipeline to preprocess tmr eeg data of specified run, e.g. [ant-001,...].

    Args:
        path_run: str - The path of specified eeg run.
        rerun: bool - The flag indicates whether rerun ICA.
        seed: int - The random seed used to initialize ICA.

    Returns:
        None
    """
    # Initialize path of params & behavior & eeg & montage & output & dataset.
    path_run_params = os.path.join(path_run, "preprocess_params.json")
    path_run_montage = os.path.join(path_run, "standard-1020-cap64.locs")
    path_run_behavior = os.path.join(path_run, "behavior"); path_run_eeg = os.path.join(path_run, "eeg")
    path_run_output = os.path.join(path_run, "output.run_va1.tmr"); path_run_dataset = os.path.join(path_run, "dataset.tmr")
    if not os.path.exists(path_run_output): os.makedirs(path_run_output)
    path_run_output_n23 = os.path.join(path_run_output, "N23"); path_run_output_rem = os.path.join(path_run_output, "REM")
    if not os.path.exists(path_run_output_n23): os.makedirs(path_run_output_n23)
    if not os.path.exists(path_run_output_rem): os.makedirs(path_run_output_rem)
    # Load default preprocess_params.
    default_preprocess_params = cp.deepcopy(default_tmr_preprocess_params)
    with open(path_run_params, "r") as f:
        preprocess_params = DotDict(json.load(f))
    for key_i in default_preprocess_params.keys():
        if key_i not in preprocess_params.keys():
            setattr(preprocess_params, key_i, getattr(default_preprocess_params, key_i))
    preprocess_params.path_montage = path_run_montage; preprocess_params.ica.rerun = rerun; preprocess_params.ica.seed = seed
    print("INFO: The parameters of preprocess is initialized as {}.".format(preprocess_params))
    # Load data from specified eeg data path.
    eeg_fnames = sorted([fname_i for fname_i in os.listdir(path_run_eeg)\
        if fname_i.startswith("tmr") and fname_i.endswith(".vhdr")])
    print("INFO: The detected eeg file names labelled by tmr are {}.".format(eeg_fnames))
    if len(eeg_fnames) == 0: return
    # Initialize data as a `Raw` data object.
    # TODO: We directly concatenate all eeg data segments, ignoring the splits!
    data = mne.concatenate_raws([session.load_data(os.path.join(path_run_eeg, eeg_fname_i)) for eeg_fname_i in eeg_fnames])\
        if len(eeg_fnames) > 1 else session.load_data(os.path.join(path_run_eeg, eeg_fnames[0]))
    # Preprocess eeg data, then set the corresponding file of dataset.
    dataset = DotDict()
    # Execute yasa sleep staging.
    data = _sleep_stage(data, preprocess_params.sls)
    # Split data according to yasa sleep staging.
    sleep_stages = DotDict({"N2/3":[],"REM":[],})
    for annotation_i in [annotation_i for annotation_i in data.annotations if annotation_i["description"].startswith("sleep-")]:
        stage_type = annotation_i["description"].split("-")[-1]
        if stage_type in ["N2", "N3"]:
            sleep_stages["N2/3"].append((annotation_i["onset"], annotation_i["onset"] + annotation_i["duration"]))
        elif stage_type == "REM":
            sleep_stages["REM"].append((annotation_i["onset"], annotation_i["onset"] + annotation_i["duration"]))
    sleep_stages["N2/3"] = _merge_overlapping_ranges(sleep_stages["N2/3"])
    sleep_stages["REM"] = _merge_overlapping_ranges(sleep_stages["REM"])
    # Construct data only from ranges that include tmr events.
    data_n23 = []; data_rem = []; crop_padding = preprocess_params.crop_padding; audio_markers = preprocess_params.markers.audio
    for sleep_range_i in sleep_stages["N2/3"]:
        ## Prepare data crop.
        # Crop data range, then get events & markers.
        data_i = session.crop(data.copy(), crop_ranges=[sleep_range_i,])
        events_i = session.get_events(data_i); markers_i = set([event_i.marker for event_i in events_i])
        # If current data slice doesn't include tmr events, skip it.
        if len(markers_i & set(audio_markers.values())) == 0: continue
        ## Start & end tmr event detection.
        # Pop events, so that `events_i` starts & ends with tmr event.
        while events_i[0].marker not in audio_markers.values(): events_i.pop(0)
        while events_i[-1].marker not in audio_markers.values(): events_i.pop(-1)
        # Use the first & last event to further crop data.
        start_time_i = events_i[0].onset - (crop_padding[0] + np.random.uniform(-crop_padding[1], crop_padding[1]))
        end_time_i = events_i[-1].onset + (crop_padding[0] + np.random.uniform(-crop_padding[1], crop_padding[1]))
        data_i = session.crop(data_i.copy(), crop_ranges=[(start_time_i, end_time_i),])
        ## Long interval detection.
        # Since there are long intervals during tmr, we use `interval_threshold` to crop data.
        audio_events_i = [event_i for event_i in session.get_events(data_i) if event_i.marker in audio_markers.values()]
        # If contains at least 2 audio events, use `np.diff` to determine blocks.
        if len(audio_events_i) > 1:
            intervals_i = np.diff([audio_event_i.onset for audio_event_i in audio_events_i])
            min_interval_i = np.min(intervals_i); blocks_i = [[audio_events_i[0],],]
            for interval_idx, interval_i in enumerate(intervals_i):
                # If current event is close to the previous event, add it to the last list.
                if interval_i < min_interval_i * preprocess_params.interval_threshold:
                    blocks_i[-1].append(audio_events_i[interval_idx+1])
                # If current event is too far away from the previous event, add one new list.
                else:
                    blocks_i.append([audio_events_i[interval_idx+1],])
        # If contains only one audio event, just one block.
        else:
            blocks_i = [[audio_events_i[0],],]
        # Ensure blocks are correctly constructed.
        for block_i in blocks_i:
            assert (len(block_i) == 1) or (np.diff([audio_event_i.onset for audio_event_i in block_i]) <\
                min_interval_i * preprocess_params.interval_threshold).all()
        # Use `blocks_i` to crop the original `data_i`.
        crop_ranges_i = [(block_i[0].onset - (crop_padding[0]+np.random.uniform(-crop_padding[1],crop_padding[1])),
            block_i[-1].onset + (crop_padding[0]+np.random.uniform(-crop_padding[1],crop_padding[1]))) for block_i in blocks_i]
        data_i = session.crop(data_i.copy(), crop_ranges=crop_ranges_i)
        # Append `data_i` into `data_n23`.
        data_n23.append(data_i)
    for sleep_range_i in sleep_stages["REM"]:
        ## Prepare data crop.
        # Crop data range, then get events & markers.
        data_i = session.crop(data.copy(), crop_ranges=[sleep_range_i,])
        events_i = session.get_events(data_i); markers_i = set([event_i.marker for event_i in events_i])
        # If current data slice doesn't include tmr events, skip it.
        if len(markers_i & set(audio_markers.values())) == 0: continue
        ## Start & end tmr event detection.
        # Pop events, so that `events_i` starts & ends with tmr event.
        while events_i[0].marker not in audio_markers.values(): events_i.pop(0)
        while events_i[-1].marker not in audio_markers.values(): events_i.pop(-1)
        # Use the first & last event to further crop data.
        start_time_i = events_i[0].onset - (crop_padding[0] + np.random.uniform(-crop_padding[1], crop_padding[1]))
        end_time_i = events_i[-1].onset + (crop_padding[0] + np.random.uniform(-crop_padding[1], crop_padding[1]))
        data_i = session.crop(data_i.copy(), crop_ranges=[(start_time_i, end_time_i),])
        ## Long interval detection.
        # Since there are long intervals during tmr, we use `interval_threshold` to crop data.
        audio_events_i = [event_i for event_i in session.get_events(data_i) if event_i.marker in audio_markers.values()]
        # If contains at least 2 audio events, use `np.diff` to determine blocks.
        if len(audio_events_i) > 1:
            intervals_i = np.diff([audio_event_i.onset for audio_event_i in audio_events_i])
            min_interval_i = np.min(intervals_i); blocks_i = [[audio_events_i[0],],]
            for interval_idx, interval_i in enumerate(intervals_i):
                # If current event is close to the previous event, add it to the last list.
                if interval_i < min_interval_i * preprocess_params.interval_threshold:
                    blocks_i[-1].append(audio_events_i[interval_idx+1])
                # If current event is too far away from the previous event, add one new list.
                else:
                    blocks_i.append([audio_events_i[interval_idx+1],])
        # If contains only one audio event, just one block.
        else:
            blocks_i = [[audio_events_i[0],],]
        # Ensure blocks are correctly constructed.
        for block_i in blocks_i:
            assert (len(block_i) == 1) or (np.diff([audio_event_i.onset for audio_event_i in block_i]) <\
                min_interval_i * preprocess_params.interval_threshold).all()
        # Use `blocks_i` to crop the original `data_i`.
        crop_ranges_i = [(block_i[0].onset - (crop_padding[0]+np.random.uniform(-crop_padding[1],crop_padding[1])),
            block_i[-1].onset + (crop_padding[0]+np.random.uniform(-crop_padding[1],crop_padding[1]))) for block_i in blocks_i]
        data_i = session.crop(data_i.copy(), crop_ranges=crop_ranges_i)
        # Append `data_i` into `data_rem`.
        data_rem.append(data_i)
    data_n23 = mne.concatenate_raws(data_n23) if len(data_n23) > 1 else (data_n23[0] if len(data_n23) == 1 else None)
    data_rem = mne.concatenate_raws(data_rem) if len(data_rem) > 1 else (data_rem[0] if len(data_rem) == 1 else None)
    # Add `dataset_i` to the whole `dataset`.
    dataset = DotDict({
        "REM": _preprocess_tmr(data_rem, preprocess_params, path_run_output_rem, rerun=rerun) if data_rem is not None else [],
        "N2/3": _preprocess_tmr(data_n23, preprocess_params, path_run_output_n23, rerun=rerun) if data_n23 is not None else [],
    })
    # Save data to dataset.
    save_pickle(path_run_dataset, dataset)

# def _preprocess_tmr func
def _preprocess_tmr(data, preprocess_params, path_run_output, rerun=False):
    """
    The whole pipeline to preprocess tmr eeg data of specified run, e.g. [ant-001,...].

    Args:
        data: object - The loaded `mne.io.brainvision.brainvision.RawBrainVision` object.
        preprocess_params: DotDict - The parameters of preprocess.
        path_run_output: str - The path of specified eeg run.
        rerun: bool - The flag indicates whether rerun ICA.

    Returns:
        dataset: (n_samples[list],) - The list of data pairs.
    """
    # Save the initialized preprocess_params for further future analysis.
    save_pickle(os.path.join(path_run_output, "preprocess_params"), preprocess_params)
    # Execute preprocess over the whole data.
    data = session.preprocess(data, preprocess_params, path_output=path_run_output); events = session.get_events(data)
    # Save the preprocessed data for further future analysis. There are some argument settings we have to note:
    # 1) fmt: "auto" (default). Format to export. Could be one of ["auto", "brainvision", "edf", "eeglab"].
    # 2) physical_range: "auto" (default). Only used for exporting EDF files.
    # 3) add_ch_type: False (default). Only used for exporting EDF files.
    # 4) overwrite: True. If True (default False), overwrite the destination file if it exists.
    mne.export.export_raw(os.path.join(path_run_output, "data.vhdr"), data, overwrite=True)
    # Get bad segments from preprocessed data.
    bad_segments = [(annotation_i["onset"], annotation_i["duration"])\
        for annotation_i in data.annotations if annotation_i["description"] == "bad-segment"]
    # Detect un-expected event, then fix events.
    # TODO: We directly remove all events that are not audio markers, ignoring their effect on previous event.
    audio_markers = preprocess_params.markers.audio; epoch_range = preprocess_params.epochs
    unexpected_idxs = [event_idx for event_idx, event_i in enumerate(events) if event_i.marker not in audio_markers.values()]
    remove_idxs = []; remove_idxs.extend(unexpected_idxs); remove_idxs = list(set(remove_idxs))
    events = [event_i for event_idx, event_i in enumerate(events) if event_idx not in remove_idxs]
    assert len([event_idx for event_idx, event_i in enumerate(events) if event_i.marker not in audio_markers.values()]) == 0
    # Use `bad_segments` to drop events that overlap with bad segments.
    events_ = []; event_idx = 0; segment_idx = 0
    if len(bad_segments) > 0:
        while event_idx < len(events):
            # Ensure that the end time of `bad_segment_i` is greater than the start time of current event.
            while (bad_segments[segment_idx][0] + bad_segments[segment_idx][1] < events[event_idx].onset + epoch_range[0]) and\
                  (segment_idx + 1 < len(bad_segments)): segment_idx += 1
            # The current event has no overlap with `bad_segment_i`.
            if (bad_segments[segment_idx][0] > events[event_idx].onset + epoch_range[1]) or\
               (bad_segments[segment_idx][0] + bad_segments[segment_idx][1] < events[event_idx].onset + epoch_range[0]):
                events_.append(events[event_idx]); event_idx += 1
            # Check whether the start time of `bad_segment_i` is smaller than the end time of current event.
            else:
                event_idx += 1
        # Calculate the differences between the endpoints of the ranges, then
        # check if any of the differences have opposite signs (i.e., overlap).
        if len(events_) > 0:
            events_ranges = np.array([(event_i.onset + epoch_range[0],
                event_i.onset + epoch_range[1]) for event_i in events_], dtype=np.float32)
            bad_segments_ranges = np.array(bad_segments, dtype=np.float32); bad_segments_ranges[:,1] += bad_segments_ranges[:,0]
            differences = np.subtract.outer(events_ranges, bad_segments_ranges)
            assert differences.shape[-1] == 2 and not np.any(np.sign(differences[:,:,:,0]) != np.sign(differences[:,:,:,1]))
    else:
        events_ = events
    print((
        "INFO: Drop {:d} events according to bad segments, the number of remaining events is {:d}."
    ).format(len(events)-len(events_), len(events_)))
    events = events_
    # If the number of remaining events is 0, directly return [].
    if len(events) == 0: return []
    # Log information about the number of diff durations that are less than `epoch_range[1]`.
    onsets = [event_i.onset for event_i in events]; onset_diffs = np.diff(onsets)
    print((
        "WARNING: There are {:d} diff durations that are less than epoch_range[1] ({:.2f}s)."
    ).format(np.sum(onset_diffs <= epoch_range[1]), epoch_range[1]))
    # Before epoching data, use `events` to get pure audio events `audio_events`.
    # Then set events of `data_lvbj`, and finally save that MNE object.
    audio_markers = bidict(audio_markers).inverse
    audio_events = [event_i for event_i in events if event_i.marker in audio_markers.keys()]
    data_lvbj = session.set_events(data.copy(), audio_events)
    mne.export.export_raw(os.path.join(path_run_output, "data-lvbj.vhdr"), data_lvbj, overwrite=True)
    # Epoch data from `events`, only using audio markers.
    dataset = []
    audio_events = np.array([[
        data.time_as_index(event_i.onset, use_rounding=True)[0], 0, event_i.marker
    ] for event_i in events if event_i.marker in audio_markers.keys()], dtype=np.int64)
    # Note: We set `baseline=(0,0)`, i.e. we do not use baseline correction when creating Epochs.
    audio_epochs = mne.Epochs(data, audio_events, tmin=epoch_range[0],
        tmax=epoch_range[1]-(1./data.info["sfreq"]), baseline=(0, 0), preload=True)
    audio_epochs = audio_epochs.pick(picks=mne.pick_types(audio_epochs.info, eeg=True))
    audio_epochs = [DotDict({"name":audio_markers[audio_event_i],"data":audio_epoch_i,})\
        for audio_event_i, audio_epoch_i in zip(audio_events[:,-1], audio_epochs.get_data())]
    # Note: There is no need to check whether `audio_events` has the same length with `audio_epochs`, because
    # when epoching data, there may be other annotations starts with "bad", which leads to epoch drop.
    print((
        "INFO: Drop {:d} events due to other bad annotations, the number of remaining epochs is {:d}."
    ).format(audio_events.shape[0] - len(audio_epochs), len(audio_epochs)))
    # Construct data from [audio,]_epochs.
    for audio_epoch_i in audio_epochs: dataset.append(DotDict({"audio":audio_epoch_i,}))
    # Return the final `dataset`.
    return dataset

# def _sleep_stage func
def _sleep_stage(data, sls_params):
    """
    Perform sleep stage over raw data.

    Args:
        data: object - The loaded `mne.io.brainvision.brainvision.RawBrainVision` object.
        sls_params: DotDict - The parameters of sleep stage.

    Returns:
        data: object - The modified `mne.io.brainvision.brainvision.RawBrainVision` object.
    """
    # Initialzie duration_scale, i.e. the scale of predicted sleep stage duration.
    duration_scale = 30.; hypno_map = bidict({"W":0,"N1":1,"N2":2,"N3":3,"REM":4,})
    # Construct `data_sls` from `sls_params`.
    data_sls = data.copy()
    if (sls_params.eeg_channel is not None) and (sls_params.eeg_ref_channel is not None):
        sls_eeg_idx = mne.pick_channels(data_sls.ch_names, [sls_params.eeg_channel,])[0]
        sls_eeg_ref_idx = mne.pick_channels(data_sls.ch_names, [sls_params.eeg_ref_channel,])[0]
        data_sls._data[sls_eeg_idx,:] = data_sls._data[sls_eeg_idx,:] - data_sls._data[sls_eeg_ref_idx,:]
    if (sls_params.eog_channel is not None) and (sls_params.eog_ref_channel is not None):
        sls_eog_idx = mne.pick_channels(data_sls.ch_names, [sls_params.eog_channel,])[0]
        sls_eog_ref_idx = mne.pick_channels(data_sls.ch_names, [sls_params.eog_ref_channel,])[0]
        data_sls._data[sls_eog_idx,:] = data_sls._data[sls_eog_idx,:] - data_sls._data[sls_eog_ref_idx,:]
    if (sls_params.emg_channel is not None) and (sls_params.emg_ref_channel is not None):
        sls_emg_idx = mne.pick_channels(data_sls.ch_names, [sls_params.emg_channel,])[0]
        sls_emg_ref_idx = mne.pick_channels(data_sls.ch_names, [sls_params.emg_ref_channel,])[0]
        data_sls._data[sls_emg_idx,:] = data_sls._data[sls_emg_idx,:] - data_sls._data[sls_emg_ref_idx,:]
    # Initialize sleep stage object.
    sls = yasa.SleepStaging(data_sls, eeg_name=sls_params.eeg_channel,
        eog_name=sls_params.eog_channel, emg_name=sls_params.emg_channel)
    # Predict `hypno_pred`, then convert "W" to 0, "N1" to 1, "N2" to 2, "N3" to 3, "REM" to 4.
    hypno_pred = sls.predict(); hypno_pred = yasa.hypno_str_to_int(hypno_pred)
    # Find sleep stage ranges according to `hypno_pred`.
    ranges = []; ranges_end = np.where(np.diff(hypno_pred) != 0)[0]
    if ranges_end[0] > 0: ranges.append([0, ranges_end[0]])
    for range_idx in range(len(ranges_end) - 1): ranges.append([ranges_end[range_idx] + 1, ranges_end[range_idx+1]])
    if ranges_end[-1] < len(hypno_pred) - 1: ranges.append([ranges_end[-1] + 1, len(hypno_pred) - 1])
    ranges = [[hypno_map.inverse[hypno_pred[range_i[0]]],
        (range_i[0]*duration_scale, (range_i[1]-range_i[0]+1)*duration_scale)] for range_i in ranges]
    # TODO: `data.tmax` may lead to loss of last point.
    ranges[-1] = [ranges[-1][0], (ranges[-1][1][0], data.tmax - ranges[-1][1][0])]
    # Set annotations according to yasa sleep stage.
    annotations = mne.Annotations(
        onset=[range_i[1][0] + data.first_time for range_i in ranges],
        duration=[range_i[1][1] for range_i in ranges],
        description=["sleep-{}".format(range_i[0]) for range_i in ranges],
        orig_time=data.annotations.orig_time,
    )
    data.set_annotations(mne.Annotations(
        onset=[annotation_i["onset"] - data.first_time for annotation_i in data.annotations],
        duration=[annotation_i["duration"] for annotation_i in data.annotations],
        description=[annotation_i["description"] for annotation_i in data.annotations],
        orig_time=data.annotations.orig_time
    ) + annotations)
    # Log information related to sleep stage.
    ranges_ = DotDict()
    for range_i in ranges:
        if not hasattr(ranges_, range_i[0]):
            ranges_[range_i[0]] = [range_i[1],]
        else:
            ranges_[range_i[0]].append(range_i[1])
    msg = "INFO: Get total {:d} sleep stages:\n".format(len(ranges))
    for key_i in sorted(ranges_.keys()):
        msg += "\t{}: total {:d} sleep stages with median duration {:.2f}s. {}.\n".format(
            key_i, len(ranges_[key_i]), np.median([range_i[1] for range_i in ranges_[key_i]]), ranges_[key_i])
    print(msg)
    # Return the final `data`.
    return data

# def _merge_overlapping_ranges func
def _merge_overlapping_ranges(ranges):
    """
    Merge ranges that have overlaps.

    Args:
        ranges: list - The list of ranges.

    Returns:
        merged_ranges: list - The list of merged ranges.
    """
    # Initialize `ranges` & `merged_ranges`, then set `merged_ranges`.
    ranges = SortedSet(ranges); merged_ranges = []; range_i = None
    for start_i, end_i in ranges:
        if range_i is None:
            range_i = (start_i, end_i)
        elif start_i <= range_i[1]:
            range_i = (range_i[0], max(range_i[1], end_i))
        else:
            merged_ranges.append(range_i); range_i = (start_i, end_i)
    # If `ranges` has at least 1 range, `range_i` cannot be None.
    if range_i is not None: merged_ranges.append(range_i)
    # Return the final `merged_ranges`.
    return merged_ranges

if __name__ == "__main__":
    import time

    # Initialize macros.
    rerun = False; seed = 42
    # Initialize base path & subj_runs.
    base = os.path.join(os.getcwd(), os.pardir, os.pardir, os.pardir, os.pardir)
    subj_runs = ["ant-001/20230602", "ant-002/20230603",]

    # Initiialize random seed.
    np.random.seed(seed)

    # Loop preprocess over all available `subj_runs`.
    for subj_run_i in subj_runs:
        # Record the start time.
        time_start_i = time.time(); print("INFO: Start preprocessing subj-run ({}).".format(subj_run_i))
        # Get `subj_i` & `run_i` from `subj_run_i`, then get `path_run_i`.
        subj_i, run_i = subj_run_i.split("/"); path_run_i = os.path.join(base, "data", "eeg.anonymous", subj_i, run_i)
        # Execute task preprocess pipiline.
        preprocess_task(path_run_i, rerun=rerun, seed=seed)
        # Execute tmr preprocess pipiline.
        preprocess_tmr(path_run_i, rerun=rerun, seed=seed)
        # Record the whole time of preprocess.
        print("INFO: The total time of preprocessing subj-run ({}) is {:.2f}s.".format(subj_run_i, time.time()-time_start_i))

