seed_everything: 123
wandb_group: base_video_model
wandb_name: base_video_model
model:
  class_path: dwma.lightning.modules.diffusion_module.DiffusionModule
  init_args:
    learning_rate: 1e-4
    sample_every: 5000
    action_cond: False
    condition_frames: 2
    num_pred_frames: 8
    diffusion_model:
      class_path: dwma.models.video_diffusion_pytorch.diffusion.GaussianDiffusion
      init_args:
        image_size: 64
        timesteps: 200
        auto_normalize: False
        objective: "pred_x0"
        model:
          class_path: dwma.models.video_diffusion_pytorch.unet3d.Unet3D
          init_args:
            dim: 100
            dim_mults: [1, 2, 4, 8]
            attn_dim_head: 64
            attn_heads: 8
            train_with_cfg: False

data:
  class_path: dwma.lightning.data_modules.procgen_data.ProcgenDataModule
  init_args:
    batch_size: 32
    window_width: 12
    train_data_folder: /datasets/coinrun_500lvl_train
    test_data_folder: /datasets/coinrun_test
    val_data_folder: /datasets/coinrun_test
    cache_files: False
    precache_files: False
    num_workers: 4
    fixed_episode_length: 50

trainer:
  logger:
    class_path: pytorch_lightning.loggers.WandbLogger
    init_args:
      save_dir: /host_home/wandb/
      offline: False
      project: avid
      entity: causica
      log_model: all
  devices: [0]
  accelerator: "gpu"
  accumulate_grad_batches: 2
  gradient_clip_val: 1.0
  max_epochs: 5000
  limit_train_batches: 2000
  limit_val_batches: 200
  log_every_n_steps: 10
  precision: 16
  callbacks:
    class_path: pytorch_lightning.callbacks.ModelCheckpoint
    init_args:
      dirpath: "/host_home/avid_checkpoints"
      filename: "{epoch}-{step}"
      save_weights_only: True
      save_last: False
      every_n_epochs: 5 # Change this to the desired frequency
      save_top_k: -1 # save all checkpoints

