wandb:
    use_wandb: true
    project: pcqm
    entity: ENTITY
    name: supernode_l5

ckpt: false

dataset: pcqm-contact
data_path: datasets
batch_size: 256
lr: 1.e-3
weight_decay: 0.
debug: false
num_runs: 3
max_epoch: 1000
min_epoch: 100
patience: 100
log_test: True
scheduler_type: cos_with_warmup
scheduler_patience: 50

# encoder
encoder:
    lap:
        max_freqs: 4
        dim_pe: 16
        layers: 1
        raw_norm_type: null

#scorer_model: none
scorer_model:
    conv: gine
    hidden: 128
    num_conv_layers: 0
    num_mlp_layers: 2
    norm: batch_norm
    activation: gelu
    dropout: 0.
    num_centroids: 2

# extract the base nodes to super nodes
base2centroid:
    conv: gine
    num_conv_layers: 2
    num_mlp_layers: 3
    norm: batch_norm
    activation: gelu
    dropout: 0.
    centroid_aggr: mean

# the hetero, hierarchical GNN
hetero:
    conv: gine
    hidden: 128
    cent_hidden: 256
    num_conv_layers: 10
    num_mlp_layers: 2
    norm: batch_norm
    activation: gelu
    dropout: 0.
    residual: false
    delay: 2
    aggr: cat
    parallel: true

sampler:
    name: simple
    sample_k: 1  # by default 1, i.e., each node gets assigned to 1 centroid
    num_ensemble: 1
    n_samples: 2
    assign_value: false  # use marginals to assign weights on the node masks

hybrid_model:
    jk: identity
    target: base
    inter_pred_layer: 3
    intra_pred_layer: 2
    inter_ensemble_pool: mean
    intra_graph_pool: edge