project_name: &_project_name !ENV "${LOGNAME}_bfgraph_qm9withh"
exp_name: &_exp_name !SUB ${exp_name}
debug: !SUB ${debug}
overfit_batches: !SUB ${overfit_batches:-0.0}
no_wandb: !SUB ${no_wandb}
logging_level: !SUB ${logging_level:-warning}

# Dev Switches
input_dist_sample: !SUB ${input_dist_sample}
force_symmetric_theta_E: !SUB ${force_symmetric_theta_E}
compare_input_output_dist_samples: !SUB ${compare_input_output_dist_samples}
plot_input_dist_entropy: !SUB ${plot_input_dist_entropy}

dataset :
  name: &_dataset_name 'qm9'            # qm9, qm9_positional
  datadir: !PATHJOIN [ !ENV "${HOME}/project/data/digress/", *_dataset_name ]
  remove_h: False
  random_subset: False # for qm9
  pin_memory: False # for qm9
  filter: True                      # Use the filtered version or the raw guacamol file

model: 
  type: 'bayesian'
  model: 'graph_tf'
  #BFN 
  beta_node: !SUB ${beta_node:-3.0} # here  we se beta for the discrete version of bfn
  beta_edge: !SUB ${beta_edge:-3.0} 

  beta_node_init: !SUB ${beta_node_init:-0.0}
  beta_edge_init: !SUB ${beta_edge_init:-0.0}

  t_min: 1.0e-2
  #Prior
  transition: !SUB ${prior:-uniform}            # 'uniform', 'marginal'
  # time: !SUB ${time:-continuous}           # discrete or continuous
  node_time_scheduler: !SUB ${node_time_scheduler:-quad} # 'cosine' or 'linear'
  edge_time_scheduler: !SUB ${edge_time_scheduler:-quad} # 'cosine' or 'linear'
  alternative_sampling_theta_update_ratio: !SUB ${alternative_sampling_theta_update_ratio:-0.0}   # [0.0,1.0]
  
  
  sample_steps: !SUB ${sample_steps:-500} 
  n_layers: 7

# At the moment (03/08), y contains quite little information
  hidden_mlp_dims: {'X': 256, 'E': 128, 'y': 128}

  # The dimensions should satisfy dx % n_head == 0
  hidden_dims : {'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128}
  # hidden_mlp_dims: {'X': 256, 'E': 256, 'y': 128}
  # hidden_dims : {'dx': 256, 'de': 128, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 256, 'dim_ffy': 128}
# # Do not set hidden_mlp_E, dim_ffE too high, computing large tensors on the edges is costly
# # At the moment (03/08), y contains quite little information
# hidden_mlp_dims: {'X': 256, 'E': 128, 'y': 128}
# # The dimensions should satisfy dx % n_head == 0
# hidden_dims : {'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128}
  lambda_train_node: !SUB ${lambda_train_node:-1.0}
  lambda_train_edge: !SUB ${lambda_train_edge:-1.0}     
  lambda_train_y: !SUB ${lambda_train_y:-0.0}
  # [5, 0] # the weight to balence the loss of edge and nodes.
  #discrete:
  extra_features: !SUB ${extra_features:-None}         # 'all', 'cycles', 'eigenvalues' or null
  extra_mode: !SUB ${extra_mode:-prob} # 'prob', 'logit', or 'iid'
  n_iid: !SUB ${n_iid:-10}
  discretised_time: !SUB ${discretised_time:-False} # whether we use the discretised time or not [discr]

  output_dist_extra: !SUB ${output_dist_extra:-False}
  
 

general: 
  wandb: 'online'             # online | offline | disabled
  gpus: 1                   # Multi-gpu is not implemented on this branch
  resume: null            # If resume, path to ckpt file from outputs directory in main directory
  test_only: !SUB ${test_ckpt_fname:-null}         # Use absolute path
  sampling_bs: !SUB ${sampling_bs:-2048}

  check_val_every_n_epochs: 20
  check_point_every_n_epochs: 20

  sample_every_val: 1
  val_check_interval: null
  samples_to_generate: 1024   # We advise to set it to 2 x batch_size maximum
  samples_to_save: 8
  chains_to_save: 8
  log_every_steps: 50
  number_chain_steps: 50        # Number of frames in each gif

  final_model_samples_to_generate: 10000
  final_model_samples_to_save: 30
  final_model_chains_to_save: 20

  evaluate_all_checkpoints: False 
  logdir: &_logdir !PATHJOIN [!ENV "logs/", *_project_name, *_dataset_name, *_exp_name]
  wandb_dir: *_logdir
  graphs_path: !PATHJOIN [*_logdir, "graphs"]
  chains_path: !PATHJOIN [*_logdir, "chains"]

  #DEBUG
  print_logit: False

  name: 'qm9_beta+scheduling_tune'

train : 
  n_epochs: !SUB ${epochs:-1000}
  batch_size: !SUB ${batch_size:-512}
  lr: !SUB ${lr:-0.0002}
  clip_grad: !SUB ${clip_grad:-0}          # float, null to disable
  save_model: True
  num_workers: 0
  ema_decay: !SUB ${ema_decay:-0}         # 'Amount of EMA decay, 0 means off. A reasonable value  is 0.999.'
  progress_bar: false
  weight_decay: !SUB ${weight_decay:-1e-12}
  optimizer: adamw # adamw,nadamw,nadam => nadamw for large batches, see http://arxiv.org/abs/2102.06356 for the use of nesterov momentum with large batches
  seed: 0

# accounting:
#   logdir: &_logdir !PATHJOIN [!ENV "${HOME}/project/logs/", *_project_name, !SUB "${exp_name}"]
#   wandb_logdir: *_logdir
#   checkpoint_dir: !PATHJOIN [*_logdir, "checkpoints"]
#   generated_mol_dir: !PATHJOIN [*_logdir, "generated_mol"]
#   checkpoint_freq: 5