training:
  output_dir: runs_paper_ed/
  batch_size: 64
  num_workers: 8
  max_num_edges: 4.0e+5
  trainer_args:
    max_epochs: 1000
    accelerator: gpu
    devices: 1
    num_nodes: 1
    strategy: auto
    accumulate_grad_batches: 1
    limit_val_batches: 0.1
    gradient_clip_val: 1.0
    gradient_clip_algorithm: 'value'

  evaluation:
    mols_to_sample: 128 # how many molecules to sample during evaluation
    sample_interval: 0.2 # how often to sample molecules during training, measured in epochs
    val_loss_interval: 0.2 # how often to compute validation set loss during training, measured in epochs

wandb:
  project: mol-fm
  group:
  name: qm-dp-ep
  mode: online # can be disabled, online, offline

lr_scheduler:
  # to turn off warmup and restarts, set both warmup_length and restart_interval to 0
  base_lr: 1.0e-4
  warmup_length: 1.0
  restart_interval: 0 # 0 means no restart
  restart_type: 'linear'
  weight_decay: 1.0e-12

dataset:
  raw_data_dir: data/qm9_raw
  processed_data_dir: data/qm9
  atom_map: ['C', 'H', 'N', 'O', 'F',] 
  dataset_name: qm9 # must be qm9 or geom
  dataset_size: 

checkpointing:
  save_last: True
  save_top_k: 3
  monitor: 'val_total_loss'
  every_n_epochs: 1

mol_fm:
  parameterization: endpoint # can be "endpoint" or "vector-field"
  weight_ae: False
  target_blur: 0.0
  time_scaled_loss: True
  total_loss_weights:
    x: 3.0
    a: 0.4
    c: 1.0
    e: 2.0
  prior_config:
    x:
      align: True
      type: 'centered-normal'
      kwargs: {std: 1.0}
    a:
      align: False
      type: 'marginal'
      kwargs: {blur: 0.15}
    c:
      align: False
      type: 'c-given-a'
      kwargs: {blur: 0.15}
    e:
      align: False
      type: 'marginal'
      kwargs: {blur: 0.15}

vector_field:
  n_vec_channels: 16
  update_edge_w_distance: True
  n_hidden_scalars: 256
  n_hidden_edge_feats: 128 
  n_recycles: 1
  separate_mol_updaters: True
  n_molecule_updates: 8
  convs_per_update: 1
  n_cp_feats: 4
  n_message_gvps: 3
  n_update_gvps: 3
  message_norm: 100
  rbf_dmax: 14
  rbf_dim: 16

interpolant_scheduler:
  schedule_type:
    x: 'cosine'
    a: 'cosine'
    c: 'cosine'
    e: 'cosine'
  cosine_params:
    x: 1
    a: 2
    c: 2
    e: 1.5
