wandb: cvsim
name: snc-all-sdf-e1000-lr1e3-ginoencdec-grid32-encdec384-fno64x16
stage_name: stage1
vars:
  lr: 1.0e-3
  batch_size: 32
  max_batch_size: 32
  epochs: 1000

  radius_graph_r: 10.0
  grid_resolution: 32

  optim:
    kind: adamw
    lr: ${vars.lr}
    weight_decay: 1.0e-4
    schedule:
      template: ${yaml:schedules/wupcos_epoch}
      template.vars.end_epoch: 50

datasets:
  train:
    kind: shapenet_car
    split: train
    radius_graph_r: ${vars.radius_graph_r}
    grid_resolution: ${vars.grid_resolution}
    collators:
      - kind: rans_gino_encdec_sdf_collator
  test:
    kind: shapenet_car
    split: test
    radius_graph_r: ${vars.radius_graph_r}
    grid_resolution: ${vars.grid_resolution}
    collators:
      - kind: rans_gino_encdec_sdf_collator

model:
  kind: rans_gino_encdec_sdf_model
  encoder:
    kind: encoders.rans_gino_sdf
    kwargs: ${select:dim384:${yaml:models/dim}}
    optim: ${vars.optim}
  latent:
    kind: latent.fno_gino_model
    modes: 16
    kwargs: ${select:dim64:${yaml:models/latent/fno}}
    optim: ${vars.optim}
  decoder:
    kind: decoders.rans_gino
    kwargs: ${select:dim384:${yaml:models/dim}}
    optim: ${vars.optim}

trainer:
  kind: rans_gino_encdec_sdf_trainer
  # encountered NaN loss
  precision: float32
  max_epochs: ${vars.epochs}
  effective_batch_size: ${vars.batch_size}
  max_batch_size: ${vars.max_batch_size}
  loss_function:
    kind: elementwise_loss
    loss_function:
      kind: mse_loss
  log_every_n_epochs: 1
  callbacks:
    - kind: offline_loss_callback
      every_n_epochs: 1
      dataset_key: test
    - kind: best_checkpoint_callback
      every_n_epochs: 1
      metric_key: loss/test/total