# Based on implementations from the 4M repo: https://github.com/apple/ml-4m/
import numpy as np
from torch.utils.data import Dataset


class RepeatedDatasetWrapper(Dataset):
    def __init__(self, original_dataset, num_repeats):
        """
        Dataset wrapper that repeats the original dataset n times.

        Args:
            original_dataset (torch.utils.data.Dataset): The original dataset to be repeated.
            num_repeats (int): The number of times the dataset should be repeated.
        """
        self.original_dataset = original_dataset
        self.num_repeats = num_repeats

    def __getitem__(self, index):
        """
        Retrieve the item at the given index.
        
        Args:
            index (int): The index of the item to be retrieved.
        """
        original_index = index % len(self.original_dataset)
        return self.original_dataset[original_index]

    def __len__(self):
        """
        Get the length of the dataset after repeating it n times.
        
        Returns:
            int: The length of the dataset.
        """
        return len(self.original_dataset) * self.num_repeats


class SubsampleDatasetWrapper(Dataset):
    def __init__(self, original_dataset, dataset_size, seed=0, return_orig_idx=False):
        """
        Dataset wrapper that randomly subsamples the original dataset.

        Args:
            original_dataset (torch.utils.data.Dataset): The original dataset to be subsampled.
            dataset_size (int): The size of the subsampled dataset.
            seed (int): The seed to use for selecting the subset of indices of the original dataset.
            return_orig_idx (bool): Whether to return the original index of the item in the original dataset.
        """
        self.original_dataset = original_dataset
        self.dataset_size = dataset_size or len(original_dataset)
        self.return_orig_idx = return_orig_idx
        np.random.seed(seed)
        self.indices = np.random.permutation(len(self.original_dataset))[:self.dataset_size]

    def __getitem__(self, index):
        """
        Retrieve the item at the given index.
        
        Args:
            index (int): The index of the item to be retrieved.
        """
        original_index = self.indices[index]
        sample = self.original_dataset[original_index]
        return sample, original_index if self.return_orig_idx else sample

    def __len__(self):
        """
        Get the length of the dataset after subsampling it.
        
        Returns:
            int: The length of the dataset.
        """
        return len(self.indices)
