data:
  data_name: tau3mu
  data_dir: ../data
  split:
    train: 0.7
    valid: 0.15
    test: 0.15
  hit_filters:
    mu_hit_station: '==1'
    mu_hit_neighbor: '==0'
    mu_hit_type: '!=0'
  sample_filters:
    num_hits: '>=3'
  other_features:
    - mu_hit_bend
  feature_type: only_x # only_pos or only_x or both_x_pos or only_ones

logging:
  tensorboard: false
  topk: [3, 7, 10]  # 15%, 80%, 95%

optimizer:
  batch_size: 256
  wp_lr: 1.0e-3
  wp_wd: 1.0e-5
  attn_lr: 1.0e-3
  attn_wd: 1.0e-5
  emb_lr: 1.0e-3
  emb_wd: 1.0e-5

model:
  egnn:
    n_layers: 4
    hidden_size: 64
    dropout_p: 0.2
    norm_type: batch
    act_type: relu
    pool: add
  dgcnn:
    n_layers: 4
    hidden_size: 64
    dropout_p: 0.2
    norm_type: batch
    act_type: relu
    pool: add
  pointtrans:
    n_layers: 4
    hidden_size: 64
    dropout_p: 0.2
    norm_type: batch
    act_type: relu
    pool: add


# Default hyperparameters are tuned for dgcnn.
# If the tuned hyperparameters for other backbone models are different, they will be commented on the right.

lri_bern:
  epochs: 50                    # egnn: 100
  warmup: 50                    # egnn: 0
  final_r: 0.7
  info_loss_coef: 0.1           # pointtrans: 1.0, egnn: 0.01
  one_encoder: true
  attn_constraint: none

  temperature: 1.0
  decay_interval: 10
  decay_r: 0.1
  init_r: 0.9
  pred_loss_coef: 1.0
  pred_lr: 1.0e-3
  pred_wd: 1.0e-5
  dropout_p: 0.2
  norm_type: batch
  act_type: relu

lri_gaussian:
  epochs: 100                  # egnn: 50
  warmup: 0                    # egnn: 50
  pos_coef: 11                 # pointtrans: 7
  info_loss_coef: 0.01
  kr: 1.5
  one_encoder: true
  attn_constraint: none

  covar_dim: 2
  pred_loss_coef: 1.0
  pred_lr: 1.0e-3
  pred_wd: 1.0e-5
  dropout_p: 0.2
  norm_type: batch
  act_type: relu

gradcam:
  epochs: 1
  warmup: 100
  pred_lr: 0
  pred_wd: 0

gradgeo:
  epochs: 1
  warmup: 100
  gradgeo: true
  pred_lr: 0
  pred_wd: 0

bernmask_p:
  epochs: 50
  warmup: 100
  size_loss_coef: 0.01
  mask_ent_loss_coef: 1.0       # pointtrans: 0.1, egnn: 0.01

  temp: [1.0, 1.0]
  pred_loss_coef: 1.0
  pred_lr: 0
  pred_wd: 0
  dropout_p: 0.2
  norm_type: batch
  act_type: relu

bernmask:
  epochs: 1
  warmup: 100
  size_loss_coef: 0.1            # pointtrans: 0.01
  mask_ent_loss_coef: 0.01       # pointtrans: 0.1
  iter_lr: 1.0e-1

  iter_per_sample: 500
  pred_loss_coef: 1.0
  pred_lr: 0
  pred_wd: 0
  dropout_p: 0.2
  norm_type: batch
  act_type: relu

pointmask:
  epochs: 100
  warmup: 0
  t: 0.2
  kl_loss_coef: 0.01

  one_encoder: false
  covar_dim: 2  # MLP output 2 dims for mu and sigma
  pred_loss_coef: 1.0
  pred_lr: 1.0e-3
  pred_wd: 1.0e-5
  dropout_p: 0.2
  norm_type: batch
  act_type: relu
