defaults:
  - default
  - _self_

_target_: src.models.occupancy_autoencoder.OccupancyAutoencoderLitModule

model:
  _target_: src.models.occupancy_autoencoder.OccupancyAutoencoder
  encoder:
    _target_: src.models.components.encoders.upt_pool_transformer_perceiver.UptPoolTransformerPerceiver
    gnn_dim: ${vars.input_dim}
    enc_dim: ${vars.input_dim}
    perc_dim: ${vars.input_dim}
    enc_depth: 4
    enc_num_attn_heads: 2
    perc_num_attn_heads: 2
    num_latent_tokens: 192
    supernode_pooling:
      _target_: src.modules.supernode_gnn.SupernodeGnn
      input_dim: 6 # num classes + distance
      hidden_dim: ${vars.input_dim}
      init_weights: truncnormal

  decoder:
    _target_: src.models.components.decoders.upt_transformer_perceiver_occupancy.UptTransformerPerceiverOccupancy
    depth: 4
    clamp_mode: log
    dim: 192
    num_attn_heads: 2
    n_particle_types: 5
    input_shape:
      - 0
      - ${vars.input_dim}
    output_shape:
      - 0
      - 1

optimizer:
  _target_: torch.optim.AdamW
  _partial_: True
  lr: 0.0004
  weight_decay: 0.0
