name: classifier_mse_rt1
group: classifier_mse_rt1
logdir: /host_home/avid

model:
  base_learning_rate: 1e-4
  scale_lr: False
  target: lvdm.models.action_predictor.ActionPredictor
  params:
    in_channels: 3
    in_block_channels:
      - 32
      - 64
      - 128
      - 256
    mlp_hidden_dims:
      - 1024
      - 1024
    input_resolution:
      - 320
      - 512
    encoder_config:
      target: dwma.models.classifier.Unet3DEncoder
      params:
        encoder_only: True
        dim: 256
        attn_dim_head: 64
        attn_heads: 8
        dim_mults: [1, 1, 2]

data:
  target: ldwma.lightning.data_modules.rtx.RTXDataModule
  params:
    batch_size: 16
    target_height: 320
    target_width: 512
    dataset_name: "fractal20220817_data"
    shuffle_buffer: 1000
    traj_len: 16

lightning:
  precision: 16
  trainer:
    accelerator: "gpu"
    limit_train_batches: 1000
    limit_val_batches: 100
    check_val_every_n_epoch: 1
    num_nodes: 1
    benchmark: True
    accumulate_grad_batches: 1
    max_steps: 100000
    log_every_n_steps: 5
    gradient_clip_algorithm: "norm"
    gradient_clip_val: 0.5
    logger:
      target: pytorch_lightning.loggers.WandbLogger
      params:
        save_dir: /host_home/wandb/
        offline: False
        project: avid
        entity: causica
        log_model: all
  callbacks:
    model_checkpoint:
      target: pytorch_lightning.callbacks.ModelCheckpoint
      params:
        filename: "{epoch}-{step}"
        save_weights_only: True
        every_n_epochs: 10
        save_top_k: -1
