project_name: &_project_name !ENV "${LOGNAME}_bfgraph_sbm"
exp_name: &_exp_name !SUB ${exp_name}
debug: !SUB ${debug}
overfit_batches: !SUB ${overfit_batches:-0.0}
no_wandb: !SUB ${no_wandb}
logging_level: !SUB ${logging_level:-warning}


# Dev Switches
input_dist_sample: !SUB ${input_dist_sample}
force_symmetric_theta_E: !SUB ${force_symmetric_theta_E}
compare_input_output_dist_samples: !SUB ${compare_input_output_dist_samples}
plot_input_dist_entropy: !SUB ${plot_input_dist_entropy}

dataset : 
  name: &_dataset_name 'sbm'            
  datadir: !PATHJOIN [ !ENV "${HOME}/project/data/digress/", *_dataset_name ]
  remove_h: null              

model: 
  type: 'bayesian'
  model: 'graph_tf'
  # time: !SUB ${time:-continuous}           # discrete or continuous

  #BFN Hyper Params 
  beta_node: !SUB ${beta_node:-3.0} # here  we use beta for the discrete version of bfn
  beta_edge: !SUB ${beta_edge:-3.0} 

  beta_node_init: !SUB ${beta_node_init:-0.0}
  beta_edge_init: !SUB ${beta_edge_init:-0.0}
  # alternative_sampling_theta_update_ratio: !SUB ${alternative_sampling_theta_update_ratio:-0.0}   # [0.0,1.0]

  t_min: 1.0e-2
  # diffusion_noise_schedule: 'cosine'
  # diffusion_steps: !SUB ${sample_steps:-500}

  #Prior
  transition: !SUB ${prior:-uniform}          # 'uniform', 'marginal'
  node_time_scheduler: !SUB ${node_time_scheduler:-quad} # 'cosine' or 'linear'
  edge_time_scheduler: !SUB ${edge_time_scheduler:-quad} # 'cosine' or 'linear'
  alternative_sampling_theta_update_ratio: !SUB ${alternative_sampling_theta_update_ratio:-0.0}   # [0.0,1.0]

  sample_steps: !SUB ${sample_steps:-500}
  n_layers: 8
  hidden_mlp_dims: {'X': 128, 'E': 64, 'y': 128}
  hidden_dims : {'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 64, 'dim_ffy': 256}
  lambda_train_node: !SUB ${lambda_train_node:-1.0}
  lambda_train_edge: !SUB ${lambda_train_edge:-5.0}     
  lambda_train_y: !SUB ${lambda_train_y:-0.0}
  #discrete:
  extra_features: !SUB ${extra_features:-None}         # 'all', 'cycles', 'eigenvalues' or null
  extra_mode: !SUB ${extra_mode:-prob} # 'prob' or 'logit'
  n_iid: !SUB ${n_iid:-10} 
  discretised_time: !SUB ${discretised_time:-False} # whether we use the discretised time or not [discr]

  output_dist_extra: !SUB ${output_dist_extra:-False}
  
  
 

general: 
  wandb: 'online'             # online | offline | disabled
  resume: !SUB ${resume:-None}            # If resume, path to ckpt file from outputs directory in main directory
  test_only: !SUB ${test_ckpt_fname:-None} # Use absolute path
  sampling_bs: !SUB ${sampling_bs:-2048}

  check_val_every_n_epochs: 200
  check_point_every_n_epochs: 200

  sample_every_val: 1
  val_check_interval: null
  samples_to_generate: 32     # We advise to set it to 2 x batch_size maximum
  samples_to_save: 8
  chains_to_save: 8
  log_every_steps: 50
  number_chain_steps: 50        # Number of frames in each gif

  final_model_samples_to_generate: 40
  final_model_samples_to_save: 4
  final_model_chains_to_save: 4

  evaluate_all_checkpoints: False 
  logdir: &_logdir !PATHJOIN [!ENV "logs/", *_project_name, *_dataset_name, *_exp_name]
  wandb_dir: *_logdir
  graphs_path: !PATHJOIN [*_logdir, "graphs"]
  chains_path: !PATHJOIN [*_logdir, "chains"]

  #DEBUG
  print_logit: False

  name: 'bfn_sbm'

train : 
  n_epochs: !SUB ${epochs:-1000}
  batch_size: !SUB ${batch_size:-512}
  lr: !SUB ${lr:-0.0002}
  clip_grad: !SUB ${clip_grad:-0}         # float, null to disable
  save_model: True
  num_workers: 0
  ema_decay: !SUB ${ema_decay:-0}           # 'Amount of EMA decay, 0 means off. A reasonable value  is 0.999.'
  progress_bar: false
  weight_decay: !SUB ${weight_decay:-1e-12}
  optimizer: adamw # adamw,nadamw,nadam => nadamw for large batches, see http://arxiv.org/abs/2102.06356 for the use of nesterov momentum with large batches
  seed: 0