import torch


def generate_measure(n_sample, n_dim, slice=True):
    """
    Generate a batch of probability measures in R^d sampled over
    the unit square
    :param n_batch: Number of batches
    :param n_sample: Number of sampling points in R^d
    :param n_dim: Dimension of the feature space
    :return: A (Nbatch, Nsample, Ndim) torch.Tensor
    """
    m = torch.distributions.exponential.Exponential(1.0)
    if slice:
        a = m.sample(torch.Size([n_sample, n_dim]))
        a = a / a.sum(dim=1)[:,None]
    else:
        a = m.sample(torch.Size([n_sample]))
        a = a / a.sum()
    m = torch.distributions.uniform.Uniform(0.0, 1.0)
    x = m.sample(torch.Size([n_sample, n_dim]))
    return a, x
