wandb: cvsim
name: upt-68M-image-all-k256
stage_name:
vars:
  version: v2-5000sims
  num_input_points: null
  num_input_points_ratio: null
  radius_graph_r: 5
  radius_graph_max_num_neighbors: 256
  num_input_timesteps: 2
  mode: image
  clamp: 0
  clamp_mode: log
  norm: mean0std1q25
  num_supernodes: 2048

datasets:
  train:
    kind: cfd_dataset
    version: ${vars.version}
    split: train
    num_input_points: ${vars.num_input_points}
    num_input_points_ratio: ${vars.num_input_points_ratio}
    num_input_timesteps: ${vars.num_input_timesteps}
    clamp: ${vars.clamp}
    clamp_mode: ${vars.clamp_mode}
    norm: ${vars.norm}
    collators:
    - kind: cfd_simformer_collator
      num_supernodes: ${vars.num_supernodes}
  rollout:
    kind: cfd_dataset
    version: ${vars.version}
    split: train
    num_input_points: ${vars.num_input_points}
    num_input_points_ratio: ${vars.num_input_points_ratio}
    num_input_timesteps: .inf
    clamp: ${vars.clamp}
    clamp_mode: ${vars.clamp_mode}
    norm: ${vars.norm}
    max_num_sequences: 10
    collators:
    - kind: cfd_simformer_collator
      num_supernodes: ${vars.num_supernodes}


model:
  kind: cfd_simformer_model
  conditioner:
    kind: conditioners.timestep_velocity_conditioner_pdearena
    kwargs: ${select:dim384:${yaml:models/dim}}
    is_frozen: true
  encoder:
    kind: encoders.cfd_pool_transformer_perceiver
    num_latent_tokens: 64
    enc_depth: 4
    kwargs: ${select:dim192to384:${yaml:models/encoders/pool_transformer_perceiver}}
    is_frozen: true
  latent:
    kind: latent.transformer_model
    depth: 4
    kwargs: ${select:dim384:${yaml:models/latent/transformer}}
    is_frozen: true
  decoder:
    kind: decoders.cfd_transformer_perceiver
    depth: 4
    use_last_norm: true
    clamp: ${vars.clamp}
    clamp_mode: ${vars.clamp_mode}
    kwargs: ${select:dim384to192:${yaml:models/decoders/transformer_perceiver}}
    is_frozen: true


trainer:
  kind: cfd_simformer_trainer
  precision: bfloat16
  backup_precision: float16
  max_epochs: 0
  effective_batch_size: 1
  radius_graph_r: ${vars.radius_graph_r}
  radius_graph_max_num_neighbors: ${vars.radius_graph_max_num_neighbors}
  loss_function:
    kind: elementwise_loss
    loss_function:
      kind: mse_loss
  log_every_n_epochs: 1
  callbacks:
    # warmup
    - kind: offline_rollout_speed_callback
      every_n_epochs: 1
      dataset_key: rollout
      rollout_kwargs:
        mode: ${vars.mode}
    # benchmark
    - kind: offline_rollout_speed_callback
      every_n_epochs: 1
      dataset_key: rollout
      rollout_kwargs:
        mode: ${vars.mode}
