random_seed: 1  # 'random seed number'
device: 0  # 'gpu id'


data:
  data_dir: 'path_to_rmnist'  # 'directory for datasets.'
  dataset: 'rmnist'
  mini_batch_size: 64
  num_workers: 4
  init_timestamp: 0
  split_time: 5  # 'timestep to split ID vs OOD' #


# Training hyperparameters
trainer:
  method: 'ours'
  eval_metric: 'acc'
  eval_fix: True
  epochs: 50     # training epochs for each timestamp
  lr: 1e-3
  momentum: 0.9
  weight_decay: 0.0
  reduction: 'mean'
  dim_bottleneck_f: None  # dim for the bottlenecked features
  len_queue: 8
  len_DM_pool: 32
  warm_up: 0.2
  sample_num: 32  # number of generated weights via diffusion model
  num_DM_loop: 1
  tradeoff_con: 10.0

# Diffusion Model
DM:
  target: networks.diffusion.ddpm.LatentDiffusion
  params:
    base_learning_rate: 1e-4
    linear_start: 0.0015
    linear_end: 0.0195
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    cond_stage_key: class_label
    image_size: 1
    channels: 3
    cond_stage_trainable: true
    conditioning_key: hybrid
    monitor: val/loss_simple_ema
    unet_config:
      target: networks.diffusion.modules.openaimodel.UNetModel
      params:
        dims: 2
        width: 128 # feature_dim
        in_channels: 3
        out_channels: 1
        model_channels: 64
        attention_resolutions:
        - 4
        - 2
        - 1
        num_res_blocks: 1
        channel_mult:
        - 1
        - 2
        - 4
        num_groups: 32
        num_head_channels: 32
        use_spatial_transformer: true
        transformer_depth: 1
        context_dim: 128  # feature_dim
    cond_stage_config:
      target: networks.diffusion.modules.encoders.ClassEmbedder
      params:
        embed_dim: 64
        n_classes: 2


lightning:
  callbacks:
    image_logger:
      target: main.ImageLogger
      params:
        batch_frequency: 5000
        max_images: 8
        increase_log_steps: False
  trainer:
    benchmark: True


# Logging saving and testing options
log:
  print_freq: 500
  log_dir: './checkpoints/rmnist/'
  log_name: 'log.txt'