import torch


class TextDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['idx'] = idx
        if self.labels:
            item["labels"] = self.labels[idx]

        return item

    def __len__(self):
        return len(self.encodings["input_ids"])
