import torch
import datasets


class STSBDataset(torch.utils.data.Dataset):

    def __init__(self,split, normalize_labels=False):
        super().__init__()
        assert split in ["train","validation","test"], f"Split must be one of ['train','validation','test'], got {split}"

        self.stsb_dataset = datasets.load_dataset("sentence-transformers/stsb")
        self.split = split
        self.dataset = self.stsb_dataset[split]

        # Scale labels (ratings between 1-5) to [0,1]
        if normalize_labels:
            self.dataset = self.normalize_labels(self.dataset)

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]

        sentence_1 = sample["sentence1"]
        sentence_2 = sample["sentence2"]
        label = sample["score"]
        return sentence_1, sentence_2, label

    def normalize_labels(self, dataset):
        dataset = dataset.map(lambda example: {"score": example["score"]/5.0})
        return dataset