input_shape: [ null, 3 ]
output_shape: [ null, 1 ]

inputs:
  mesh_pos:
    shape: [ 3586, 3 ]
    is_sparse: true
  query_pos:
    shape: [ 3586, 3 ]
  batch_idx:
    shape: [ 3586 ]
    constraint: batch_idx
    is_sparse: true
  unbatch_idx:
    shape: [ 3586 ]
    constraint: unbatch_idx
    is_sparse: true
  unbatch_select:
    shape: [ 3586 ]
    constraint: unbatch_select
    is_sparse: true

output_key: x_hat

model:
  kind: rans_simformer_nognn_model
  encoder:
    kind: encoders.rans_perceiver
    num_output_tokens: 64
    dim: 1024
    num_attn_heads: 16
  latent:
    kind: latent.transformer_model
    dim: 1024
    num_attn_heads: 16
    depth: 12
  decoder:
    kind: decoders.rans_perceiver
    use_last_norm: true
    dim: 1024
    num_attn_heads: 16
