project_name: &_project_name !ENV "${LOGNAME}_bfgraph_qm9withh-test"
exp_name: &_exp_name !SUB ${exp_name}
debug: !SUB ${debug}
no_wandb: !SUB ${no_wandb}
logging_level: !SUB ${logging_level:-warning}
input_dist_sample: !SUB ${input_dist_sample}

dataset : 
  name: &_dataset_name 'qm9'            # qm9, qm9_positional
  datadir: !PATHJOIN [ !ENV "${HOME}/project/data/digress/", *_dataset_name ]
  remove_h: False
  random_subset: False # for qm9
  pin_memory: False # for qm9
  filter: True                      # Use the filtered version or the raw guacamol file

model: 
  type: 'bayesian'
  model: 'graph_tf'
  #BFN 
  beta_node: !SUB ${beta_node:-3.0} # here  we se beta for the discrete version of bfn
  beta_edge: !SUB ${beta_edge:-3.0} 
  t_min: 1.0e-2

  #Prior
  transition: 'uniform'            # 'uniform', 'transition'
  time: !SUB ${time:-continuous}           # discrete or continuous
  node_time_scheduler: !SUB ${node_time_scheduler:-quad} # 'cosine' or 'linear'
  edge_time_scheduler: !SUB ${edge_time_scheduler:-quad} # 'cosine' or 'linear'

  sample_steps: !SUB ${sample_steps:-500} 
  n_layers: 7
  hidden_mlp_dims: {'X': 256, 'E': 128, 'y': 128}
  hidden_dims : {'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128}
  lambda_train_node: !SUB ${lambda_train_node:-1.0}
  lambda_train_edge: !SUB ${lambda_train_edge:-1.0}     
  lambda_train_y: !SUB ${lambda_train_y:-0.0}
  extra_features: !SUB ${extra_features:-None}         # 'all', 'cycles', 'eigenvalues' or null
  extra_mode: !SUB ${extra_mode:-prob} # 'prob' or 'logit'
  discretised_time: !SUB ${discretised_time:-False} # 'prob' or 'logit', whether we use the discretised time or not 
 

general: 
  wandb: 'online'             # online | offline | disabled
  gpus: 1                   # Multi-gpu is not implemented on this branch
  resume: null            # If resume, path to ckpt file from outputs directory in main directory
  test_only: ''         # Use absolute path

  check_val_every_n_epochs: 25
  check_point_every_n_epochs: 25

  sample_every_val: 4
  val_check_interval: null
  samples_to_generate: 512       # We advise to set it to 2 x batch_size maximum
  samples_to_save: 20
  chains_to_save: 1
  log_every_steps: 50
  number_chain_steps: 50        # Number of frames in each gif

  final_model_samples_to_generate: 10240
  final_model_samples_to_save: 8
  final_model_chains_to_save: 4

  evaluate_all_checkpoints: False 
  logdir: &_logdir !PATHJOIN [!ENV "${HOME}/project/logs/", *_project_name, *_dataset_name, *_exp_name]
  wandb_dir: *_logdir
  graphs_path: !PATHJOIN [*_logdir, "graphs"]
  chains_path: !PATHJOIN [*_logdir, "chains"]

  #DEBUG
  print_logit: False

  name: 'lambda1:7_test'

train : 
  n_epochs: !SUB ${epochs:-1000}
  batch_size: !SUB ${batch_size:-512}
  lr: !SUB ${lr:-0.0002}
  clip_grad: 1.0          # float, null to disable
  save_model: True
  num_workers: 0
  ema_decay: 0.999           # 'Amount of EMA decay, 0 means off. A reasonable value  is 0.999.'
  progress_bar: false
  weight_decay: !SUB ${weight_decay:-1e-12}
  optimizer: adamw # adamw,nadamw,nadam => nadamw for large batches, see http://arxiv.org/abs/2102.06356 for the use of nesterov momentum with large batches
  seed: 0