wandb:
    use_wandb: true
    project: zinc-new
    entity: ENTITY
    name: zinc_func_hparam_b2c

ckpt: false

dataset: zinc
data_path: ./datasets
batch_size: 128
lr: 1.e-3
weight_decay: 0.
debug: false
num_runs: 5
max_epoch: 1500
min_epoch: 800
patience: 100
scheduler_type: cos_with_warmup
scheduler_patience: 50
log_test: true

# encoder
encoder:
    rwse:
        kernel: 20
        layers: 2
        dim_pe: 32
        raw_norm_type: 'BatchNorm'

    lap:
        max_freqs: 4
        dim_pe: 32
        layers: 1
        raw_norm_type: null

#scorer_model: none
scorer_model:
    conv: gine
    hidden: 128
    num_conv_layers: 0
    num_mlp_layers: 2
    norm: batch_norm
    activation: gelu
    dropout: 0.
    num_centroids: 2

# extract the base nodes to super nodes
base2centroid:
    conv: gine
    num_conv_layers: 2
    num_mlp_layers: 3
    norm: batch_norm
    activation: gelu
    dropout: 0.
    centroid_aggr: mean

# the hetero, hierarchical GNN
hetero:
    conv: gine
    hidden: 128
    cent_hidden: 256
    num_conv_layers: 10
    num_mlp_layers: 2
    norm: batch_norm
    activation: gelu
    dropout: 0.
    residual: true
    delay: 2
    aggr: cat
    parallel: true

sampler:
    name: simple
    sample_k: 1  # by default 1, i.e., each node gets assigned to 1 centroid
    num_ensemble: 1
    n_samples: 2
    assign_value: false  # use marginals to assign weights on the node masks

hybrid_model:
    jk: identity
    target: centroid
    inter_pred_layer: 2
    intra_pred_layer: 2
    inter_ensemble_pool: mean
    intra_graph_pool: mean
#
#auxloss:
#    soft_empty: 0.01
#    hard_empty: 1.e-5

plots:
    plot_every: 20
#    plot_folder: './plots'
#    mask: true
#    score: true
    graph: true