from torch.utils.data import Dataset


class BatchDataDict(Dataset):
    """Pytorch Dataset subclass that takes a dictionary of format {'<batch_idx>': <batch_data>}."""

    def __init__(self, X, y=None):
        """X is the dictionary dataset and y is ignored.

        Parameters
        ----------
        X : dict
         Dictionary of format {'<batch_idx>': <batch_data>}
        y : None
            Ignored.
        """
        self.data_dict = X

    def __len__(self):
        return len(self.data_dict)

    def __getitem__(self, idx):
        # This returns the batch at idx instead of a single element.
        return self.data_dict[idx]
