out_dir: results
experiment_name: geni_full_topo
# metric_best: ce
# metric_agg: argmin
# tensorboard_each_run: True  # Log to Tensorboard each run
accelerator: "cuda:0"
mlflow:
  use: False
  project: Exp
  name: ocb-GRIT-RRWP
wandb:
  use: True
  project: GenerativeCircuit
dataset:
  format: custom
  name: AnalogGenie
  task: graph
  task_type: generative
  transductive: False
  node_encoder: True
  node_encoder_name: AnalogGenieNode
  nnode_types: 88
  nnode_features: 0 # from 1 to 101 for ocb
  node_encoder_bn: False
  edge_encoder: True
  edge_encoder_name: TypeDictEdge
  nedge_types: 2
  edge_encoder_bn: False
  # label_dimension: 0
  node_features_dim: 1 # 1 for topology generation only, 2 (or more) if the model should also learn to size the devices
  subcircuit: False
  directed: False
  pins: True
  pin_prediction: False
posenc_RRWP:
  enable: True
  ksteps: 12 # -> we should try to increase this
  add_identity: True
  add_node_attr: False
  add_inverse: False
  spse: False
train:
  mode: dfm
  prior: marginal # masked or marginal
  batch_size: 6
  eval_period: 5
  enable_ckpt: True  # Checkpointing can now be disabled to save I/O when e.g. just benchmarking.
  ckpt_best: True  # WARNING: Checkpoint every epoch a better model is found may increase I/O significantly.
  ckpt_clean: True # Delete old ckpt each time.
  save_final_model : True
  noising_edge: stochastic
  use_hungarian: False
  ratio_cf_guidance: 0.1
  sample_separate_t: False
  t_sample_distortion_n: 'pow' # pow or norm
  t_sample_distortion_e: 'pow' # pow or norm
  t_sample_distortion_f: 'pow' # pow or norm
  distortion_pow_n: 4 # If > 1, skews node noising time t_x sampling towards higher values.
  distortion_pow_e: 6 # Same for edges.
  distortion_pow_f: 2 # Same for node features.
framework:
  type: defog # vfm or defog
model:
  type: GritTransformer
  loss_fun: l1
  edge_decoding: dot
  graph_pooling: add
gt:
  layer_type: GritTransformer
  layers: 10
  n_heads: 8
  dim_hidden: 64  # `gt.dim_hidden` must match `gnn.dim_inner`
  dropout: 0.0
  layer_norm: True
  batch_norm: False
  bn_momentum: 0.01
  update_e: True
  attn_dropout: 0.2
  ### Conditioning ###
  conditional_gen: False # Conditional generation - DiT style in SA layers + cls-free guidance
  ### Time conditioning ###
  time_conditioning: True ###
  sample_separate_t: True 
  ### Node number prediction ###
  node_pruning: False ###
  ###
  guidance_strength: 2
  n_rbf_centroids: 100
  # rrwp_scaling_fact: 1.0
  sizing: False
  attn:
    clamp: 5.
    act: 'relu'
    full_attn: True
    edge_enhance: True
    O_e: True
    norm_e: True
    fwl: False
    use_bias: True # False ?
gnn: # decoder --> san_graph = MLP + pooling
  head: san_graph
  layers_pre_mp: 0
  layers_post_mp: 3  # Not used when `gnn.head: san_graph`
  dim_inner: 64  # `gt.dim_hidden` must match `gnn.dim_inner`
  batchnorm: True
  act: relu
  dropout: 0.0
  agg: mean
  normalize_adj: False
optim:
  clip_grad_norm: True
  optimizer: adamW
  weight_decay: 1e-5
  base_lr: 1e-3
  max_epoch: 600
  num_warmup_epochs: 50
  scheduler: cosine_with_warmup
  min_lr: 1e-6
loss:
  edge_weight: 3
  feature_weight: 0

