import argparse
import os.path as op
import mne
import numpy as np
import os
import pickle

def reorganize_WM(project_dir, band,save_dir):
    # project_dir = '/home/public/NIPS/watermelon_dataset/'
    # save_dir = './KUL_WM'
    if not op.exists(save_dir):
        os.makedirs(save_dir)

    sub_name = ['S0', 'S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9']
    # fix the random seed
    np.random.seed(2024)


    for sub in sub_name:
        cnt_fname = op.join(project_dir, f'{sub}.cnt')
        raw = mne.io.read_raw_cnt(cnt_fname, preload=True, data_format='int32').load_data()
        # downsample to 128Hz


        # delta: 0-4
        # theta: 4-8
        # alpha: 8-12
        # beta: 12-32
        # low_gamma: 32-45
        # high_gamma: 55-95
        if band == 'delta':
            raw.filter(0, 4)
        elif band == 'theta':
            raw.filter(4, 8)
        elif band == 'alpha':
            raw.filter(8, 12)
        elif band == 'beta':
            raw.filter(12, 32)
        elif band == 'low_gamma':
            raw.filter(32, 45)
        elif band == 'high_gamma':
            raw.filter(55, 95)

        data_raw = raw.get_data()
        data_raw = data_raw * 1e6
        channel_names = raw.ch_names
        del_name = ['VEO', 'HEO', 'EKG', 'EMG']
        sel_idx = [channel_names.index(name) for name in channel_names if name not in del_name]
        data_raw = data_raw[sel_idx, :]

        eegdata_raw = np.transpose(data_raw)

        eegdata = np.zeros((40, 50, 500, 128))
        eeglabel = list(range(40))



        trial_random = np.random.permutation(40)

        for tr in range(40):
            for pic in range(50):
                eegtmp = eegdata_raw[
                         int((tr * 35 + pic * 0.5) * 1000 + 10000):int((tr * 35 + pic * (0.5) + 0.5) * 1000 + 10000),
                         0:64]
                # eegdata[trial_random[tr], pic, :, :] = np.concatenate((eegtmp, eegtmp,), axis=1)
                # attention : the only difference
                eegdata[tr, pic, :, :] = np.concatenate((eegtmp, eegtmp,), axis=1)

        # to 32bit
        eegdata = eegdata.astype(np.float32)
        eeg_data_label = {'EEG': eegdata, 'label': eeglabel}
        # save the eeg_data_label pt file
        eeg_savedir = op.join(save_dir, f'{sub}.pkl')
        with open(eeg_savedir, 'wb') as f:
            pickle.dump(eeg_data_label, f)

def reorganize_SK(project_dir, band,save_dir):
    # change this to the path of the watermelon dataset
    # project_dir = '/home/public/NIPS/sparrKULee'
    # save_dir = './KUL_SK'
    if not op.exists(save_dir):
        os.makedirs(save_dir)

    sub_name = ['S0', 'S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9']
    sub_raw_name = [f'sub-{i:03d}' for i in range(7, 17)]
    # fix the random seed
    np.random.seed(2024)

    for sub_i, sub in enumerate(sub_raw_name):
        eeg_sub_dir = op.join(project_dir,  sub, 'ses-shortstories01')

        eegdata = np.zeros((40, 50, 500, 128))
        eeglabel = list(range(40))

        for run in range(1, 9):
            name_prefix = f'{sub}_ses-shortstories01_task-listeningActive_run-{run:02d}_desc-preproc-audio'
            eeg_fname = op.join(eeg_sub_dir, f'{name_prefix}-_eeg.pickle')
            raw_fname = op.join(eeg_sub_dir, f'{name_prefix}-_raw.pickle')
            with open(eeg_fname, 'rb') as f:
                eeg_data = pickle.load(f)
            eeg_data = eeg_data / 1e3

            # downsample to 128Hz using mne
            raw = mne.io.RawArray(eeg_data, mne.create_info(ch_names=64, sfreq=1000, ch_types='eeg'))

            # delta: 0-4
            # theta: 4-8
            # alpha: 8-12
            # beta: 12-32
            # low_gamma: 32-45
            # high_gamma: 55-95
            if band == 'delta':
                raw.filter(0, 4)
            elif band == 'theta':
                raw.filter(4, 8)
            elif band == 'alpha':
                raw.filter(8, 12)
            elif band == 'beta':
                raw.filter(12, 32)
            elif band == 'low_gamma':
                raw.filter(32, 45)
            elif band == 'high_gamma':
                raw.filter(55, 95)

            eeg_data = raw.get_data()

            eeg_data = np.transpose(eeg_data)
            # each run contain 4 blocks, each block contain 50 pictures
            for tri, tr in enumerate(range(run * 4 - 4, run * 4)):
                for pic in range(50):
                    eegtmp = eeg_data[
                             int((tri * 35 + pic * 0.5) * 1000 + 10000):int(
                                 (tri * 35 + pic * (0.5) + 0.5) * 1000 + 10000),
                             0:64]
                    eegdata[tr, pic, :, :] = np.concatenate((eegtmp, eegtmp,), axis=1)

        eegdata = eegdata.astype(np.float32)
        eeg_data_label = {'EEG': eegdata, 'label': eeglabel}
        # save the eeg_data_label pt file
        eeg_savedir = op.join(save_dir, f'{sub_name[sub_i]}.pkl')
        with open(eeg_savedir, 'wb') as f:
            pickle.dump(eeg_data_label, f)