wandb: cvsim
name: ???
stage_name: loss
vars:
  stage_id: ???
  radius_graph_r: ???
  loss_function: ???

datasets:
  train:
    kind: shapenet_car
    split: train
    radius_graph_r: ${vars.radius_graph_r}
    collators:
      - kind: rans_simformer_collator
  test:
    kind: shapenet_car
    split: test
    radius_graph_r: ${vars.radius_graph_r}
    collators:
      - kind: rans_simformer_collator

model:
  kind: rans_simformer_model
  encoder:
#    kind: encoders.rans_perceiver
#    num_output_tokens: 3072
#    kwargs: ${select:base:${yaml:models/encoders/perceiver}}
    is_frozen: true
    initializers:
      - kind: previous_run_initializer
        stage_id: ${vars.stage_id}
        stage_name: stage1
        model_name: rans_simformer_model.encoder
        checkpoint: ${vars.checkpoint}
        use_checkpoint_kwargs: true
  latent:
#    kind: latent.transformer_model
#    kwargs: ${select:base:${yaml:models/latent/transformer}}
    is_frozen: true
    initializers:
      - kind: previous_run_initializer
        stage_id: ${vars.stage_id}
        stage_name: stage1
        model_name: rans_simformer_model.latent
        checkpoint: ${vars.checkpoint}
        use_checkpoint_kwargs: true
  decoder:
#    kind: decoders.rans_perceiver
#    kwargs: ${select:base:${yaml:models/decoders/perceiver}}
    is_frozen: true
    initializers:
      - kind: previous_run_initializer
        stage_id: ${vars.stage_id}
        stage_name: stage1
        model_name: rans_simformer_model.decoder
        checkpoint: ${vars.checkpoint}
        use_checkpoint_kwargs: true

trainer:
  kind: rans_simformer_trainer
  precision: bfloat16
  backup_precision: float16
  max_epochs: 0
  effective_batch_size: 128
  loss_function:
    kind: elementwise_loss
    loss_function:
      kind: ${vars.loss_function}
  log_every_n_epochs: 1
  callbacks:
    - kind: offline_loss_callback
      every_n_epochs: 1
      dataset_key: train
    - kind: offline_loss_callback
      every_n_epochs: 1
      dataset_key: test
