batch_size: 32
num_workers: 9
l1_lambda: 0.0015

data:
  dataset: cifar10_amortization
  dataset_path: ./data/cifar10/
  data_path: ./data/cifar10/weights.npy
  metrics_path: ./data/cifar10/metrics.csv.gz
  layout_path: ./data/cifar10/layout.csv
  idcs_file: ./data/cifar10/cifar10_split.csv
  activation_function: relu
  node_pos_embed: &node_pos_embed True
  edge_pos_embed: &edge_pos_embed False
  layer_layout: [1, 16, 16, 16, 10]
  accuracy_threshold: 0.4

cifar_data:
    dataset: cifar10
    dataset_path: ./data/cifar10/grayscale/
    batch_fraction: 0.5
  
train_args:
    num_epochs: 300   #+ early stopping 
    seed: 42
    output_path: output_parameters

scalegmn_args:
    d_in_v: &d_in_v 1  # initial dimension of input nn bias
    d_in_e: &d_in_e 1  # initial dimension of input nn weights
    d_hid: &d_hid 128  # hidden dimension
    num_layers: 8 # number of gnn layers to apply
    direction: bidirectional
    equivariant: True
    symmetry: scale  # symmetry
    jit: False # prefer compile - compile gnn to optimize performance
    compile: False # compile gnn to optimize performance
    
    out_scale: 0.01
    gnn_skip_connections: True
    layer_layout: [1, 16, 16, 16, 10]
    
    concat_mlp_directions: False
    reciprocal: True
    
    node_pos_embed: *node_pos_embed  # use positional encodings
    edge_pos_embed: *edge_pos_embed  # use positional encodings

    _max_kernel_height: 3
    _max_kernel_width: 3
    
    graph_init:
        d_in_v: *d_in_v
        d_in_e: *d_in_e
        project_node_feats: True
        project_edge_feats: True
        d_node: *d_hid
        d_edge: *d_hid
        
    positional_encodings:
        final_linear_pos_embed: False
        sum_pos_enc: False
        po_as_different_linear: False
        equiv_net: False
        # args for the equiv net option.
        sum_on_io: True
        equiv_on_hidden: True
        num_mlps: 3
        layer_equiv_on_hidden: False
        
    gnn_args:
        d_hid: *d_hid
        message_fn_layers: 1
        message_fn_skip_connections: False
        update_node_feats_fn_layers: 1
        update_node_feats_fn_skip_connections: False
        update_edge_attr: True
        dropout: 0
        dropout_all: True  # False: only in between the gnn layers, True: + all mlp layers
        update_as_act: False
        update_as_act_arg: sum
        mlp_on_io: False
    
        msg_equiv_on_hidden: True
        upd_equiv_on_hidden: True
        layer_msg_equiv_on_hidden: False
        layer_upd_equiv_on_hidden: False
        msg_num_mlps: 2
        upd_num_mlps: 2
        pos_embed_msg: False
        pos_embed_upd: False
        layer_norm: False
        aggregator: add
        sign_symmetrization: False
        edge_output_dimension: 9 # max_kernel_size = 3 => edge dimension = 9, needed for equivariant CNN task.
        
    mlp_args:
        d_k: [ *d_hid ]
        activation: silu
        dropout: 0
        final_activation: identity
        batch_norm: False
        layer_norm: True
        bias: True
        skip: False
        break_symmetry: False
        

optimization:
  clip_grad: True
  clip_grad_max_norm: 10.0
  optimizer_name: AdamW
  optimizer_args:
    lr: 0.0005
    weight_decay: 0.01
  scheduler_args:
    scheduler: WarmupLRScheduler
    warmup_steps: 1000
    scheduler_mode: min
    decay_rate: 0
    decay_steps: 0
    patience: 30
    min_lr: None