from configs.default_cifar10_configs import get_default_configs
import ml_collections

def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True
  training.n_iters = 1000001
  training.batch_size = 100
  training.log_freq = 50
  training.eval_freq = 500

  # data
  data = config.data
  data.centered = False
  data.random_flip = False
  data.dataset = 'CIFAR10'
  data.image_size = 32
  data.num_channels = 3

  # eval
  eval = config.eval = ml_collections.ConfigDict()
  eval.batch_size = 200
  eval.num_samples = 51000
  eval.num_samples_all = 50000
  eval.mode = 'class'
  eval.batch_splits = 5
  eval.nearest_k = 3

  return config