wandb:
    use_wandb: true
    project: trees_new_model
    entity: ENTITY
    name: b2c

ckpt: false

dataset: tree_4
data_path: ./datasets
batch_size: 1024
lr: 1.e-3
weight_decay: 0.
debug: false
num_runs: 1
max_epoch: 1500
min_epoch: 400
patience: 150
scheduler_type: cos_with_warmup
scheduler_patience: 50
log_test: true

# encoder
encoder:
#     rwse:
#         kernel: 20
#         layers: 2
#         dim_pe: 32
#         raw_norm_type: 'BatchNorm'

#     lap:
#         max_freqs: 4
#         dim_pe: 32
#         layers: 1
#         raw_norm_type: null

#scorer_model: none
scorer_model:
    conv: gine
    hidden: 32
    num_conv_layers: 0
    num_mlp_layers: 2
    norm: batch_norm
    activation: gelu
    dropout: 0.
    num_centroids: 2

# extract the base nodes to super nodes
base2centroid:
    conv: gine
    num_conv_layers: 2
    num_mlp_layers: 3
    norm: batch_norm
    activation: gelu
    dropout: 0.
    centroid_aggr: mean


# the hetero, hierarchical GNN
hetero:
    conv: gine
    hidden: 32
    cent_hidden: 64
    num_conv_layers: 5
    num_mlp_layers: 2
    norm: batch_norm
    activation: gelu
    dropout: 0.
    residual: true
    delay: 0
    aggr: cat
    parallel: true

sampler:
    name: simple
    sample_k: 1  # by default 1, i.e., each node gets assigned to 1 centroid
    num_ensemble: 1
    n_samples: 2
    assign_value: false  # use marginals to assign weights on the node masks

hybrid_model:
    jk: identity
    target: centroid
    inter_pred_layer: 2
    intra_pred_layer: 2
    inter_ensemble_pool: mean
    intra_graph_pool: mean
#
#auxloss:
#    soft_empty: 0.01
#    hard_empty: 1.e-5

# plots:
#     plot_every: 20
# #    plot_folder: './plots'
# #    mask: true
# #    score: true
#     graph: true