import h5py
import numpy as np
from pathlib import Path
import scipy.io as sio
import os
import mne
import os.path as osp
class h5Dataset:
    def __init__(self, path:Path, name:str,mode:str='a') -> None:
        self.__name = name
        if mode !='a' and mode !='r':
            raise Exception(f'can not set mode to {mode}, only "a" or "r"')
        self.__f = h5py.File(path / f'{name}.hdf5', mode)
        
    def get_group_names(self):
        return list(self.__f.keys())

    def get_dataset_names_from_group(self,grpName:h5py.Group):
        return list(self.__f[grpName].keys())

    def get_group(self,grpName:h5py.Group):
        return self.__f[grpName]

    def get_dataset_from_group(self,grpName:h5py.Group,dsName:h5py.Dataset):
        return self.__f[grpName][dsName]
    
    def addGroup(self, grpName:str):
        print(self.get_group_names())
        return self.__f.create_group(grpName)
    
    def addDataset(self, grp:h5py.Group, dsName:str, arr:np.array, chunks:tuple):
        if chunks is not None:
            return grp.create_dataset(dsName, data=arr, chunks=chunks)
        else:
            return grp.create_dataset(dsName, data=arr)

    def addAttributes(self, src:'h5py.Dataset|h5py.Group', attrName:str, attrValue):
        src.attrs[f'{attrName}'] = attrValue


    def save(self):
        self.__f.close()
    
    @property
    def name(self):
        return self.__name


def normalize(data):
    mean = np.mean(data, axis=-1, keepdims=True)
    std = np.std(data, axis=-1, keepdims=True)

    std[std == 0] = 1

    normalized_data = (data - mean) / std
    return normalized_data


class SEED_ori():
    def __init__(self, file_list, time_window=200, type='train'):
        eeg_path = '../data/SEED/Preprocessed_EEG'
        X_train = []
        X_val = []
        X_test = []
        y_train = []
        y_val = []
        y_test = []
        for file in file_list:
            print(file)
            file_path = os.path.join(eeg_path, file)
            data = sio.loadmat(file_path)
            keys = list(data.keys())
            head1 = keys[0]
            head2 = keys[1]
            head3 = keys[2]
            keys.remove(head1)
            keys.remove(head2)
            keys.remove(head3)
            self.time_window = time_window
            
            for i in range(9):
                eeg = data[keys[i]].astype('float')
                # print(eeg.shape)
                eeg_length = eeg.shape[1]
                self.label = sio.loadmat(eeg_path + '/label.mat')['label'][0][i]
                self.label = self.label.astype('float')
                if self.label == -1:
                    self.label = 2
                start_ind = 0
                while start_ind + self.time_window <= eeg_length:
                    X_train.append(eeg[:, start_ind: start_ind + self.time_window])
                    y_train.append(self.label)
                    start_ind += self.time_window
           

        X_train = np.stack(X_train, axis=0)
        y_train = np.array(y_train)
        X_train = normalize(X_train)

 
        self.eeg = X_train
        self.label = y_train
        self.len = self.eeg.shape[0]
        
    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.eeg[item], self.label[item]


class SEEDIV_ori():
    def __init__(self, file_list, time_window=200, type='train'):
        eeg_path = '../data/SEED_IV/eeg_raw_data'
        X_train = []
        print(file_list)
        for file in file_list:
            print(file)
            file_path = os.path.join(eeg_path, file)
            new_file_list = os.listdir(file_path)
            for new_file in new_file_list:
                new_file_path = os.path.join(file_path, new_file)
                data = sio.loadmat(new_file_path)
                keys = list(data.keys())
                head1 = keys[0]
                head2 = keys[1]
                head3 = keys[2]
                keys.remove(head1)
                keys.remove(head2)
                keys.remove(head3)
                self.time_window = time_window
                
                for i in range(16):
                    eeg = data[keys[i]].astype('float')
                    eeg_length = eeg.shape[1]
                    start_ind = 0
                    while start_ind + self.time_window <= eeg_length:
                        X_train.append(eeg[:, start_ind: start_ind + self.time_window])
                    
                        start_ind += self.time_window
        

        X_train = np.stack(X_train, axis=0)  
        X_train = normalize(X_train)

       
        self.eeg = X_train
        self.len = self.eeg.shape[0]
       
    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.eeg[item]


class SEEDVII_ori():
    def __init__(self, file_list, time_window=200):
        eeg_path = '../data/SEED_VII/EEG_preprocessed'
        self.time_window = time_window
        X_train = []
        for file in file_list:
            file_path = os.path.join(eeg_path, file)
            data = sio.loadmat(file_path)
            keys = list(data.keys())
            keys.remove('__header__')
            keys.remove('__version__')
            keys.remove('__globals__')
            for ind in keys:
                eeg = data[ind]
                eeg_length = eeg.shape[1]
                start_ind = 0
                while start_ind + self.time_window < eeg_length:
                    X_train.append(eeg[:, start_ind: start_ind + self.time_window])
                    start_ind += self.time_window

        X_train = np.stack(X_train, axis=0)
        X_train = normalize(X_train)

 
        self.eeg = X_train
        self.len = self.eeg.shape[0]
        
    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.eeg[item]


        pass

class ERP2015a():
    def __init__(self, file_list, time_window=512):
        eeg_path = '../data/ERP2015a'
        self.time_window = time_window
        channel_data = os.path.join(eeg_path, 'Header.mat')
        channel_data = sio.loadmat(channel_data)
        channel = channel_data['Header']
        channel = np.delete(channel, [0, -2, -1], axis=1)
        self.channel = []
        for i in range(channel.shape[1]):
            self.channel.append(channel[0][i][0])
        self.channel = np.array(self.channel, dtype=h5py.string_dtype())

        X_train = []
        for file in file_list:
            file_path = os.path.join(eeg_path, file)
            data = sio.loadmat(file_path)
            eeg = np.delete(data['DATA'], [0, -2, -1], axis=1).transpose(1, 0)
            eeg = mne.filter.filter_data(eeg, sfreq=512, l_freq=0.1, h_freq=75)
            eeg = mne.filter.notch_filter(eeg, Fs=512, freqs=50)
            # print(eeg.shape)
            eeg_length = eeg.shape[1]
            start_ind = 0
            while start_ind + self.time_window <= eeg_length:
                X_train.append(eeg[:, start_ind: start_ind + self.time_window])
                start_ind += self.time_window
        X_train = np.stack(X_train, axis=0)  
        X_train = normalize(X_train)

       
        self.eeg = X_train
        self.len = self.eeg.shape[0]
    
    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.eeg[item]

class individual_imagery():
    def __init__(self, file_list, time_window=256):
        self.time_window = time_window
        eeg_path = '../data/Individual imagery'
        X_train = []
        for file in file_list:
            file_path = os.path.join(eeg_path, file)
            data = sio.loadmat(file_path)
            eeg = data['data'][0][0]['X'][0][0].transpose(1, 0)
            eeg = mne.filter.filter_data(eeg, sfreq=256, l_freq=0.1, h_freq=75)
            eeg_length = eeg.shape[1] - 256 * 60
            start_ind = 256 * 60
            while start_ind < eeg_length:
                X_train.append(eeg[:, start_ind: start_ind + time_window])
                start_ind += self.time_window
            
            eeg = data['data'][0][1]['X'][0][0].transpose(1, 0)
            eeg = mne.filter.filter_data(eeg, sfreq=256, l_freq=0.1, h_freq=75)
            eeg_length = eeg.shape[1] - 256 * 60
            start_ind = 256 * 60
            while start_ind < eeg_length:
                X_train.append(eeg[:, start_ind: start_ind + time_window])
                start_ind += self.time_window
        X_train = np.stack(X_train, axis=0)  
        X_train = normalize(X_train)

       
        self.eeg = X_train
        self.len = self.eeg.shape[0]
    
    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.eeg[item]
            
class center_spller():
    def __init__(self, file_list, time_window=250):
        eeg_path = '../data/Speller'
        self.time_window = time_window
        X_train = []
        for file in file_list:
            file_path = os.path.join(eeg_path, file)
            data = sio.loadmat(file_path)['data']
            eeg = data[0][0]['X'][0][0].transpose(1, 0)
            eeg_length = eeg.shape[1]
            trial = data[0][0]['trial'][0][0]
            start_ind = trial[0][0]
            while start_ind + self.time_window < eeg_length - self.time_window * 60:
                X_train.append(eeg[:, start_ind: start_ind + self.time_window])
                start_ind += self.time_window
            
            channels = data[0][0]['channels'][0][0]

        self.channel = []
        for i in range(channels.shape[1]):
            self.channel.append(channels[0][i][0])
        self.channel = np.array(self.channel, dtype=h5py.string_dtype())

        X_train = np.stack(X_train, axis=0)
        X_train = normalize(X_train)

 
        self.eeg = X_train
        self.len = self.eeg.shape[0]
        
    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.eeg[item]


class RSVP():
    def __init__(self, file_list, time_window=200):
        self.time_window = time_window
        eeg_path = '../data/RSVP'
        X_train = []
        for file in file_list:
            file_path = os.path.join(eeg_path, file)
            data = sio.loadmat(file_path)['data']
            eeg = data[0][0]['X'][0][0].transpose(1, 0)
            if eeg.shape[0] != 63:
                continue
            eeg_length = eeg.shape[1]
            trial = data[0][0]['trial'][0][0]
            start_ind = trial[0][0]
            while start_ind + self.time_window < eeg_length - self.time_window * 60:
                X_train.append(eeg[:, start_ind: start_ind + self.time_window])
                start_ind += self.time_window
            
            channels = data[0][0]['channels'][0][0]

        self.channel = []
        for i in range(channels.shape[1]):
            self.channel.append(channels[0][i][0])
        self.channel = np.array(self.channel, dtype=h5py.string_dtype())

        X_train = np.stack(X_train, axis=0)
        X_train = normalize(X_train)

 
        self.eeg = X_train
        self.len = self.eeg.shape[0]
        
    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.eeg[item]

class musicBCI():
    def __init__(self, file_list, time_window=200):
        self.time_window = time_window
        eeg_path = '../data/musicBCI'
        X_train = []
        for file in file_list:
            file_path = os.path.join(eeg_path, file)
            data = sio.loadmat(file_path)['data']
            eeg = data[0][0]['X']
            eeg = np.delete(eeg, [-1], axis=1)
            eeg = eeg.transpose(1, 0)
            eeg = mne.filter.filter_data(eeg, sfreq=200, l_freq=0.1, h_freq=75)
            eeg = mne.filter.notch_filter(eeg, Fs=200, freqs=50)
            start_ind = 4000
            eeg_length = eeg.shape[1]
            while start_ind + self.time_window < eeg_length - 200 * 60:
                X_train.append(eeg[:, start_ind: start_ind + self.time_window])
                start_ind += self.time_window

            channels = data[0][0]['clab']
        self.channel = []
        for i in range(63):
            self.channel.append(channels[0][i][0])
        self.channel = np.array(self.channel, dtype=h5py.string_dtype())

        X_train = np.stack(X_train, axis=0)
        X_train = normalize(X_train)

 
        self.eeg = X_train
        self.len = self.eeg.shape[0]
        
    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.eeg[item]


class SHU_ori():
    def __init__(self, file_list, time_window=250):
        self.time_window = time_window
        eeg_path = '../data/SHU/mat'
        X = []
        y = []
        num = 0
        for file in file_list:
            num += 1
            print(num)
            for session in range(1, 6):
                data = sio.loadmat(osp.join(eeg_path, file))
                eeg = data['data']
                eeg = np.array(eeg, dtype=np.float64)
                labels = data['labels']
                for i in range(eeg.shape[0]):
                    # eeg[i] = mne.filter.filter_data(eeg[i], sfreq=250, l_freq=0.3, h_freq=75)
                    # eeg[i] = mne.filter.notch_filter(eeg[i], Fs=250, freqs=50)
                    start_ind = 0
                    while start_ind + self.time_window <= eeg.shape[-1]:
                        X.append(eeg[i][:, start_ind: start_ind + self.time_window])
                        y.append(labels[0][i] - 1)
                        start_ind += self.time_window

        
        X = np.stack(X, axis=0)
        y = np.array(y)
        X = normalize(X)

        # _, T, D = X.shape
        # X = np.reshape(X, (X.shape[0], T * D))
        # scaler = preprocessing.StandardScaler().fit(X)
        # X = scaler.transform(X)
        # X = np.reshape(X, (X.shape[0], T, D))

        self.eeg = X
        self.label = y
        self.len = self.eeg.shape[0]

    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.eeg[item], self.label[item]


class Attention_ori():
    def __init__(self, file_list, time_window=200):
        self.time_window = time_window
        eeg_path = '../data/attention'
        X = []
        y = []
        for file in file_list:
            data = sio.loadmat(osp.join(eeg_path, file))['data']
            eeg = data['X'][0][0].transpose(1, 0)
            eeg = np.delete(eeg, [1, 6], axis=0)
            label = data['y'][0][0][0]
            trial = data['trial'][0][0][0]
           
            for i in range(len(trial) -1):
                start_ind = trial[i]
                while start_ind + self.time_window < trial[i+1]:
                    X.append(eeg[:, start_ind: start_ind + self.time_window])
                    y.append(label[i] - 1)
                    start_ind += self.time_window
            channels = data['clab'][0][0]
        self.channel = []
        for i in range(62):
            if i != 1 and i != 6:
                self.channel.append(channels[0][i][0])
        self.channel = np.array(self.channel, dtype=h5py.string_dtype())
           
        
        X = np.stack(X, axis=0)
        y = np.array(y)
        X = normalize(X)
        print(X.shape)


        self.eeg = X
        self.label = y
        self.len = self.eeg.shape[0]

    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.eeg[item], self.label[item]

seed_ch_list = ['Fp1', 'Fpz', 'Fp2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1',
               'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP7', 'CP5',
               'CP3','CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7',
               'PO5', 'PO3','POz', 'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'Oz', 'O2', 'CB2']

attention_ch_list = [
    'Iz', 'I1', 'AF3', 'AF4', 'I2', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8',
    'PO9', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'PO10', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8',
    'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'P9', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'P10',
    'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'O1', 'Oz', 'O2'
]

center_ch_list = [
    'Iz', 'Fp2', 'I1', 'AF3', 'AF4', 'I2', 'F9', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'F10',
    'PO9', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'PO10', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8',
    'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'P9', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'P10',
    'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'O1', 'Oz', 'O2'
]

ERP_ch_list = [
    'FP1', 'FP2', 'AFz', 'F7', 'F3', 'F4', 'F8',
    'FC5', 'FC1', 'FC2', 'FC6',
    'T7', 'C3', 'Cz', 'C4', 'T8',
    'CP5', 'CP1', 'CP2', 'CP6',
    'P7', 'P3', 'Pz', 'P4', 'P8',
    'PO7', 'O1', 'Oz', 'O2', 'PO8', 'PO9', 'PO10'
]

individual_ch_list = [
    'AFz', 'F7', 'F3', 'Fz', 'F4', 'F8',
    'FC3', 'FCz', 'FC4',
    'T3', 'C3', 'Cz', 'C4', 'T4',
    'CP3', 'CPz', 'CP4',
    'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8',
    'PO3', 'PO4', 'O1', 'O2'
]

music_ch_list = [
    'Fp1', 'Fp2', 'AF7', 'AF3', 'AF4', 'AF8',
    'F9', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'F10',
    'FT7', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'FT8',
    'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8',
    'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8',
    'P9', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'P10',
    'PO7', 'PO3', 'POz', 'PO4', 'PO8',
    'O1', 'Oz', 'O2'
]

RSVP_ch_list = [
    'Fp1', 'Fp2', 'AF3', 'AF4',
    'F9', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'F10',
    'FT7', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'FT8',
    'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8',
    'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8',
    'P9', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'P10',
    'PO9', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'PO10',
    'O1', 'Oz', 'O2'
]


shu_ch_list = ['FP1', 'FP2', 'FZ', 'F3', 'F4', 'F7', 'F8', 'FC1', 'FC2', 'FC5', 'FC6', 'CZ', 'C3', 'C4', 'T7', 'T8', 'A1', 'A2', 
                'CP1', 'CP2', 'CP5', 'CP6', 'PZ', 'P3', 'P4', 'P7', 'P8', 'PO3', 'PO4', 'OZ', 'O1', 'O2']

standard_channel_list = ['Fp1', 'Fpz', 'Fp2', 'Fp9', 'Fp10', 'Nz', 'AF1', 'AF2', 'AFz', 'AF3', 'AF4', 'AF5', 'AF6', 
                         'AF7', 'AF8', 'AF9', 'AF10', 'F1', 'F2', 'Fz', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'F9', 'F10',
                         'FC1', 'FC2', 'FCz', 'FC3', 'FC4', 'FC5', 'FC6', 'FT7', 'FT8', 'FT9', 'FT10', 'C1', 'C2', 'Cz',
                         'C3', 'C4', 'C5', 'C6', 'T7', 'T8', 'T9', 'T10', 'I1', 'I2', 'CP1', 'CP2', 'CPz', 'CP3', 'CP4',
                         'CP5', 'CP6', 'TP7', 'TP8', 'TP9', 'TP10', 'P1', 'P2', 'Pz', 'P3', 'P4', 'P5', 'P6', 'P7', 'P8',
                         'P9', 'P10', 'PO1', 'PO2', 'POz', 'PO3', 'PO4', 'PO5', 'PO6', 'PO7', 'PO8', 'PO9', 'PO10',
                         'O1', 'O2', 'Oz', 'O9', 'O10', 'Iz', 'CB1', 'CB2', 'A1', 'A2']



if __name__ == '__main__':
    # eeg_path = '../data/SHU/mat'
    # file_list = os.listdir(eeg_path)
    # data = SHU_ori(file_list)

    # channels = shu_ch_list
    # channels = np.array(channels, dtype=h5py.string_dtype())

    # print(data.eeg.shape)
    
    
    # save_path = Path('../data')
    # h5_dataset = h5Dataset(save_path, 'all_data')
    # group = h5_dataset.addGroup('SHU')
    # ds = h5_dataset.addDataset(group, 'data', data.eeg, chunks=(1, 32, 250))
    # h5_dataset.addAttributes(ds, 'channels', channels)
    # h5_dataset.save()


    # standard_channel_list = [s.upper() for s in standard_channel_list]
    h5_path = "../data/all_data.hdf5"
    with h5py.File(h5_path, 'r') as f:
        print(list(f.keys()))
        keys = list(f.keys())
        for key in keys:
            print(key)
            print(f[key]['data'].shape)
            channels = f[key]['data'].attrs['channels']
            print(channels)
            # channels = [s.upper() for s in channels]
            # channel_index = [standard_channel_list.index(item) for item in channels]
            # print(channel_index)
        
