from sklearn.mixture import GaussianMixture, BayesianGaussianMixture


class Sampler:

    def __init__(self, data, config):
        self.data = data
        self.model = None
        self.n_components = None
        self.n_samples = config.num_nodes if config.n_samples == "auto" else config.n_samples

        self.fit_model(data, config) 

    def fit_model(self, data, config):

        if config.sampler_model == "gmm":
            self.n_components = config.sampler.gmm.n_components
            self.gmm(data)
            
        elif config.sampler_model == "dpgmm":
            self.n_components = config.sampler.dpgmm.n_components
            self.dpgmm(data)
        
        else:
            raise "Sampler not found!"

    def sample(self):
        return self.model.sample(self.n_samples)[0]
    
    def gmm(self, data):

        self.model = GaussianMixture(
            n_components=self.n_components,
        )
        self.model.fit(data)
    
    def dpgmm(self, data):

        self.model = BayesianGaussianMixture(
            n_components=self.n_components,  
        )
        self.model.fit(data)





