import torch
from attrdict import AttrDict
import numpy as np
import scipy.io as sio
from sklearn import preprocessing
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import h5py
import os.path as osp

class SEED_ori():
    def __init__(self, file, time_window=200, type='train'):
        eeg_path = '../data/SEED/Preprocessed_EEG'
        file_path = os.path.join(eeg_path, file)
        data = sio.loadmat(file_path)
        keys = list(data.keys())
        head1 = keys[0]
        head2 = keys[1]
        head3 = keys[2]
        keys.remove(head1)
        keys.remove(head2)
        keys.remove(head3)
        self.time_window = time_window
        X_train = []
        X_val = []
        X_test = []
        y_train = []
        y_val = []
        y_test = []
        for i in range(9):
            eeg = data[keys[i]].astype('float')
            # print(eeg.shape)
            eeg_length = eeg.shape[1]
            self.label = sio.loadmat(eeg_path + '/label.mat')['label'][0][i]
            self.label = self.label.astype('float')
            if self.label == -1:
                self.label = 2
            start_ind = 0
            while start_ind + self.time_window <= eeg_length:
                X_train.append(eeg[:, start_ind: start_ind + self.time_window])
                y_train.append(self.label)
                start_ind += self.time_window
           

        for i in range(9, 15):
            eeg = data[keys[i]].astype('float')
            eeg_length = eeg.shape[1]
            self.label = sio.loadmat(eeg_path + '/label.mat')['label'][0][i]
            self.label = self.label.astype('float')
            if self.label == -1:
                self.label = 2
            start_ind = 0
            while start_ind + self.time_window <= eeg_length:
                X_test.append(eeg[:, start_ind: start_ind + self.time_window])
                y_test.append(self.label)
                start_ind += self.time_window

        X_train = np.stack(X_train, axis=0)
        X_test = np.stack(X_test, axis=0)
        y_train = np.array(y_train)
        y_test = np.array(y_test)

        X_train = normalize(X_train)
        X_test = normalize(X_test)


        if type == 'train':
            self.eeg = torch.from_numpy(X_train)
            self.label = torch.from_numpy(y_train)
            self.len = self.eeg.shape[0]
        else:
            self.eeg = torch.from_numpy(X_test)
            self.label = torch.from_numpy(y_test)
            self.len = self.eeg.shape[0]
    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.eeg[item], self.label[item]


class Motor_imagery_ori():
    def __init__(self, file, time_window=250, type='train'):
        X = []
        y = []
        self.time_window = time_window
        if type == 'train':
            eeg_path = '../data/Motor_imagery/train'
            data = sio.loadmat(osp.join(eeg_path, f'A0{file}T.mat'))
            if file == 4:
                eeg_data = data['data'][0][1:]
            else:
                eeg_data = data['data'][0][3:]

            for i in range(6):
                eeg_run = eeg_data[i]
                eeg_run_X = eeg_run['X'][0][0]
                eeg_run_X = eeg_run_X[:,:22].transpose(1, 0)
                eeg_run_trial = eeg_run['trial'][0][0]
                eeg_run_y = eeg_run['y'][0][0]
                for j in range(len(eeg_run_trial)):
                    for k in range(2, 6): #2, 6
                        start_ind = eeg_run_trial[j][0] + k * time_window
                        end_ind = start_ind + time_window
                        X.append(eeg_run_X[:, start_ind: end_ind])
                        y.append(eeg_run_y[j][0] - 1)


        else:
            eeg_path = '../data/Motor_imagery/val_test'
            data = sio.loadmat(osp.join(eeg_path, f'A0{file}E.mat'))
            eeg_data = data['data'][0][3:]

            for i in range(6):
                eeg_run = eeg_data[i]
                eeg_run_X = eeg_run['X'][0][0]
                eeg_run_X = eeg_run_X[:, :22].transpose(1, 0)
                eeg_run_trial = eeg_run['trial'][0][0]
                eeg_run_y = eeg_run['y'][0][0]
                for j in range(len(eeg_run_trial)):
                    for k in range(2, 6):
                        start_ind = eeg_run_trial[j][0] + k * time_window
                        end_ind = start_ind + time_window
                        X.append(eeg_run_X[:, start_ind: end_ind])
                        y.append(eeg_run_y[j][0] - 1)

        X = np.stack(X, axis=0)
        y = np.array(y)
        X = normalize(X)

        self.eeg = torch.from_numpy(X)
        self.label = torch.from_numpy(y)
        self.len = self.eeg.shape[0]

    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.eeg[item], self.label[item]


class monitor_erp():
    def __init__(self, file, time_window=512, type='train'):
        self.time_window = time_window
        eeg_path = '../data/monitor_erp'
        if type == 'train':
            data = sio.loadmat(osp.join(eeg_path, f'Subject0{file}_s1.mat'))
        else:
            data = sio.loadmat(osp.join(eeg_path, f'Subject0{file}_s2.mat'))
        X = []
        y = []
        eeg_list = data['run'][0]
        for i in range(10):
            eeg = eeg_list[i]['eeg'][0][0].transpose(1, 0)
           
            EVENT = eeg_list[i]['header'][0][0]['EVENT'][0][0]
            trigger = EVENT['POS'][0][0]
            for j in range(trigger.shape[0]-1):
                start_ind = trigger[j][0]
                end_ind = trigger[j+1][0]
                while start_ind + self.time_window < end_ind:
                    X.append(eeg[:, start_ind: start_ind + self.time_window])
                    start_ind += self.time_window
        X = np.stack(X, axis=0)
        
        X = normalize(X)

        self.eeg = torch.from_numpy(X)
        self.len = self.eeg.shape[0]
    
    def __len__(self):
        return self.len

    def __getitem__(self, item):
        return self.eeg[item]
      


            
def normalize(data):
    mean = np.mean(data, axis=-1, keepdims=True)
    std = np.std(data, axis=-1, keepdims=True)

    std[std == 0] = 1

    normalized_data = (data - mean) / std
    return normalized_data


if __name__ == '__main__':
    print(torch.__version__)
        
 



        
            

   


