wandb:
    use_wandb: true
    project: qm9_new_model
    entity: ENTITY

ckpt: false

dataset: QM9
task_id: 1
data_path: ./datasets
batch_size: 512
lr: 1.e-3
weight_decay: 0.
debug: false
num_runs: 1
max_epoch: 1500
min_epoch: 800
patience: 100
scheduler_type: cos_with_warmup
scheduler_patience: 50
log_test: true

# encoder
encoder:
    
#scorer_model: none
scorer_model:
    conv: gine
    hidden: 128
    num_conv_layers: 0
    num_mlp_layers: 2
    norm: batch_norm
    activation: gelu
    dropout: 0.
    num_centroids: 2

# extract the base nodes to super nodes
base2centroid:
    conv: gine
    num_conv_layers: 2
    num_mlp_layers: 3
    norm: batch_norm
    activation: gelu
    dropout: 0.
    centroid_aggr: mean


# the hetero, hierarchical GNN
hetero:
    conv: gine
    hidden: 128
    cent_hidden: 256
    num_conv_layers: 10
    num_mlp_layers: 2
    norm: batch_norm
    activation: gelu
    dropout: 0.
    residual: true
    delay: 2
    aggr: cat
    parallel: true

sampler:
    name: simple
    sample_k: 1  # by default 1, i.e., each node gets assigned to 1 centroid
    num_ensemble: 1
    n_samples: 2
    assign_value: false  # use marginals to assign weights on the node masks

hybrid_model:
    jk: identity
    target: centroid
    inter_pred_layer: 2
    intra_pred_layer: 2
    inter_ensemble_pool: mean
    intra_graph_pool: mean
#
#auxloss:
#    soft_empty: 0.01
#    hard_empty: 1.e-5

# plots:
#     plot_every: 20
# #    plot_folder: './plots'
# #    mask: true
# #    score: true
#     graph: true