from tasks.modadd.contrastive.datasets import Data
import torch, random, torch

class Dataset(torch.utils.data.Dataset):
    def __init__(self, data:Data, split="train", device="cuda:0"):
        self.data    = data
        self.device  = device
        self.dataset = getattr(data, split)

    def __len__(self):
        return self.dataset.size(0)

    def __getitem__(self, idx):
        context = self.dataset[idx].clone()
        symbol_idx = random.randint(0,context.size(0)-2)
        symbol = context[symbol_idx].clone()
        context[symbol_idx] = self.data.mask_symbol
        return symbol, context

    def collate(self, data):
        return {
            "symbol"  : torch.stack([d[0] for d in data]).to(self.device), 
            "context" : torch.stack([d[1] for d in data]).to(self.device),
        }


