name: cifar10
tags: ["test"]
description: ""
version: null # if not specified, will be set to version_{index}
output_dir: "outputs/${name}"

seed: 42

train: False
test: True
resume: outputs/cifar10/version_0/checkpoints/epoch=4799-step=470400.ckpt

data:
  _target_: src.data.multiview.MultiViewDataModule
  train_dataset:
    _target_: src.data.cifar10.CIFAR10Warrper
    mode: train
  train_batch_size: 128
  val_dataset:
    _target_: src.data.cifar10.CIFAR10Warrper
    mode: val
    num_samples: 40
  val_batch_size: 10
  test_dataset:
    _target_: src.data.cifar10.CIFAR10Warrper
    mode: train
    num_samples: 50000
  test_batch_size: 500
  num_workers: 32
  pin_memory: True

system:
  _target_: src.systems.diffusion_system.DiffusionSystem
  diffusion_timesteps: 50
  mode: "epsilon"

trainer:
  _target_: lightning.pytorch.trainer.Trainer
  default_root_dir: ${output_dir}
  # max_steps: 20001
  max_epochs: 5000
  # val_check_interval: 2000
  check_val_every_n_epoch: 5
  log_every_n_steps: 10
  num_sanity_val_steps: 1
  enable_progress_bar: true
  strategy: ddp_find_unused_parameters_true
  accumulate_grad_batches: 1.0
  gradient_clip_val: 1.0
  accelerator: gpu
  # devices: 1
  # num_nodes: 1
  precision: 16-mixed # mixed precision for extra speed-up

callbacks:
  model_checkpoint:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    save_top_k: -1
    # every_n_train_steps: 50000
    every_n_epochs: 100
  # rich_progress_bar:
    # _target_: lightning.pytorch.callbacks.RichProgressBar

logger:
  tensorboard:
    _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
    save_dir: "${output_dir}"
    name: ""
    version: "${version}"
    sub_dir: "tb_logs"
  # wandb:
  #   _target_: lightning.pytorch.loggers.wandb.WandbLogger
  #   project: "${name}"
  #   save_dir: "outputs"
  #   name: "${version}"