import numpy as np

class Sampler:

    def __init__(self):
        pass
    
    @staticmethod
    def uniform(mean, sigma, nsample, reshape = False):
        '''
        Args:
        - mean:     [ ndim ] np
        - sigma:    [ ndim ] np, width on every dimension
        Output:
        - out:      [ nsample, ndim ]
        '''
        if np.isscalar(mean):
            mean = np.array([mean])
            sigma = np.array([sigma])
        mean = np.asarray(mean)
        sigma = np.asarray(sigma)
        assert len(mean) == len(sigma)
        out_list = []
        for i in range(len(mean)):
            out = (np.random.uniform(size=[nsample, 1]) - 0.5) * 2 * sigma[i] + mean[i]
            out_list.append(out)
        out = np.concatenate(out_list, axis=1) # [ nsample, ndim ] np
        if out.shape[1] == 1 and reshape:
            out = out.reshape(-1) 
        return out
    
    @staticmethod
    def gaussian(mean, width, sigma, nsample, squeeze = False):
        '''
        Args:
        - mean:     [ ndim ] np
        - width:    [ ndim ] np, clipping boundary of the Gaussian distribution
        - sigma:    [ ndim ] np, variance
        Return:
        - data:     [ nsample, ndim ] np
        '''
        mean = np.asarray(mean) 
        width = np.asarray(width) 
        sigma = np.asarray(sigma) 

        ndim = len(mean)
        data = np.random.randn(nsample, ndim) * sigma[None, :] + mean[None, :] # [ nsample, ndim ] np
        data = np.clip(data,
                       a_min = mean - width,
                       a_max = mean + width)
        return data # [ nsample, ndim ] np