defaults:
  - /dataset: cmnist_uniform
  - _self_

model:
  _target_: models.conditional_unet.ClassConditionalUnet
  num_class_per_label:
  - 10
  - 10
  interaction: sum
  sample_size: 28
  in_channels: 3
  out_channels: 3
  layers_per_block: 2
  block_out_channels:
  - 56
  - 112
  - 168
  attention_head_dim: 8
  class_embed_type: null
  norm_num_groups: 8
  dropout: 0.1
  act_fn: gelu
  down_block_types:
  - DownBlock2D
  - AttnDownBlock2D
  - AttnDownBlock2D
  up_block_types:
  - AttnUpBlock2D
  - AttnUpBlock2D
  - UpBlock2D

noise_scheduler:
  _target_: diffusers.DDPMScheduler
  num_train_timesteps: 1000
  clip_sample: true
  prediction_type: epsilon
  beta_schedule: linear

seed: 42
checkpoint_path: checkpoints/cmnist/cmnist_uniform_normal/checkpoints/epoch_212-step_50000.ckpt
use_neg_guidance: false
evaluate_baselines: true

composition_classifier:
  _target_: cs_classifier.models.MultiheadClassifier
  base_model:
    _target_: torchvision.models.resnet18
  num_classes_per_label: ${model.num_class_per_label}

judge_classifier:
  _target_: cs_classifier.models.MultiLabelClassifier
  base_model:
    _target_: torchvision.models.resnet18
  num_classes_per_label: ${model.num_class_per_label}

composition_classifier_checkpoint: checkpoints/cmnist/cs_cmnist_composition/classifier/version_0/checkpoints/epoch=49-step=2950.ckpt
judge_classifier_checkpoint: checkpoints/cmnist/cs_cmnist_judge/classifier/version_0/checkpoints/epoch=49-step=2950.ckpt
output_suffix: cmnist

guidance:
  atom: 1.0
  not: 1.0
  ours:
    and: 1.0
    not: 1.0
    or_mi: 1.0
    or_me: 1.0
  constant:
    and: 1.0
    not: 1.0
    or: 1.0