import os
import torch
import numpy as np

from einops import rearrange
from utils.io import write_json
from data.process.utils import processed_path, downloadextract, targetpaths


# eeg datasets are from this repo: https://github.com/mims-harvard/TFC-pretraining

# Label values
# W = 0
# N1 = 1
# N2 = 2
# N3 = 3
# REM = 4
# UNKNOWN = 5


def main():
    downloadextract("epilepsy")
    preprocess_EpilepsyData()


def preprocess_EpilepsyData(reprocess=False):

    splits = ["train", "val", "test"]
    processedpath = processed_path["epilepsy"]
    targetpath = targetpaths["epilepsy"]

    os.makedirs(processedpath, exist_ok=True)

    if os.path.exists(processedpath) and reprocess == False:
        print("EEG data has already been processed")

    for split in splits:
        data = torch.load(os.path.join(targetpath, f"{split}.pt"))

        subseq = rearrange(data["samples"], "b c t -> b t c")
        labels = data["labels"]

        np.save(os.path.join(processedpath, f"{split}_data_subseq.npy"), subseq)
        np.save(os.path.join(processedpath, f"{split}_labels_subseq.npy"), labels)

    # label_name = {
    #     "0": "W",
    #     "1": "N1",
    #     "2": "N2",
    #     "3": "N3",
    #     "4": "REM",
    #     }

    # write_json(label_name, os.path.join(processedpath, "label_name.json"))


if __name__ == "__main__":
    main()
