import torch
from torch.utils.data import Dataset


class local_dataset(Dataset):
    def __init__(self, X,Y):
        assert X.shape[0] == Y.shape[0]
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.Y)
    
    def __getitem__(self, idx):
        return self.X[idx],self.Y[idx]
    