out_dir: results
experiment_name: EXPE_NAME
# metric_best: ce
# metric_agg: argmin
# tensorboard_each_run: True  # Log to Tensorboard each run
accelerator: "cuda:0"
mlflow:
  use: False
  project: Exp
  name: ocb-GRIT-RRWP
wandb:
  use: True
  project: GenerativeCircuit # GenerativeCircuit or DeepSimulator ## CHANGE CLASSIFICATION
dataset:
  format: custom
  name: ocb_CktBench101
  task: graph
  ##
  task_type: generative # generative or classification ## CHANGE CLASSIFICATION
  ##
  transductive: False
  node_encoder_name: OCBNode
  nnode_types: 11 # 11 in the version with nets, else 10
  nnode_features: 101 # from 1 to 101 for ocb
  node_encoder_bn: False
  edge_encoder_name: TypeDictEdge
  nedge_types: 2
  edge_encoder_bn: False
  ### Device sizing ###
  node_features_dim: 1 # 1 for topology generation only, 2 for sizing also --> DEPRECATED, replaced by gt.sizing
  ###
  ### Use a novel version of the dataset with pins/nets explicitly modelled
  use_pins: False # if true, then make gm i/o pins explicit (classes 6 & 7)
  ###
  ### Choose a larger version of the OCB dataset
  large_idx: (False, '') ## CHANGE CLASSIFICATION
  version: v3
  scaled_features: True
  ###
  subcircuit: False
  directed: False
posenc_RRWP:
  enable: True
  ksteps: 12 #
  add_identity: True
  add_node_attr: False
  add_inverse: False
  spse: False
train:
  mode: dfm
  prior: marginal # masked or marginal
  batch_size: 64 # 32 if generative, 64 for classification ## CHANGE CLASSIFICATION
  eval_period: 5 # 5 if generative, 2 for classification ## CHANGE CLASSIFICATION
  enable_ckpt: True  # Checkpointing can now be disabled to save I/O when e.g. just benchmarking.
  ckpt_best: True  # WARNING: Checkpoint every epoch a better model is found may increase I/O significantly.
  ckpt_clean: True # Delete old ckpt each time.
  save_final_model : True
  use_hungarian: False
  ratio_cf_guidance: 0.1
  ###
  t_sample_distortion_n: 'pow' # pow or norm
  t_sample_distortion_e: 'pow' # pow or norm
  t_sample_distortion_f: 'pow' # pow or norm
  distortion_pow_n: 4 # If > 1, skews node noising time t_x sampling towards higher values.
  distortion_pow_e: 6 # Same for edges.
  distortion_pow_f: 3 # Same for node features.
  ### Train a sizing model ###
  noise_feat_only: False # If true, train to denoise device sizing only
  ### Path to the pretrained classifier weights & config goes here
  classifier_path: ''
  use_classifier: False
  ###
  ## Softmax temperature applied on generative model output before classifer guidance ##
  classifier_input_temp: 0.7
  ##
framework:
  ### Model ###
  type: defog # vfm or defog
  ###
model:
  type: GritTransformer
  loss_fun: l1
  edge_decoding: dot
  graph_pooling: mean
gt:
  layer_type: GritTransformer
  layers: 10
  n_heads: 8
  dim_hidden: 64  # `gt.dim_hidden` must match `gnn.dim_inner`
  dropout: 0.0
  layer_norm: True
  batch_norm: False
  bn_momentum: 0.01
  update_e: True
  attn_dropout: 0.2
  ### Conditioning ###
  conditional_gen: True # Conditional generation - DiT style in SA layers + cls-free guidance
  sep_t_spec_cond: False ## Whether to condition on specs in a parallel way as on time, i.e. with additive & multiplicative biases in the tf 
  add_spec_cond_to_t: False ## Whether to embed spec conditioning with time conditioning in the FeatureEncoder module
  supernode: True
  conditioning_loss: 'cfg' ## Classifier-free guidance (cfg) vs classifier guidance (cg)
  # conditional_dim: 1 # 0 for gain, 1 for ugw, 2 for pm
  ### Time conditioning ###
  time_conditioning: True ###
  ### Node number prediction ###
  node_pruning: 0 ### 0: no node pruning, 1: add disconnected noise nodes - learn edges, 2: add disconnected nodes from a new class - learn nodes & edges
  ###
  guidance_strength: 2
  n_rbf_centroids: 100
  ### Separate t per dimension ###
  sample_separate_t: True # If True, one t is sampled for each node and each edge separately during training, else t is sampled /
  # **per modality**.
  ###
  ### Sizing ###
  sizing: True
  ###
  ## Single branch for x and x_features ##
  process_feats_with_x: False
  ###
  attn:
    clamp: 5.
    act: 'relu'
    full_attn: True
    edge_enhance: True
    O_e: True
    norm_e: True
    fwl: False
    use_bias: True # False ?
    #### Whether to merge node type & node features before each SA layer ###
    x_f_coupling: True ## CHANGE CLASSIFICATION
gnn: # decoder --> san_graph = MLP + pooling
  head: san_graph
  ## Take x & features in the classification head
  dual_head: False
  ##
  layers_pre_mp: 0
  layers_post_mp: 3  # Not used when `gnn.head: san_graph`
  dim_inner: 64  # `gt.dim_hidden` must match `gnn.dim_inner`
  # dim_edge: 64
  batchnorm: True
  # act: silu
  dropout: 0.0
  agg: mean
  normalize_adj: False
  ### Number of conditioning quantities (here 3 for gain, ugw & pm)
  n_spec: 1
  spec_dim: 0
  ###
  ### Number of classes / discretized bins for each spec
  n_bins: 4
  ###
optim:
  clip_grad_norm: True
  optimizer: adamW
  weight_decay: 1e-5
  base_lr: 1e-3 # 1e-3 if generative ## CHANGE CLASSIFICATION
  max_epoch: 600 ## CHANGE CLASSIFICATION
  num_warmup_epochs: 50 # 50 if generative ## CHANGE CLASSIFICATION
  scheduler: cosine_with_warmup
  min_lr: 1e-6
  cross_entropy_epoch: 0 ## At which epoch to start the cross entropy loss
loss:
  edge_weight: 3
  feature_weight: 2.0
  cross_entropy: 1.0