wandb_group: controlnet_22M_coinrun_500lvl
wandb_name: controlnet_22M_coinrun_500lvl
model:
  class_path: dwma.lightning.modules.diffusion_module.DiffusionModule
  init_args:
    ckpt_path: "/host_home/avid_checkpoints/pretrained_base_model/epoch=259-step=260000.ckpt"
    learning_rate: 1e-4
    sample_every: 5000
    action_cond: True
    condition_frames: 2
    num_pred_frames: 8
    diffusion_model:
      class_path: dwma.models.video_diffusion_pytorch.diffusion_control_net.GaussianDiffusionWithControlNet
      init_args:
        image_size: 64
        timesteps: 200
        auto_normalize: False
        objective: "pred_x0"
        model:
          class_path: dwma.models.control_net.ControlledUnetModel
          init_args:
            dim: 100
            dim_mults: [1, 2, 4, 8]
            attn_dim_head: 64
            attn_heads: 8
            train_with_cfg: False
        control_net:
          class_path: dwma.models.control_net.ControlNetSmaller
          init_args:
            new_dim: 52
            dim: 100
            dim_mults: [1, 2, 4, 8]
            attn_dim_head: 64
            attn_heads: 8
            train_with_cfg: False

trainer:
  devices: [0]
  callbacks:
    class_path: pytorch_lightning.callbacks.ModelCheckpoint
    init_args:
      dirpath: "/host_home/avid_checkpoints/controlnet_22M_coinrun_500lvl"
      every_n_epochs: 5
