# @package _global_

cdsb: False

# data 
Dataset: mnist_2
data:
  dataset: "ColoredMNIST"
  image_size: 32
  channels: 3

# transfer
transfer: True
Dataset_transfer: mnist_3

final_adaptive: False
adaptive_mean: False
mean_final: torch.zeros([${data.channels}, ${data.image_size}, ${data.image_size}])
var_final: 1 * torch.ones([${data.channels}, ${data.image_size}, ${data.image_size}])
load: True

# device
device: cuda
num_workers: 2 
pin_memory: True

# logging
log_stride: 100
gif_stride: 5000
plot_npar: 100
# FID number of samples
test_npar: 10000
test_batch_size: 250
cache_npar: 10000
cache_batch_size: 250
num_repeat_data: 1
cache_refresh_stride: 1000

# training
use_prev_net: True
ema: True
ema_rate: 0.999
grad_clipping: True
grad_clip: 1.0
batch_size: 128
num_iter: 500000
n_ipf: 50
lr: 0.0001

num_steps: 30
