input_shape: [ null, 3 ]
output_shape: [ null, 1 ]

inputs:
  mesh_pos:
    shape: [ 3586, 3 ]
    is_sparse: true
  sdf:
    shape: [ 64, 64, 64, 4 ]
  query_pos:
    shape: [ 3586, 3 ]
  batch_idx:
    shape: [ 3586 ]
    constraint: batch_idx
    is_sparse: true
  unbatch_idx:
    shape: [ 3586 ]
    constraint: unbatch_idx
    is_sparse: true
  unbatch_select:
    shape: [ 3586 ]
    constraint: unbatch_select
    is_sparse: true

output_key: x_hat

model:
  kind: rans_simformer_nognn_sdf_model
  grid_encoder:
    kind: encoders.rans_grid_convnext
    patch_size: 2
    kernel_size: 3
    depthwise: false
    global_response_norm: true
    depths: [ 2, 2, 2 ]
    dims: [ 192, 384, 768 ]
    upsample_size: 64
    upsample_mode: nearest
    resolution: [ 64, 64, 64 ]
    concat_pos_to_sdf: true
  mesh_encoder:
    kind: encoders.rans_perceiver
    num_output_tokens: 1024
    add_type_token: true
    dim: 768
    num_attn_heads: 12
  latent:
    kind: latent.transformer_model
    dim: 768
    num_attn_heads: 12
    depth: 12
  decoder:
    kind: decoders.rans_perceiver
    dim: 768
    num_attn_heads: 12
