from sklearn.datasets import load_digits, fetch_openml; import numpy as np

def make_mnist(n_samples, n_classes=2):
  X,y = fetch_openml("mnist_784", return_X_y=True, as_frame=False)
  idxs = [np.argwhere(y == str(i))[:int(n_samples/n_classes)] for i in range(n_classes)]
  X = np.concatenate([X[idx].squeeze() for idx in idxs], axis=0)/255 #* np.pi # normalize
  y = (np.concatenate([y[idx].squeeze() for idx in idxs], axis=0).astype(int))
  return X, y
