name: classifier_categorical_rt1
group: classifier_categorical_rt1
logdir: /host_home/avid

diffusion_config_file: configs/train/dynamicrafter_512.yaml

model:
  base_learning_rate: 1e-4
  scale_lr: False
  target: lvdm.models.action_predictor.ActionPredictor
  params:
    encode_and_add_noise: True
    discretize_actions: True
    in_channels: 4
    mlp_hidden_dims:
      - 1024
      - 1024
    input_resolution:
      - 40
      - 64
    encoder_config:
      target: dwma.models.classifier.Unet3DEncoder
      params:
        encoder_only: True
        dim: 160
        attn_dim_head: 64
        attn_heads: 8
        dim_mults: [1, 2, 4, 4]

data:
  target: ldwma.lightning.data_modules.rtx.RTXDataModule
  params:
    batch_size: 16
    target_height: 320
    target_width: 512
    dataset_name: "fractal20220817_data"
    shuffle_buffer: 1000
    traj_len: 16

lightning:
  precision: 16
  trainer:
    accelerator: "gpu"
    limit_train_batches: 1000
    limit_val_batches: 50
    check_val_every_n_epoch: 1
    num_nodes: 1
    benchmark: True
    accumulate_grad_batches: 1
    max_steps: 500000
    log_every_n_steps: 10
    gradient_clip_algorithm: "norm"
    gradient_clip_val: 1.0
    logger:
      target: pytorch_lightning.loggers.WandbLogger
      params:
        save_dir: /host_home/wandb/
        offline: False
        project: avid
        entity: causica
        log_model: all
  callbacks:
    model_checkpoint:
      target: pytorch_lightning.callbacks.ModelCheckpoint
      params:
        filename: "{epoch}-{step}"
        save_weights_only: True
        every_n_epochs: 5
        save_top_k: -1
