import numpy as np
import os
from scipy.io import loadmat
import einops

raw_data_path = "../dataset/SEED_IV/eeg_feature_smooth/"
save_path = '../EEGDATA/SEED_IV/EEG/'

trial_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]

def eeg_data(trial_list, raw_path, subject_id):
    data_list= []
    subject = find_strings(os.listdir(raw_path), subject_id+"_")[0]
    subject_name = str(subject).strip('.mat')

    data = np.array([])
    label = np.array([])
    clip = None
    frequency = "de_LDS"
    count = 0

    for i in trial_list:
        dataKey = frequency + str(i+1)
        metaData = np.array((loadmat(os.path.join(raw_path, subject_name), verify_compressed_data_integrity=False)[dataKey])).astype('float')
        trMetaData = einops.rearrange(metaData, 'w h c -> h (w c)')  # (42, 310)

        count += 1
        x = np.array(trMetaData)
        data_list.append(x)

        if count == 16:
            clip = data.shape[0]

    return data_list

# labels sequence on the SEED-IV dataset
session1_label = [1, 2, 3, 0, 2, 0, 0, 1, 0, 1, 2, 1, 1, 1, 2, 3, 2, 2, 3, 3, 0, 3, 0, 3]
session2_label = [2, 1, 3, 0, 0, 2, 0, 2, 3, 3, 2, 3, 2, 0, 1, 1, 2, 1, 0, 3, 0, 1, 3, 1]
session3_label = [1, 2, 2, 1, 3, 3, 3, 1, 1, 2, 1, 0, 2, 3, 3, 0, 2, 3, 0, 0, 2, 0, 1, 0]


def rearrange_labels_and_data(labels, data):
    # Initialize storage for rearranged labels and data
    new_labels = [0] * len(labels)
    new_data = [0] * len(data)
    # Store indices of each type of label in groups
    label_groups = {0: [], 1: [], 2: [], 3: []}
    for idx, label in enumerate(labels):
        label_groups[label].append(idx)

    # Rearrange labels and data in the order of 0, 1, 2, 3
    order = [0, 1, 2, 3]
    idx = 0
    while idx < len(labels):
        for label in order:
            if label_groups[label]:
                new_labels[idx] = label
                new_data[idx] = data[label_groups[label].pop(0)]
                idx += 1
                if idx >= len(labels):
                    break

    return new_labels, new_data


def split_into_folds(new_labels, new_data, num_folds=3):
    fold_size = len(new_labels) // num_folds
    folds = []
    for i in range(num_folds):
        fold_labels = new_labels[i * fold_size:(i + 1) * fold_size]
        fold_data = new_data[i * fold_size:(i + 1) * fold_size]
        folds.append((fold_labels, fold_data))
    return folds

def find_strings(list1, str1):
    return [item for item in list1 if item.startswith(str1)]

def saveData(subject_id, fold_id, fold_data, fold_labels):
    data = np.concatenate(fold_data, axis=0)
    labels = np.concatenate([
        np.repeat(label, arr.shape[0]) for label, arr in zip(fold_labels, fold_data)
    ])
    labels = labels.reshape(-1, 1)
    np.save(os.path.join(save_path, 'de_{}_{}').format(subject_id, fold_id), data)
    np.save(os.path.join(save_path, 'label_{}_{}').format(subject_id, fold_id), labels)

for subject_id in range(1, 16):

    session1_data = eeg_data(trial_list, os.path.join(raw_data_path, str(1)), str(subject_id))
    session2_data = eeg_data(trial_list, os.path.join(raw_data_path, str(2)), str(subject_id))
    session3_data = eeg_data(trial_list, os.path.join(raw_data_path, str(3)), str(subject_id))

    # Rearrange the three sessions and split them into folds.
    session1_new_labels, session1_new_data = rearrange_labels_and_data(session1_label, session1_data)
    session2_new_labels, session2_new_data = rearrange_labels_and_data(session2_label, session2_data)
    session3_new_labels, session3_new_data = rearrange_labels_and_data(session3_label, session3_data)

    session1_folds = split_into_folds(session1_new_labels, session1_new_data)
    session2_folds = split_into_folds(session2_new_labels, session2_new_data)
    session3_folds = split_into_folds(session3_new_labels, session3_new_data)

    fold1_labels = session1_folds[0][0] + session2_folds[0][0] + session3_folds[0][0]
    fold1_data = session1_folds[0][1] + session2_folds[0][1] + session3_folds[0][1]
    saveData(subject_id, 1, fold1_data, fold1_labels)

    fold2_labels = session1_folds[1][0] + session2_folds[1][0] + session3_folds[1][0]
    fold2_data = session1_folds[1][1] + session2_folds[1][1] + session3_folds[1][1]
    saveData(subject_id, 2, fold2_data, fold2_labels)

    fold3_labels = session1_folds[2][0] + session2_folds[2][0] + session3_folds[2][0]
    fold3_data = session1_folds[2][1] + session2_folds[2][1] + session3_folds[2][1]
    saveData(subject_id, 3, fold3_data, fold3_labels)


