#!/usr/bin/env python3
"""
Created on 19:59, Dec. 25th, 2022

@author: Anonymous
"""
import os
import numpy as np
# local dep
if __name__ == "__main__":
    import os, sys
    sys.path.insert(0, os.path.join(os.pardir, os.pardir, os.pardir))
from utils import DotDict
from utils.data import load_pickle

__all__ = [
    "load_run_task",
    "load_run_tmr",
]

# def load_run_task func
def load_run_task(path_run, session_type="image-audio-pre"):
    """
    Load the task data from specified run.

    Args:
        path_run: str - The path of specified run.
        session_type: str - The type of session, could be one of [image-audio-[pre,post],audio-image-[pre,post]].

    Returns:
        X: DotDict - The raw data loaded from specified run, containing [image,audio], and
            each data type contains [eeg,] organized as (n_samples, seq_len, n_channels).
        y: DotDict - The corresponding label of the raw data, containing [image,audio].
    """
    # Load data from specified run.
    data = load_pickle(os.path.join(path_run, "dataset.task"))[session_type]; np.random.shuffle(data)
    if len(data) == 0: return DotDict({"image":None,"audio":None,}), DotDict({"image":None,"audio":None,})
    # Get the corresponding label set.
    image_labels = [data_i.image.name for data_i in data]; audio_labels = [data_i.audio.name for data_i in data]
    # Should have both [image,audio] data.
    assert set(image_labels) == set(audio_labels)
    labels = list(set(image_labels)); labels.sort()
    y_image = np.array([labels.index(image_label_i) for image_label_i in image_labels], dtype=np.int64)
    X_image = np.transpose(np.array([data_i.image.data for data_i in data], dtype=np.float32), axes=[0,2,1])
    y_audio = np.array([labels.index(audio_label_i) for audio_label_i in audio_labels], dtype=np.int64)
    X_audio = np.transpose(np.array([data_i.audio.data for data_i in data], dtype=np.float32), axes=[0,2,1])
    # Return the final `X` & `y`.
    return DotDict({"image":X_image,"audio":X_audio,}), DotDict({"image":y_image,"audio":y_audio,})

# def load_run_tmr func
def load_run_tmr(path_run, session_type="N2/3"):
    """
    Load the tmr data from specified run.

    Args:
        path_run: str - The path of specified run.
        session_type: str - The type of session, could be one of [N2/3,REM].

    Returns:
        X: DotDict - The raw data loaded from specified run, containing [audio,], and
            each data type contains [eeg,] organized as (n_samples, seq_len, n_channels).
        y: DotDict - The corresponding label of the raw data, containing [audio,].
    """
    # Load data from specified run.
    data = load_pickle(os.path.join(path_run, "dataset.tmr"))[session_type]; np.random.shuffle(data)
    if len(data) == 0: return DotDict({"image":None,"audio":None,}), DotDict({"image":None,"audio":None,})
    # Get the corresponding label set.
    audio_labels = [data_i.audio.name for data_i in data]
    # Get the corresponding X & y.
    labels = list(set(audio_labels)); labels.sort()
    y_audio = np.array([labels.index(audio_label_i) for audio_label_i in audio_labels], dtype=np.int64)
    X_audio = np.transpose(np.array([data_i.audio.data for data_i in data], dtype=np.float32), axes=[0,2,1])
    # Return the final `X` & `y`.
    return DotDict({"image":None,"audio":X_audio,}), DotDict({"image":None,"audio":y_audio,})

if __name__ == "__main__":
    # macro
    base = os.path.join(os.getcwd(), os.pardir, os.pardir, os.pardir)
    path_run = os.path.join(base, "data", "eeg.anonymous", "005", "20221223")

    # Load task & tmr data from specified run.
    data = load_run_task(path_run); data = load_run_tmr(path_run)

