name: cifar10_x0
tags: ["test"]
description: ""
version: baseline_fix # if not specified, will be set to version_{index}
output_dir: "outputs/${name}"

seed: 42
resume: null

data:
  _target_: src.data.multiview.MultiViewDataModule
  train_dataset:
    _target_: src.data.cifar10.CIFAR10Warrper
    mode: train
  train_batch_size: 128
  val_dataset:
    _target_: src.data.cifar10.CIFAR10Warrper
    mode: train
    num_samples: 50000
  val_batch_size: 500
  test_dataset:
    _target_: src.data.cifar10.CIFAR10Warrper
    mode: test
  test_batch_size: 1
  num_workers: 32
  pin_memory: True

system:
  _target_: src.systems.diffusion_system.DiffusionSystem
  diffusion_timesteps: 50
  mode: "sample"
  ema_update_every: 10
  lr: 2.0e-4

trainer:
  _target_: lightning.pytorch.trainer.Trainer
  default_root_dir: ${output_dir}
  # max_steps: 20001
  max_epochs: 50000
  # val_check_interval: 2000
  check_val_every_n_epoch: 50
  log_every_n_steps: 10
  num_sanity_val_steps: 1
  enable_progress_bar: true
  strategy: ddp_find_unused_parameters_true
  accumulate_grad_batches: 1.0
  gradient_clip_val: 1.0
  accelerator: gpu
  # devices: 1
  # num_nodes: 1
  precision: 16-mixed # mixed precision for extra speed-up

callbacks:
  model_checkpoint:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    save_top_k: -1
    # every_n_train_steps: 50000
    every_n_epochs: 200
  # rich_progress_bar:
    # _target_: lightning.pytorch.callbacks.RichProgressBar

logger:
  # tensorboard:
  #   _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
  #   save_dir: "${output_dir}"
  #   name: ""
  #   version: "${version}"
  #   sub_dir: "tb_logs"
  wandb:
    _target_: lightning.pytorch.loggers.wandb.WandbLogger
    project: "${name}"
    save_dir: "outputs"
    name: "${version}"