defaults:
  - default
  - scheduler@_here_: linear_warmup_cosine_annealing
  - _self_

_target_: src.models.occupancy_autoencoder_particle.OccupancyAutoencoderParticleLitModule
num_classes: ${vars.num_classes}

model:
  _target_: src.models.occupancy_autoencoder_particle.OccupancyAutoencoderParticle
  _partial_: True
  encoder:
    _target_: src.models.components.encoders.upt_pool_transformer_perceiver.UptPoolTransformerPerceiver
    _partial_: True
    gnn_dim: 96
    enc_dim: 96
    perc_dim: ${vars.latent_dim}
    enc_depth: 4
    enc_num_attn_heads: 2
    perc_num_attn_heads: 3
    num_latent_tokens: ${vars.num_latent_tokens}
    condition_dim: 768
    output_ln: true
    supernode_pooling:
      _target_: src.modules.supernode_pooling.SupernodePooling
      supernodes_radius: ${vars.supernode_radius}
      supernodes_max_neighbours: ${vars.supernodes_max_neighbours}
      net:
        _target_: src.modules.supernode_pooling.SupernodeGnn
        input_dim: ${vars.input_dim}
        hidden_dim: 96
        ndim: 3
        aggr:
          _target_: torch_geometric.nn.aggr.MeanAggregation
        pos_embed:
          _target_: src.modules.positional_embeddings.ContinuousSincosEmbed
          _partial_: True
          pos_scale: ${vars.pos_scale}
        relative_pos_embed:
          _target_: src.modules.positional_embeddings.ContinuousSincosEmbed
          _partial_: True
          box_size: ${vars.supernode_radius}

  decoder:
    _target_: src.models.components.decoders.upt_transformer_perceiver_occupancy.UptTransformerPerceiverOccupancy
    _partial_: True
    dim: 192
    depth: 4
    num_attn_heads: 3
    input_dim: 192
    feat_dim: ${vars.output_dim}
    n_particle_types: ${vars.num_classes}
    condition_dim: 768
    ndim: 3
    pos_embed:
      _target_: src.modules.positional_embeddings.ContinuousSincosEmbed
      _partial_: True
      pos_scale: ${vars.pos_scale}

  conditioner:
    _target_: src.models.components.conditioners.upt_timestep_conditioner.UptTimestepConditioner
    _partial_: True
    dim: 192
    num_total_timesteps: ${vars.num_total_timesteps}
    condition_dim: 768

optimizer:
  _target_: torch.optim.AdamW # LION  # we need "full" single precision
  _partial_: True
  lr: 0.0001
  weight_decay: 0.05 # linear warmup + cosine decay; 1e-6...1e-4, 10% of epochs; look at a PyTLightning example "per step" scheduling

scheduler:
  _target_: src.modules.schedulers.LinearWarmupCosineAnnealingLR
  _partial_: true
  warmup_epochs: 2
  max_epochs: ${trainer.max_epochs}
  min_lr: 1e-6
  last_epoch: -1
