from torch.utils.data import Dataset, DataLoader
from collections import Counter

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        x = self.data[idx]
        label = Counter(x['labels']).most_common(1)[0][0]
        return x['sentence1'], x['sentence2'], label

    def __len__(self):
        return len(self.data)

def get_dataloader(dataset, batch_size, shuffle=False):
    return DataLoader(dataset, batch_size, shuffle)
