from tasks.modadd.classification.datasets import Data
import torch

class Dataset(torch.utils.data.Dataset):
    def __init__(self, data:Data, split="train", device="cuda:0"):
        self.data    = data
        self.device  = device
        self.src = getattr(data, split+"_src")
        self.tgt = getattr(data, split+"_tgt")

    def __len__(self):
        return self.src.size(0)

    def __getitem__(self, idx):
        return self.src[idx], self.tgt[idx]

    def collate(self, data):
        return {
            "src" : torch.stack([d[0] for d in data]).to(self.device),
            "tgt" : torch.stack([d[1] for d in data]).to(self.device),
        }


