import argparse
import os.path as op
import mne
import numpy as np
import os
import pickle

def reorganize_WM(project_dir, band,save_dir):
    # project_dir = '/home/public/NIPS/watermelon_dataset/'
    # save_dir = './KUL_WM'
    if not op.exists(save_dir):
        os.makedirs(save_dir)

    sub_name = ['S0', 'S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9']
    # fix the random seed
    np.random.seed(2024)


    for sub in sub_name:
        cnt_fname = op.join(project_dir, f'{sub}.cnt')
        raw = mne.io.read_raw_cnt(cnt_fname, preload=True, data_format='int32').load_data()
        # downsample to 128Hz


        # delta: 0-4
        # theta: 4-8
        # alpha: 8-12
        # beta: 12-32
        # low_gamma: 32-45
        # high_gamma: 55-95
        if band == 'delta':
            raw.filter(0, 4)
        elif band == 'theta':
            raw.filter(4, 8)
        elif band == 'alpha':
            raw.filter(8, 12)
        elif band == 'beta':
            raw.filter(12, 32)
        elif band == 'low_gamma':
            raw.filter(32, 45)
        elif band == 'high_gamma':
            raw.filter(55, 95)

        raw.resample(128, npad='auto')
        data_raw = raw.get_data()
        data_raw = data_raw * 1e6
        channel_names = raw.ch_names
        del_name = ['VEO', 'HEO', 'EKG', 'EMG']
        sel_idx = [channel_names.index(name) for name in channel_names if name not in del_name]
        data_raw = data_raw[sel_idx, :]

        eegdata_raw = np.transpose(data_raw)

        eegdata = np.zeros((40, 7680, 64))

        trial_random = np.random.permutation(40)
        # start at 60s, every trial is 60s, resting for 40s
        for i in range(40):
            eegdata[trial_random[i], :] = eegdata_raw[
                                          i * (60 + 40) * 128 + 60 * 128:i * (60 + 40) * 128 + (60 + 60) * 128, 0:64]

        # to 32bit
        eeglabel = np.concatenate([np.full((10,), 0), np.full((10,), 1), np.full((10,), 2), np.full((10,), 3)])
        eegdata = eegdata.astype(np.float32)
        eegdata = eegdata[:, :, 0:32]
        eeg_data_label = {'EEG': eegdata, 'label': eeglabel}
        # save the eeg_data_label pt file
        eeg_savedir = op.join(save_dir, f'{sub}.pkl')
        with open(eeg_savedir, 'wb') as f:
            pickle.dump(eeg_data_label, f)



def reorganize_SK(project_dir, band,save_dir):
    # change this to the path of the watermelon dataset
    # project_dir = '/home/public/NIPS/sparrKULee'
    # save_dir = './KUL_SK'
    if not op.exists(save_dir):
        os.makedirs(save_dir)

    sub_name = ['S0', 'S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9']
    sub_raw_name = [f'sub-{i:03d}' for i in range(7, 17)]
    # fix the random seed
    np.random.seed(2024)

    for sub_i, sub in enumerate(sub_raw_name):
        eeg_sub_dir = op.join(project_dir,  sub, 'ses-shortstories01')

        trial_random = np.random.permutation(40)
        eegdata = np.zeros((40, 7680, 64))


        for run in range(1, 9):
            name_prefix = f'{sub}_ses-shortstories01_task-listeningActive_run-{run:02d}_desc-preproc-audio'
            eeg_fname = op.join(eeg_sub_dir, f'{name_prefix}-_eeg.pickle')
            raw_fname = op.join(eeg_sub_dir, f'{name_prefix}-_raw.pickle')
            with open(eeg_fname, 'rb') as f:
                eeg_data = pickle.load(f)
            eeg_data = eeg_data / 1e3

            # downsample to 128Hz using mne
            raw = mne.io.RawArray(eeg_data, mne.create_info(ch_names=64, sfreq=1000, ch_types='eeg'))

            # delta: 0-4
            # theta: 4-8
            # alpha: 8-12
            # beta: 12-32
            # low_gamma: 32-45
            # high_gamma: 55-95
            if band == 'delta':
                raw.filter(0, 4)
            elif band == 'theta':
                raw.filter(4, 8)
            elif band == 'alpha':
                raw.filter(8, 12)
            elif band == 'beta':
                raw.filter(12, 32)
            elif band == 'low_gamma':
                raw.filter(32, 45)
            elif band == 'high_gamma':
                raw.filter(55, 95)

            raw.resample(128, npad='auto')
            eeg_data = raw.get_data()
            # data_raw = eeg_data.get_data()
            # channel_names = eeg_data.ch_names
            # eegdata = np.transpose(data_raw)

            eeg_data = np.transpose(eeg_data)

            for i in range(4):
                eegdata[trial_random[run * 4 - 4 + i], :] = eeg_data[
                                                            i * (60 + 40) * 128 + 60 * 128:i * (60 + 40) * 128 + (
                                                                    60 + 60) * 128, 0:64]

        eegdata = eegdata.astype(np.float32)
        eegdata = eegdata[:, :, 0:32]
        eeglabel = np.concatenate([np.full((10,), 0), np.full((10,), 1), np.full((10,), 2), np.full((10,), 3)])
        eeg_data_label = {'EEG': eegdata, 'label': eeglabel}
        # save the eeg_data_label pt file
        eeg_savedir = op.join(save_dir, f'{sub_name[sub_i]}.pkl')

        with open(eeg_savedir, 'wb') as f:
            pickle.dump(eeg_data_label, f)

