# @package _global_

cdsb: False

# data 
Dataset: cifar10
data:
  dataset: "CIFAR10"
  image_size: 32
  channels: 3
  random_flip: true

# transfer
transfer: False
Dataset_transfer: mnist


final_adaptive: False
adaptive_mean: False
mean_final: torch.zeros([${data.channels}, ${data.image_size}, ${data.image_size}])
var_final: 1 * torch.ones([${data.channels}, ${data.image_size}, ${data.image_size}])
load: True

# device
device: cuda
num_workers: 2
pin_memory: True

# logging
log_stride: 100
gif_stride: 100000
plot_npar: 100
test_npar: 50000
test_batch_size: 500
cache_npar: 250000
cache_batch_size: 1250
num_repeat_data: 1  # 4
cache_refresh_stride: ${num_iter}

# training
optimizer: AdamW
use_prev_net: True
ema: True
ema_rate: 0.9999
grad_clipping: True
grad_clip: 1.0
batch_size: 128
num_iter: 500000
n_ipf: 100
lr: 0.0001
weight_decay: 0.01

num_steps: 200
