import datasets
import torch

class Banking77Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.dataset = datasets.load_dataset("PolyAI/banking77", split=split, trust_remote_code=True)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        text = sample["text"]
        label = sample["label"]
        return text, label

    @property
    def num_classes(self):
        return 77