defaults:
  - default
  - _self_

model:
  _target_: src.models.occupancy_autoencoder.OccupancyAutoencoder
  encoder:
    _target_: src.models.components.encoders.perceiver_encoder.PerceiverEncoder
    _partial_: True
    dim: ${vars.hidden_dim}
    num_latents: 196
    latent_dim: ${vars.latent_dim}
    cross_heads: 3
    cross_dim_head: 64
    latent_heads: 4
    latent_dim_head: 128
    weight_tie_layers: False
    depth: 4
    supernode_pooling:
      _target_: src.modules.supernode_pooling.SupernodePooling
      supernodes_radius: 1.0
      supernodes_max_neighbours: 2
      net:
        _target_: src.modules.supernode_pooling.SupernodeGnn
        input_dim: 6 # num classes + distance
        hidden_dim: ${vars.hidden_dim}
        aggr:
          _target_: torch_geometric.nn.aggr.SumAggregation

  decoder:
    _target_: src.models.components.decoders.perceiver_decoder.PerceiverDecoder
    _partial_: True
    queries_dim: 126
    latent_dim: ${vars.latent_dim}
    cross_heads: 4
    cross_dim_head: 128
    n_particle_types: 5
    n_channels: 1

  # initialize_weights:
  #   _target_: src.modules.weight_initalization.TruncNormal
  #   std: 0.02
