defaults:
  - /dataset: shapes3d_uniform
  - _self_

model:
  _target_: models.conditional_unet.ClassConditionalUnet
  interaction: sum
  num_class_per_label: [10, 10, 10, 8, 4, 15]
  sample_size: 64
  in_channels: 3
  out_channels: 3
  layers_per_block: 2
  block_out_channels:  [56, 112, 168, 224]
  dropout: 0.1
  attention_head_dim: 8
  class_embed_type: null
  norm_num_groups: 8
  act_fn: silu
  down_block_types:
  - DownBlock2D
  - AttnDownBlock2D
  - AttnDownBlock2D
  - AttnDownBlock2D
  up_block_types:
  - AttnUpBlock2D
  - AttnUpBlock2D
  - AttnUpBlock2D
  - UpBlock2D

noise_scheduler:
  _target_: diffusers.DDPMScheduler
  num_train_timesteps: 1000
  clip_sample: true
  prediction_type: epsilon
  beta_schedule: linear

seed: 42
use_neg_guidance: false
evaluate_baselines: true

checkpoint_path: checkpoints/3dshapes/shapes3d_uniform_normal/checkpoints/epoch_10-step_82500.ckpt

composition_classifier:
  _target_: cs_classifier.models.MultiheadClassifier
  base_model:
    _target_: torchvision.models.resnet18
  num_classes_per_label: ${model.num_class_per_label}

judge_classifier:
  _target_: cs_classifier.models.MultiLabelClassifier
  base_model:
    _target_: torchvision.models.resnet18
  num_classes_per_label: ${model.num_class_per_label}

composition_classifier_checkpoint: checkpoints/3dshapes/cs_shapes3d_composition/classifier/version_0/checkpoints/epoch=49-step=23450.ckpt
judge_classifier_checkpoint: checkpoints/3dshapes/cs_shapes3d_judge/classifier/version_0/checkpoints/epoch=49-step=23450.ckpt
output_suffix: shapes3d

guidance:
  atom: 1.0
  not: 1.0
  ours:
    and: 1.0
    not: 1.0
    or_mi: 1.0
    or_me: 1.0
  constant:
    and: 1.0
    not: 1.0
    or: 1.0