wandb: cvsim
name: ???
stage_name: rollout-correlation
vars:
  stage_id: ???
  batch_size: 32
  version: ???
  num_input_points: ???
  num_input_points_ratio: ???
  radius_graph_r: ???
  radius_graph_max_num_neighbors: ???
  grid_resolution: ???
  num_input_timesteps: ???
  checkpoint: ???
  clamp: ???
  clamp_mode: ???
  norm: ???


datasets:
  train:
    kind: cfd_dataset
    version: ${vars.version}
    split: train
    num_input_points: ${vars.num_input_points}
    num_input_points_ratio: ${vars.num_input_points_ratio}
    num_input_timesteps: ${vars.num_input_timesteps}
    grid_resolution: ${vars.grid_resolution}
    clamp: ${vars.clamp}
    clamp_mode: ${vars.clamp_mode}
    norm: ${vars.norm}
    radius_graph_r: ${vars.radius_graph_r}
    radius_graph_max_num_neighbors: ${vars.radius_graph_max_num_neighbors}
    num_query_points: 32768 # not needed but here to satisfy assertion
    standardize_query_pos: true
    collators:
      - kind: cfd_baseline_collator
  train_rollout:
    kind: cfd_dataset
    version: ${vars.version}
    split: train
    num_input_points: ${vars.num_input_points}
    num_input_points_ratio: ${vars.num_input_points_ratio}
    num_input_timesteps: .inf
    grid_resolution: ${vars.grid_resolution}
    clamp: ${vars.clamp}
    clamp_mode: ${vars.clamp_mode}
    norm: ${vars.norm}
    radius_graph_r: ${vars.radius_graph_r}
    radius_graph_max_num_neighbors: ${vars.radius_graph_max_num_neighbors}
    max_num_sequences: 500
    standardize_query_pos: true
    collators:
      - kind: cfd_baseline_collator
  test_rollout:
    kind: cfd_dataset
    version: ${vars.version}
    split: test
    num_input_points: ${vars.num_input_points}
    num_input_points_ratio: ${vars.num_input_points_ratio}
    num_input_timesteps: .inf
    grid_resolution: ${vars.grid_resolution}
    clamp: ${vars.clamp}
    clamp_mode: ${vars.clamp_mode}
    norm: ${vars.norm}
    radius_graph_r: ${vars.radius_graph_r}
    radius_graph_max_num_neighbors: ${vars.radius_graph_max_num_neighbors}
    standardize_query_pos: true
    collators:
      - kind: cfd_baseline_collator

model:
  kind: cfd_baseline_model
  conditioner:
#    kind: conditioners.timestep_velocity_conditioner
#    kwargs: ${select:debug:${yaml:models/dim}}
    is_frozen: true
    initializers:
      - kind: previous_run_initializer
        stage_id: ${vars.stage_id}
        stage_name: stage1
        model_name: cfd_baseline_model.conditioner
        checkpoint: ${vars.checkpoint}
        use_checkpoint_kwargs: true
  encoder:
#    kind: encoders.cfd_gino
#    kwargs: ${select:dim768:${yaml:models/encoders/gino}}
    is_frozen: true
    initializers:
      - kind: previous_run_initializer
        stage_id: ${vars.stage_id}
        stage_name: stage1
        model_name: cfd_baseline_model.encoder
        checkpoint: ${vars.checkpoint}
        use_checkpoint_kwargs: true
  latent:
#    kind: latent.fno_cond_model
#    padding: 0
#    kwargs: ${select:dim128:${yaml:models/dim}}
    is_frozen: true
    initializers:
      - kind: previous_run_initializer
        stage_id: ${vars.stage_id}
        stage_name: stage1
        model_name: cfd_baseline_model.latent
        checkpoint: ${vars.checkpoint}
        use_checkpoint_kwargs: true
  decoder:
#    kind: decoders.cfd_gino
#    kwargs: ${select:dim768:${yaml:models/decoders/gino}}
    is_frozen: true
    initializers:
      - kind: previous_run_initializer
        stage_id: ${vars.stage_id}
        stage_name: stage1
        model_name: cfd_baseline_model.decoder
        checkpoint: ${vars.checkpoint}
        use_checkpoint_kwargs: true

trainer:
  kind: cfd_baseline_trainer
  precision: bfloat16
  backup_precision: float16
  max_epochs: 0
  effective_batch_size: ${vars.batch_size}
  loss_function:
    kind: elementwise_loss
    loss_function:
      kind: mse_loss
  log_every_n_epochs: 1
  callbacks:
    # rollout test
    - kind: offline_correlation_time_callback
      every_n_epochs: 1
      dataset_key: test_rollout
    # rollout train
    - kind: offline_correlation_time_callback
      every_n_epochs: 1
      dataset_key: train_rollout