wandb_group: classifier_coinrun_10000lvl
wandb_name: classifier_coinrun_10000lvl

model:
  class_path: dwma.lightning.modules.classifier_module.ClassifierModule
  init_args:
    init_frames: 2
    next_frames: 8
    train_with_noise: False # train on clean videos
    classifier:
      class_path: dwma.models.classifier.GaussianNoiseActionClassifier
      init_args:
        num_timesteps: 200
        model:
          class_path: dwma.models.classifier.Unet3DEncoder
          init_args:
            dim: 64
            attn_dim_head: 64
            attn_heads: 6
            num_actions: 15

data:
  class_path: dwma.lightning.data_modules.procgen_data.ProcgenDataModule
  init_args:
    train_data_folder: /datasets/coinrun_10000lvl_train
    test_data_folder: /datasets/coinrun_test
    val_data_folder: /datasets/coinrun_test

trainer:
  devices: [0]
  callbacks:
    class_path: pytorch_lightning.callbacks.ModelCheckpoint
    init_args:
      dirpath: "/host_home/avid_checkpoints/classifier_coinrun_10000lvl"
