import ml_collections
import torch


def get_default_configs():
  config = ml_collections.ConfigDict()
  # training
  config.training = training = ml_collections.ConfigDict()
  config.training.batch_size = 128
  training.labeled_batch_size = 64
  training.n_iters = 130001
  training.likelihood_weighting = False
  training.continuous = True
  training.reduce_mean = True
  training.score_model = True
  training.clf_model = True
  training.score_path = ''
  training.denoise_augment = True
  
  # sampling
  config.sampling = sampling = ml_collections.ConfigDict()
  sampling.n_steps_each = 1
  sampling.noise_removal = True
  sampling.probability_flow = False
  sampling.snr = 0.16

  # data
  config.data = data = ml_collections.ConfigDict()
  data.dataset = 'MNIST'
  data.image_size = 32
  data.random_flip = False
  data.centered = False
  data.uniform_dequantization = False
  data.num_channels = 1
  data.labels_per_class = 100
  data.pu = False 
  data.pu_config = pu_config = ml_collections.ConfigDict()
  pu_config.use_classes = (0,1,2)
  pu_config.positive_classes = (0,1)

  # model
  config.model = model = ml_collections.ConfigDict()
  model.sigma_min = 0.01
  model.sigma_max = 50
  model.num_scales = 1000
  model.beta_min = 0.1
  model.beta_max = 20.
  model.dropout = 0.1
  model.embedding_type = 'fourier'
  
  # optimization
  config.optim = optim = ml_collections.ConfigDict()
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 2e-4
  optim.beta1 = 0.9
  optim.eps = 1e-8
  optim.warmup = 5000
  optim.grad_clip = 1.
  
  config.classifier = classifier = ml_collections.ConfigDict()
  classifier.name = "classifier"
  classifier.model = "wrn28_2_cond"
  classifier.nf = 32
  classifier.embedding_type = 'fourier'
  classifier.fourier_scale = 16
  classifier.classes = 10
  classifier.norm = "group" # none, batch, group
  classifier.noise_list = [0.0001, 0.3333, 0.6666, 0.9999]
  classifier.time = True
  
  config.seed = 42
  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

  return config