wandb: cvsim
name: 5ksims-e100-lr2e5-16to64r5k32q16-nodes2048to128-68M-logmse-lion-decclamp-rec16k
stage_name: stage1
vars:
  lr: 2.0e-5
  batch_size: 1024
  max_batch_size: 1024
  epochs: 100
  version: v2-5000sims
  num_input_points: 16384
  num_supernodes: 2048
  num_latent_tokens: 128
  radius_graph_r: 5.0
  radius_graph_max_num_neighbors: 32
  rollout_num_input_points: 16384
  norm: mean0std1q25
  clamp: 0
  clamp_mode: log

  optim:
    kind: lion
    lr: ${vars.lr}
    weight_decay: 0.5
    schedule:
      template: ${yaml:schedules/wupcos_epoch}
      template.vars.end_epoch: 10

datasets:
  train:
    kind: cfd_dataset
    version: ${vars.version}
    split: train
    num_input_timesteps: 2
    num_input_points: ${vars.num_input_points}
    couple_query_with_input: true
    norm: ${vars.norm}
    clamp: ${vars.clamp}
    clamp_mode: ${vars.clamp_mode}
    collators:
      - kind: cfd_simformer_collator
        num_supernodes: ${vars.num_supernodes}
  train_rollout:
    kind: cfd_dataset
    version: ${vars.version}
    split: train
    num_input_timesteps: .inf
    num_input_points: ${vars.rollout_num_input_points}
    norm: ${vars.norm}
    clamp: ${vars.clamp}
    clamp_mode: ${vars.clamp_mode}
    max_num_sequences: 32
    collators:
      - kind: cfd_simformer_collator
        num_supernodes: ${vars.num_supernodes}
  test_rollout:
    kind: cfd_dataset
    version: ${vars.version}
    split: test
    num_input_timesteps: .inf
    num_input_points: ${vars.rollout_num_input_points}
    norm: ${vars.norm}
    clamp: ${vars.clamp}
    clamp_mode: ${vars.clamp_mode}
    max_num_sequences: 32
    collators:
      - kind: cfd_simformer_collator
        num_supernodes: ${vars.num_supernodes}


model:
  kind: cfd_simformer_model
  conditioner:
    kind: conditioners.timestep_velocity_conditioner_pdearena
    kwargs: ${select:dim384:${yaml:models/dim}}
    optim: ${vars.optim}
  encoder:
    kind: encoders.cfd_pool_transformer_perceiver
    num_latent_tokens: ${vars.num_latent_tokens}
    enc_depth: 4
    kwargs: ${select:dim192to384:${yaml:models/encoders/pool_transformer_perceiver}}
    optim: ${vars.optim}
  latent:
    kind: latent.transformer_model
    depth: 4
    kwargs: ${select:dim384:${yaml:models/latent/transformer}}
    optim: ${vars.optim}
  decoder:
    kind: decoders.cfd_transformer_perceiver
    depth: 4
    use_last_norm: true
    clamp: ${vars.clamp}
    clamp_mode: ${vars.clamp_mode}
    kwargs: ${select:dim384to192:${yaml:models/decoders/transformer_perceiver}}
    optim: ${vars.optim}

trainer:
  kind: cfd_simformer_trainer
  precision: bfloat16
  backup_precision: float16
  max_epochs: ${vars.epochs}
  effective_batch_size: ${vars.batch_size}
  max_batch_size: ${vars.max_batch_size}
  radius_graph_r: ${vars.radius_graph_r}
  radius_graph_max_num_neighbors: ${vars.radius_graph_max_num_neighbors}
  reconstruct_prev_x_weight: 1
  reconstruct_dynamics_weight: 1
  loss_function:
    kind: elementwise_loss
    loss_function:
      kind: mse_loss
  early_stopper:
    kind: metric_early_stopper
    every_n_epochs: 1
    tolerance: 20
    metric_key: loss/online/x_hat/E1
  log_every_n_epochs: 1
  callbacks:
    # best checkpoint
    - kind: best_checkpoint_callback
      every_n_epochs: 1
      metric_key: loss/online/x_hat/E1
      save_optim: true
    # checkpoints for resuming
    - kind: checkpoint_callback
      every_n_epochs: 10
      save_weights: false
      save_latest_weights: true
      save_latest_optim: true
    # checkpoints for resuming
    - kind: checkpoint_callback
      every_n_epochs: 10
      save_weights: true
      save_optim: true
    # train rollout
    - kind: offline_correlation_time_callback
      every_n_epochs: 10
      dataset_key: train_rollout
    # test rollout
    - kind: offline_correlation_time_callback
      every_n_epochs: 10
      dataset_key: test_rollout
    # train rollout latent
    - kind: offline_correlation_time_callback
      every_n_epochs: 10
      dataset_key: train_rollout
      rollout_kwargs:
        mode: latent
    # test rollout latent
    - kind: offline_correlation_time_callback
      every_n_epochs: 10
      dataset_key: test_rollout
      rollout_kwargs:
        mode: latent
