from torch.utils.data import Dataset
import torch


class EmbeddingDataset(Dataset):
    def __init__(self, img_path, cap_path):
        super().__init__()
        self.imgs = torch.load(img_path).float()
        self.caps = torch.load(cap_path).float()

        # make sure the lengths match
        assert self.imgs.size(0) == self.caps.size(0)

    def __len__(self):
        return self.caps.size(0)

    def __getitem__(self, cap_id):
        img = self.imgs[cap_id]
        text = self.caps[cap_id]
        return img, text

    def set_size(self, size):
        # choose the first size images and captions
        self.imgs = self.imgs[:size]
        self.caps = self.caps[:size]
        print(f"Truncated dataset to {size} samples", flush=True)
