wandb: cvsim
name: lagrangian_minp_1500_sn_256_latent_256_lr_1e-3_save_rollout
stage_name: stage1
vars:
  lr: 1.0e-3
  batch_size: 128
  max_batch_size: 128
  epochs: 50

  dataset:
    kind: lagrangian_dataset
    name: tgv2d
    n_input_timesteps: 6
    n_pushforward_timesteps: 0
    graph_mode: radius_graph
    radius_graph_r: 0.05
    radius_graph_max_num_neighbors: 4

  optim:
    kind: adamw
    lr: ${vars.lr}
    weight_decay: 0.05
    schedule:
      template: ${yaml:schedules/wupcos_epoch}
      template.vars.end_epoch: 10

datasets:
  train:
    kind: ${vars.dataset.kind}
    name: ${vars.dataset.name}
    split: train
    n_input_timesteps: ${vars.dataset.n_input_timesteps}
    n_pushforward_timesteps: ${vars.dataset.n_pushforward_timesteps}
    graph_mode: ${vars.dataset.graph_mode}
    radius_graph_r: ${vars.dataset.radius_graph_r}
    radius_graph_max_num_neighbors: ${vars.dataset.radius_graph_max_num_neighbors}
    num_points_range: [ 1500, 2500 ]
    collators:
      - kind: lagrangian_simformer_collator
  valid:
    kind: ${vars.dataset.kind}
    name: ${vars.dataset.name}
    split: valid
    n_input_timesteps: ${vars.dataset.n_input_timesteps}
    n_pushforward_timesteps: ${vars.dataset.n_pushforward_timesteps}
    graph_mode: ${vars.dataset.graph_mode}
    radius_graph_r: ${vars.dataset.radius_graph_r}
    radius_graph_max_num_neighbors: ${vars.dataset.radius_graph_max_num_neighbors}
    collators:
      - kind: lagrangian_simformer_collator
  test_rollout:
    kind: ${vars.dataset.kind}
    name: ${vars.dataset.name}
    split: test
    n_input_timesteps: ${vars.dataset.n_input_timesteps}
    n_pushforward_timesteps: 20
    graph_mode: ${vars.dataset.graph_mode}
    radius_graph_r: ${vars.dataset.radius_graph_r}
    radius_graph_max_num_neighbors: ${vars.dataset.radius_graph_max_num_neighbors}
    collators:
      - kind: lagrangian_simformer_collator
  test_full_traj_rollout:
    kind: ${vars.dataset.kind}
    name: ${vars.dataset.name}
    split: test
    test_mode: full_traj
    n_input_timesteps: ${vars.dataset.n_input_timesteps}
    n_pushforward_timesteps: 20
    graph_mode: ${vars.dataset.graph_mode}
    radius_graph_r: ${vars.dataset.radius_graph_r}
    radius_graph_max_num_neighbors: ${vars.dataset.radius_graph_max_num_neighbors}
    collators:
      - kind: lagrangian_simformer_collator

model:
  kind: lagrangian_simformer_model
  conditioner:
    kind: conditioners.timestep_conditioner_pdearena
    kwargs: ${select:tiny:${yaml:models/timestep_embed}}
    optim: ${vars.optim}
  encoder:
    kind: encoders.lagrangian_gnn_pool_transformer_perceiver
    num_supernodes: 256
    num_latent_tokens: 256
    kwargs: ${select:dim192:${yaml:models/encoders/gnn_pool_transformer_perceiver}}
    optim: ${vars.optim}
  latent:
    kind: latent.transformer_model
    kwargs: ${select:tiny:${yaml:models/latent/transformer}}
    optim: ${vars.optim}
  decoder:
    kind: decoders.lagrangian_transformer_perceiver
    depth: 4
    kwargs: ${select:dim192:${yaml:models/decoders/transformer_perceiver}}
    optim: ${vars.optim}

trainer:
  kind: lagrangian_simformer_trainer
  precision: bfloat16
  backup_precision: float16
  max_epochs: ${vars.epochs}
  effective_batch_size: ${vars.batch_size}
  max_batch_size: ${vars.max_batch_size}
  loss_function:
    kind: elementwise_loss
    loss_function:
      kind: mse_loss
  log_every_n_epochs: 1
  forward_kwargs:
    reconstruct_prev_a: true
  callbacks:
    - kind: offline_loss_callback
      every_n_epochs: 1
      dataset_key: valid
    - kind: offline_lagrangian_rollout_mesh_loss_callback
      every_n_epochs: 1
      dataset_key: test_rollout
      rollout_kwargs:
        full_rollout: false
        save_rollout: false
    - kind: offline_lagrangian_rollout_mesh_loss_callback
      every_n_epochs: 1
      dataset_key: test_rollout
      rollout_kwargs:
        full_rollout: true
        save_rollout: false
    - kind: offline_lagrangian_rollout_mesh_loss_callback
      every_n_epochs: 1
      dataset_key: test_full_traj_rollout
      rollout_kwargs:
        full_rollout: false
        save_rollout: true
    - kind: offline_lagrangian_rollout_mesh_loss_callback
      every_n_epochs: 1
      dataset_key: test_full_traj_rollout
      rollout_kwargs:
        full_rollout: true
        save_rollout: true