import torch
import config as cfg
from torch.utils.data import Dataset

class end2end_dataset(Dataset):
    def __init__(self, eeg, label):
        self.eeg = eeg
        self.label = label

    def __len__(self):
        return self.eeg.shape[0]

    def __getitem__(self, index):
        x = self.eeg[index]
        y = self.label[index] # the label is same across the specific length
        x = torch.tensor(x, dtype=torch.float32)
        y = torch.tensor(y, dtype=torch.long)
        return x, y

class domain_dataset(Dataset):
    def __init__(self, eeg, label,domain):
        self.eeg = eeg
        self.label = label
        self.domain = domain

    def __len__(self):
        return self.eeg.shape[0]

    def __getitem__(self, index):
        x = self.eeg[index]
        y = self.label[index] # the label is same across the specific length
        z = self.domain[index]
        x = torch.tensor(x, dtype=torch.float32)
        y = torch.tensor(y, dtype=torch.long)
        z = torch.tensor(z, dtype=torch.long)
        return x, y, z

