wandb: cvsim
name: ???
stage_name: loss
vars:
  stage_id: ???
  radius_graph_r: ???
  loss_function: ???
  grid_resolution: ???

datasets:
  train:
    kind: shapenet_car
    split: train
    radius_graph_r: ${vars.radius_graph_r}
    grid_resolution: ${vars.grid_resolution}
    collators:
      - kind: rans_baseline_collator
  test:
    kind: shapenet_car
    split: test
    radius_graph_r: ${vars.radius_graph_r}
    grid_resolution: ${vars.grid_resolution}
    collators:
      - kind: rans_baseline_collator

model:
  kind: rans_baseline_model
  encoder:
#    kind: encoders.rans_gino
#    kwargs: ${select:dim768:${yaml:models/encoders/gino}}
    is_frozen: true
    initializers:
      - kind: previous_run_initializer
        stage_id: ${vars.stage_id}
        stage_name: stage1
        model_name: rans_baseline_model.encoder
        checkpoint: ${vars.checkpoint}
        use_checkpoint_kwargs: true
  latent:
#    kind: latent.fno_gino_model
#    kwargs: ${select:dim64:${yaml:models/latent/fno}}
    is_frozen: true
    initializers:
      - kind: previous_run_initializer
        stage_id: ${vars.stage_id}
        stage_name: stage1
        model_name: rans_baseline_model.latent
        checkpoint: ${vars.checkpoint}
        use_checkpoint_kwargs: true
  decoder:
#    kind: decoders.rans_gino
#    kwargs: ${select:dim768:${yaml:models/decoders/gino}}
    is_frozen: true
    initializers:
      - kind: previous_run_initializer
        stage_id: ${vars.stage_id}
        stage_name: stage1
        model_name: rans_baseline_model.decoder
        checkpoint: ${vars.checkpoint}
        use_checkpoint_kwargs: true

trainer:
  kind: rans_baseline_trainer
  # torch.fft.rfftn does not support bfloat16
  # torch.fft.rfftn only supports float16 for powers of 2
  precision: float16
  backup_precision: float32
  max_epochs: 0
  effective_batch_size: 32
  loss_function:
    kind: elementwise_loss
    loss_function:
      kind: ${vars.loss_function}
  log_every_n_epochs: 1
  callbacks:
    - kind: offline_loss_callback
      every_n_epochs: 1
      dataset_key: train
    - kind: offline_loss_callback
      every_n_epochs: 1
      dataset_key: test
