import numpy as np
import os
from scipy.io import loadmat
import einops
import re

# the data path
raw_data_path = "../eeg_data/ExtractedFeatures/"
train_save_path = '../EEGDATA/SEED/'
test_save_path = '../EEGDATA/SEED/'

# labels of SEED dataset
label_seed3 = [
    [2, 1, 0, 0, 1, 2, 0, 1, 2, 2, 1, 0, 1, 2, 0],
    [2, 1, 0, 0, 1, 2, 0, 1, 2, 2, 1, 0, 1, 2, 0],
    [2, 1, 0, 0, 1, 2, 0, 1, 2, 2, 1, 0, 1, 2, 0]
]
trial_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

def find_strings(list1, str1):
    return [item for item in list1 if item.startswith(str1)]

def load_eeg_data(subject_id, session_id, raw_path):
    data_list = []
    base_path = os.path.join(raw_path, str(session_id))
    subject = find_strings(os.listdir(base_path), str(subject_id) + "_")[0]
    subject_name = str(subject).strip('.mat')

    data = np.array([])
    frequency = "de_LDS"
    count = 0

    for i in trial_list:
        dataKey = frequency + str(i+1)
        metaData = np.array((loadmat(os.path.join(base_path, subject_name), verify_compressed_data_integrity=False)[dataKey])).astype('float')
        trMetaData = einops.rearrange(metaData, 'w h c -> h (w c)')  # (42, 310)

        count += 1
        x = np.array(trMetaData)
        data_list.append(x)
    return data_list


def save_data(subject_id, session_id, data, labels, path, train=True):
    np_data = np.concatenate(data, axis=0)
    np_labels = np.concatenate([
        np.repeat(label, arr.shape[0]) for label, arr in zip(labels, data)
    ])
    np_labels = np_labels.reshape(-1, 1)

    suffix = 'train' if train else 'test'
    data_path = os.path.join(path, f'{suffix}/de/{subject_id}_{session_id}.npy')
    label_path = os.path.join(path, f'{suffix}/label/{subject_id}_{session_id}.npy')

    # save data
    np.save(data_path, np.array(np_data))
    np.save(label_path, np.array(np_labels))


for subject_id in range(1, 16):  # 15 subjects
    for session_id in range(1, 4):  # 3 sessions
        all_data = load_eeg_data(subject_id, session_id, raw_data_path)
        session_labels = label_seed3[session_id - 1]

        # split into origin train dataset and test set
        train_data = all_data[:9]
        test_data = all_data[9:]
        train_labels = session_labels[:9]
        test_labels = session_labels[9:]

        # save data
        save_data(subject_id, session_id, train_data, train_labels, train_save_path, train=True)
        save_data(subject_id, session_id, test_data, test_labels, test_save_path, train=False)