from torch.utils.data import Dataset


class GPUDataset(Dataset):
    def __init__(self, device="cuda"):
        super(GPUDataset, self).__init__()
        # self.data = torch.ones(1000)
        self.data = self.data.to(device)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        element = self.data[i]
        return element
