import torch


def Sampler(dim, xmin=-1, xmax=1, T=0.3, is_init=False):
    
    def random_sampler(N):
        if is_init:
            ts = torch.zeros((N, 1))  # zero vector
        else:
            ts = T * torch.rand((N, 1)) # [0, T]
            
        xs = (xmax - xmin) * torch.rand((N, dim)) + xmin
        pnts = torch.cat([ts, xs], dim=1)
        return pnts.requires_grad_(True)
    
    return random_sampler