# Script that creates the the labeling structure for PhysioNET used in downstream classification
# Uses entire runs as input

import os
import csv
import tqdm
import numpy as np
import mne
import pyedflib as edf

print('running create_pretrain_structure.py')

def load_data_edf(path):
    data = mne.io.read_raw_edf(path, preload=True)
    data = data.get_data()
    return data

def load_annotations_edf(path):
    f = edf.EdfReader(path)
    data = f.readAnnotations()
    f.close()
    return data

os.chdir("/scratch/sem23h11/physionet.org/files/eegmmidb/1.0.0")

# iterate over all files in the physionet dataset
for dir in os.listdir("/scratch/sem23h11/physionet.org/files/eegmmidb/1.0.0"):
    if os.path.isfile(dir):
        continue
    for file in os.listdir(dir):
        #check if the file is a .edf file
        if file.endswith(".edf"):
            #save content as .npy file
            data = load_data_edf(dir + "/" + file)
            ann = load_annotations_edf(dir + "/" + file)
            starting_times = ann[0]
            # durations = ann[1] # not needed
            labels_d = ann[2]
            ann = np.zeros_like(data[0])
            for t in range(len(starting_times)):
                beginning = int(starting_times[t] * 160)
                end = int(starting_times[t+1] * 160) if t < len(starting_times) - 1 else len(ann)
                if labels_d[t] == 'T0':
                    ann[beginning:end] = 0
                elif labels_d[t] == 'T1':
                    ann[beginning:end] = 1
                elif labels_d[t] == 'T2':
                    ann[beginning:end] = 2
            path = "/scratch/sem23h11/BrainBERT/pretrain_data/" + dir + "/" + file.split('.edf')[0] + "/" + 'ann' + ".npy"
            np.save(path, ann)
            print(path)