import numpy as np

from toy_data import generate_offline_data

dim = 12
generate_offline_data(rng=np.random.default_rng(),
                 train_num=10000, dim=dim,
                 centers=[-1.2 * np.ones(dim), -0.8 * np.ones(dim), 0.8 * np.ones(dim), 1.2 * np.ones(dim)],
                 covs=[1 * np.eye(dim), 1 * np.eye(dim)],
                 source_priors=[0.9, 0.1],
                 radius=10,
                 cls_num=2,
                 output='./',
                 fname='offline_data_syn.pt')
