wandb: cvsim
name: snc-all-sdfpos-e1000-subsam1-lr5e4-sdfperconly-seqlen1024-sdf512-cnext-dim768-sd02unif-reprcnn-grn-grid32
stage_name: stage1
vars:
  lr: 5.0e-4
  batch_size: 32
  max_batch_size: 16
  epochs: 1000
  grid_resolution: 32

  optim:
    kind: adamw
    lr: ${vars.lr}
    weight_decay: 0.05
    schedule:
      template: ${yaml:schedules/wupcos_epoch}
      template.vars.end_epoch: 50

datasets:
  train:
    kind: shapenet_car
    split: train
    grid_resolution: ${vars.grid_resolution}
    concat_pos_to_sdf: true
    collators:
      - kind: rans_simformer_nognn_collator
  test:
    kind: shapenet_car
    split: test
    grid_resolution: ${vars.grid_resolution}
    concat_pos_to_sdf: true
    collators:
      - kind: rans_simformer_nognn_collator

model:
  kind: rans_simformer_nognn_sdf_model
  grid_encoder:
    kind: encoders.rans_grid_convnext
    patch_size: 2
    kernel_size: 3
    depthwise: false
    global_response_norm: true
    depths: [ 2, 2, 2 ]
    dims: [ 192, 384, 768 ]
    upsample_size: 64
    upsample_mode: nearest
    optim: ${vars.optim}
  mesh_encoder:
    kind: encoders.rans_perceiver
    num_output_tokens: 1024
    add_type_token: true
    init_weights: truncnormal
    kwargs: ${select:dim768:${yaml:models/encoders/perceiver}}
    optim: ${vars.optim}
  latent:
    kind: latent.transformer_model
    init_weights: truncnormal
    drop_path_rate: 0.2
    drop_path_decay: false
    kwargs: ${select:dim768depth12:${yaml:models/latent/transformer}}
    optim: ${vars.optim}
  decoder:
    kind: decoders.rans_perceiver
    init_weights: truncnormal
    kwargs: ${select:dim768:${yaml:models/decoders/perceiver}}
    optim: ${vars.optim}

trainer:
  kind: rans_simformer_nognn_sdf_trainer
  precision: bfloat16
  backup_precision: float16
  max_epochs: ${vars.epochs}
  effective_batch_size: ${vars.batch_size}
  max_batch_size: ${vars.max_batch_size}
  loss_function:
    kind: elementwise_loss
    loss_function:
      kind: mse_loss
  log_every_n_epochs: 1
  callbacks:
    - kind: offline_loss_callback
      every_n_epochs: 1
      dataset_key: test
    - kind: best_checkpoint_callback
      every_n_epochs: 1
      metric_key: loss/test/total