import torch
import os
import pickle

class TUHLoader(torch.utils.data.Dataset):
    def __init__(self, root, target=False):
        self.root = root
        self.files = os.listdir(root)
        self.target = target
    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        sample = pickle.load(open(os.path.join(self.root, self.files[index]), "rb"))
        X = sample["X"]
        X = torch.FloatTensor(X)
        if(self.target):
            Y = sample["y"]
            return X, Y
        return X