from torch.utils.data import Dataset

class RequestDataset(Dataset):
    def __init__(self, task, limit, corrupted, tokenizer=None):
        self.docs = task.get_docs_in_templates(limit, corrupted, tokenizer)
    def __getitem__(self, idx):
        return self.docs[idx]
    def __len__(self):
        return len(self.docs)