import torch
import os
import numpy as np

a = torch.load('../data/EEG/test_0.pt')
print(a['samples'].shape)
print(a['labels'].shape)
domain_divide = [[0,1,2,3,4],[5,6,7,8,9],[10,11,12,13,14],[15,16,17,18,19]]

data_path = '../data/EEG'

data = []
y = []
for i in range(20):
    test = torch.load(os.path.join(data_path, 'test_'+str(i)+'.pt'))
    samples = test['samples']
    labels = test['labels'].reshape(-1,1)
    person_id = np.ones(samples.shape[0])*i
    position_id = np.ones(samples.shape[0])*0
    print(labels.shape, person_id.shape, position_id.shape)
    info = np.concatenate((labels, person_id.reshape(-1,1), position_id.reshape(-1,1)), axis=1)
    if i == 0:
        data = samples.transpose(0,2,1)
        y = info
    else:
        data = np.concatenate((data, samples.transpose(0,2,1)), axis=0)
        y = np.concatenate((y, info), axis=0)


print(data.shape)
print(y.shape)
np.save('../data/EEG/eeg_x.npy', data)
np.save('../data/EEG/eeg_y.npy', y)
    

