import torch
from torch.utils.data import Dataset as BaseDataLoader


class MatchedForgetRetainDataset(BaseDataLoader):
    def __init__(self, forget_dataset, retain_dataset):
        self.forget_dataset = forget_dataset
        self.retain_dataset = retain_dataset
        self.retain_len = len(retain_dataset)

    def __len__(self):
        return len(self.forget_dataset)

    def __getitem__(self, idx):
        forget_item = self.forget_dataset[idx]
        retain_idx = idx % self.retain_len
        retain_item = self.retain_dataset[retain_idx]

        return {
            "forget": forget_item,
            "retain": retain_item,
        }


class MatchedForgetRetainRandomDataset(BaseDataLoader):
    def __init__(self, forget_dataset, retain_dataset):
        self.forget_dataset = forget_dataset
        self.retain_dataset = retain_dataset
        self.retain_len = len(retain_dataset)

    def __len__(self):
        return len(self.forget_dataset)

    def __getitem__(self, idx):
        forget_item = self.forget_dataset[idx]
        retain_idx = torch.randint(self.retain_len, (1,)).item()
        retain_item = self.retain_dataset[retain_idx]

        return {
            "forget": forget_item,
            "retain": retain_item,
        }
