_target_: src.SetVAE.SetVAEModule
model_args:
  input_dim: 3 # Number of input dimensions (3 for 3D point clouds)
  init_dim: 64 # Number of dimensions for each initial set element
  n_mixtures: 4 # Number of mixture components for the initial set
  fixed_gmm: False # Whether to use fixed initialization (Fibonacci sphere-based) for initial set GMM parameters
  train_gmm: False # Whether to train initial set GMM parameters via reparameterization
  z_dim: 32 # Number of dimensions for each latent set element
  z_scales: [2, 4, 8, 16, 32] # Top-down scales for hierarchical latent sets
  hidden_dim: 64 # Number of hidden dimensions
  num_heads: 4 # Number of attention heads
  slot_att: True # Whether to use slot attention
  i_net: 'elem_mlp' # Induced network to use', choices=['full_mlp', 'elem_mlp', 'set_transformer']
  i_net_layers: 0 # Number of hidden layers in induced network
  d_net: 'set_transformer' # Deterministic layer to use', choices=['elem_mlp', 'set_transformer']
  enc_in_layers: 0 # Number of deterministic layers in pre-encoder
  dec_in_layers: 0 # Number of deterministic layers in pre-decoder
  dec_out_layers: 0 # Number of deterministic layers in post-decoder
  isab_inds: 16 # Number of inducing points in deterministic layers
  ln: True # Whether to use layer normalization
  activation: 'relu' # Activation function for MLP', choices=['relu', 'tanh']
  use_bn: False # Whether to use batchnorm for MLP.
  residual: True # Whether to use residual connections for MLP.
  dropout_p: 0. # Dropout rate.
opt_args:
  type: adam # Optimizer to use', choices=['adam', 'adamax', 'sgd']
  # batch_size: 32
  lr: 5e-2
  # max_grad_norm: 5. # Gradient norm clipping
  # max_grad_threshold: None # Gradient norm threshold for update
  beta1: 0.9 # Beta1 for Adam.
  beta2: 0.999 # Beta2 for Adam.
  momentum: 0.9 # Momentum for SGD
  weight_decay: 0. # Weight decay for the optimizer.
  # dropout_p: 0. # Dropout rate.
  epochs: 1000 # Total epochs to train ##### ref to trainer?
  # seed: None # Random seed for reproducibility
  matcher: 'approxEMD' # choices=('hungarian', 'chamfer', 'all ', 'approxEMD');  Matcher for reconstruction loss computation
  # matcher_ckpt_path: None
  beta: 1e-2 # KL loss weight
  kl_warmup_epochs: 50 # KL annealing epochs
  sch_type: 'cosine' # Type of learning rate schedule, choices=('exponential','step','linear', 'cosine', 'none')
  warmup_epochs: 0 # Length of learning rate warm-up
  exp_decay: 1. # Learning rate schedule exponential decay rate
  # exp_decay_freq: 1 # Learning rate exponential decay frequency
max_outputs: 500