wandb_group: avid_22M_coinrun_500lvl
wandb_name: avid_22M_coinrun_500lvl
model:
  class_path: dwma.lightning.modules.diffusion_module.DiffusionModule
  init_args:
    ckpt_path: "/host_home/avid_checkpoints/pretrained_base_model/epoch=259-step=260000.ckpt"
    learning_rate: 1e-4
    sample_every: 5000
    action_cond: True
    condition_frames: 2
    num_pred_frames: 8
    diffusion_model:
      class_path: dwma.models.avid.GaussianDiffusionWithAVID
      init_args:
        condition_adapter_on_base_outputs: True
        learnt_mask: True
        image_size: 64
        timesteps: 200
        auto_normalize: False
        objective: "pred_x0"
        model:
          class_path: dwma.models.video_diffusion_pytorch.unet3d.Unet3D
          init_args:
            dim: 100
            dim_mults: [1, 2, 4, 8]
            attn_dim_head: 64
            attn_heads: 8
            train_with_cfg: False
            channels: 3
        adapter:
          class_path: dwma.models.video_diffusion_pytorch.unet3d.Unet3D
          init_args:
            dim: 50
            attn_dim_head: 50
            attn_heads: 4
            train_with_cfg: False
            channels: 3

trainer:
  devices: [0]
  callbacks:
    class_path: pytorch_lightning.callbacks.ModelCheckpoint
    init_args:
      dirpath: "/host_home/avid_checkpoints/avid_22M_coinrun_500lvl"
      every_n_epochs: 5
  accumulate_grad_batches: 2
