import datasets
import torch

class HellaswagDataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.dataset = datasets.load_dataset("Rowan/hellaswag", split=split, trust_remote_code=True)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        ctx = sample["ctx"]
        endings = sample["endings"]
        label = int(sample["label"])

        return ctx, endings, label