import argparse
import os.path as op
import mne
import numpy as np
import os
import pickle

def reorganize_WM(project_dir, band,save_dir):
    # project_dir = '/home/public/NIPS/watermelon_dataset/'
    # save_dir = './KUL_WM'
    if not op.exists(save_dir):
        os.makedirs(save_dir)

    sub_name = ['S0', 'S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9']
    # fix the random seed
    np.random.seed(2024)


    for sub in sub_name:
        cnt_fname = op.join(project_dir, f'{sub}.cnt')
        raw = mne.io.read_raw_cnt(cnt_fname, preload=True, data_format='int32').load_data()
        # downsample to 128Hz


        # delta: 0-4
        # theta: 4-8
        # alpha: 8-12
        # beta: 12-32
        # low_gamma: 32-45
        # high_gamma: 55-95
        if band == 'delta':
            raw.filter(0, 4)
        elif band == 'theta':
            raw.filter(4, 8)
        elif band == 'alpha':
            raw.filter(8, 12)
        elif band == 'beta':
            raw.filter(12, 32)
        elif band == 'low_gamma':
            raw.filter(32, 45)
        elif band == 'high_gamma':
            raw.filter(55, 95)

        raw.resample(128, npad='auto')
        data_raw = raw.get_data()
        data_raw = data_raw * 1e6
        channel_names = raw.ch_names
        del_name = ['VEO', 'HEO', 'EKG', 'EMG']
        sel_idx = [channel_names.index(name) for name in channel_names if name not in del_name]
        data_raw = data_raw[sel_idx, :]

        eegdata_raw = np.transpose(data_raw)
        # KUL dataset: 8 trial, 128 sampling rate, 360s,64 channels
        eegdata = np.zeros((8, 360*128, 64))
        eeglabel = [0, 0, 0 ,0, 1, 1, 1, 1]

        trial_random = np.random.permutation(8)

        # start at 120s, every trial is 360s, resting for 120s
        for i in range(8):
            eegdata[trial_random[i],:] = eegdata_raw[i*(360+120)*128+120*128:i*(360+120)*128+(360+120)*128, :]


        # to 32bit to save memory
        eegdata = eegdata.astype(np.float32)
        eeg_data_label = {'EEG': eegdata, 'label': eeglabel}
        # save the eeg_data_label pt file
        eeg_savedir = op.join(save_dir, f'{sub}.pkl')
        with open(eeg_savedir, 'wb') as f:
            pickle.dump(eeg_data_label, f)

def reorganize_SK(project_dir, band,save_dir):
    # change this to the path of the watermelon dataset
    # project_dir = '/home/public/NIPS/sparrKULee'
    # save_dir = './KUL_SK'
    if not op.exists(save_dir):
        os.makedirs(save_dir)

    sub_name = ['S0', 'S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9']
    sub_raw_name = [f'sub-{i:03d}' for i in range(7, 17)]
    # fix the random seed
    np.random.seed(2024)

    for sub_i, sub in enumerate(sub_raw_name):
        eeg_sub_dir = op.join(project_dir,  sub, 'ses-shortstories01')

        trial_random = np.random.permutation(8)
        eegdata = np.zeros((8, 360 * 128, 64))
        eeglabel = [0, 0, 0, 0, 1, 1, 1, 1]

        for run in range(1, 9):
            name_prefix = f'{sub}_ses-shortstories01_task-listeningActive_run-{run:02d}_desc-preproc-audio'
            eeg_fname = op.join(eeg_sub_dir, f'{name_prefix}-_eeg.pickle')
            raw_fname = op.join(eeg_sub_dir, f'{name_prefix}-_raw.pickle')
            with open(eeg_fname, 'rb') as f:
                eeg_data = pickle.load(f)
            eeg_data = eeg_data / 1e3

            # downsample to 128Hz using mne
            raw = mne.io.RawArray(eeg_data, mne.create_info(ch_names=64, sfreq=1000, ch_types='eeg'))

            # delta: 0-4
            # theta: 4-8
            # alpha: 8-12
            # beta: 12-32
            # low_gamma: 32-45
            # high_gamma: 55-95
            if band == 'delta':
                raw.filter(0, 4)
            elif band == 'theta':
                raw.filter(4, 8)
            elif band == 'alpha':
                raw.filter(8, 12)
            elif band == 'beta':
                raw.filter(12, 32)
            elif band == 'low_gamma':
                raw.filter(32, 45)
            elif band == 'high_gamma':
                raw.filter(55, 95)

            raw.resample(128, npad='auto')
            eeg_data = raw.get_data()

            eeg_data = np.transpose(eeg_data)

            eegdata[trial_random[run - 1], :] = eeg_data[60 * 128:(360 + 60) * 128, 0:64]

        eegdata = eegdata.astype(np.float32)
        eeg_data_label = {'EEG': eegdata, 'label': eeglabel}
        # save the eeg_data_label pt file
        eeg_savedir = op.join(save_dir, f'{sub_name[sub_i]}.pkl')
        with open(eeg_savedir, 'wb') as f:
            pickle.dump(eeg_data_label, f)