import torch
from torch.utils.data import Dataset
import torchvision

class RandBullShitGo(Dataset):
    def __init__(self, data_path, transform=None):
        self.data = torchvision.datasets.CIFAR10(data_path, train=True)
        self.to_tensor = torchvision.transforms.ToTensor()
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        X, y = self.data[index]
        X = self.to_tensor(X) 
        return {"input": X, "label": y}