
point_encoder:
  class_path: asymdsd.models.PointEncoder
  init_args:
    patchify:
      # class_path: asymdsd.layers.MultiPointPatchify
      # init_args:
      #   num_patches: [64]
      #   patch_size: [32]
    cls_token: true
    patch_embedding:
      class_path: asymdsd.layers.PatchEmbeddingConfig
      init_args:
        position_embedding:
          class_path: asymdsd.layers.tokenization.PositionEmbeddingConfig
          init_args:
            in_features: 3
            embed_dim: 768
            act_layer: torch.nn.GELU
            normalize: false
        point_embedding:
          class_path: asymdsd.layers.tokenization.VarMemEfficientPointMaxEmbeddingConfig
          init_args:
            in_features: 3
            embed_dim: 768
            allow_grad_ckpt: true
            hidden_dims: [[256, 512, 1024], [2048]]
            # hidden_dims: [[128, 256, 512], [1024]]
            act_layer: torch.nn.GELU # asymdsd.layers.activation.GEGLU # torch.nn.GELU
            norm_layer: asymdsd.layers.RMSNorm #asymdsd.layers.RMSNorm # asymdsd.layers.TransposeBatchNorm1d
            bias: false
            # dropout_p: 0.2
            process_num_chunks: 1
        normalize_patches: false
    encoder:
      embed_dim: 768
      num_heads: 12
      num_layers: 12
      drop_path_p: 0.2