import numpy as np
from scipy.stats import bernoulli


def load_gmm_data(n, d, K, seed):
    np.random.seed(seed)
    mu = np.array([[2] * d, [-1] * d])
    sigma2 = np.ones((K, d))
    z = bernoulli.rvs(p=0.5, size=n)
    y = np.random.randn(n, d) * np.sqrt(sigma2[z, :]) + mu[z, :]
    return y
