from torch.utils.data import Dataset


class CustomDataset(Dataset):
    def __init__(self, data, target, transform=None, target_transform=None):
        self.transform = transform
        self.target_transform = target_transform

        self.data = data
        self.targets = target

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):

        sample_data = self.data[idx]
        label = self.targets[idx]
        if self.transform is not None:
            sample_data = self.transform(sample_data)

        if self.target_transform is not None:
            label = self.target_transform(label)
        return sample_data, label, idx
