from torch.utils.data import Dataset


class DistributionSamplingDataset(Dataset):
    def __init__(self, dist, num_samples, return_transform=False):
        self.dist = dist
        self.num_samples = num_samples
        self.return_transform = return_transform

        if return_transform:
            self.data, self.transform = self.dist.sample(
                (self.num_samples,), return_transform=True
            )

        else:
            self.data = self.dist.sample((self.num_samples,))

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        if self.return_transform:
            return self.data[idx], self.transform[idx]
        else:
            return self.data[idx]
