import torch
import numpy as np
from torch.utils.data import Dataset


class DatasetWithIndices(Dataset):
    def __init__(self, dataset, embedding_dim=None, with_labels=False, seed=0):
        self.dataset = dataset
        self.with_labels = with_labels
        self.permutation = np.random.permutation(len(dataset))
        if embedding_dim is not None:
            self.embedding = torch.randn(
                len(dataset), embedding_dim, requires_grad=False
            )

    def __getitem__(self, n):
        x, y = self.dataset[n]
        # just to be safe!
        n = torch.tensor([self.permutation[n]])
        if hasattr(self, "embedding"):
            n = self.embedding[n]
        if self.with_labels:
            return x, y, n
        return x, n

    def __len__(self):
        return len(self.dataset)
