training:
  output_dir: runs_paper_ed/
  batch_size: 16
  num_workers: 4
  trainer_args:
    max_epochs: 20
    accelerator: gpu
    devices: 2
    num_nodes: 1
    strategy: ddp_find_unused_parameters_true
    accumulate_grad_batches: 2
    limit_val_batches: 0.002

  evaluation:
    mols_to_sample: 64 # how many molecules to sample during evaluation
    sample_interval: 0.012 # how often to sample molecules during training, measured in epochs
    val_loss_interval: 0.012 # how often to compute validation set loss during training, measured in epochs

wandb:
  project: mol-fm
  group: "for-reals"
  name: geom-gaussian
  mode: online # can be disabled, online, offline

lr_scheduler:
  # to turn off warmup and restarts, set both warmup_length and restart_interval to 0
  base_lr: 1.0e-4
  warmup_length: 1.0
  restart_interval: 0 # 0 means no restart
  restart_type: 'linear'
  weight_decay: 1.0e-12

dataset:
  raw_data_dir: data/raw/
  processed_data_dir: data/geom/
  # atom_map: ['C', 'H', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'As', 'Hg', 'Bi', 'Se']
  atom_map: ['C', 'H', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I',] # I removed As, Hg, Bi, Se. As, Hg, and Se are not in the dataset, and Bi is virtually not in the dataset.. p(a = Bi) = 1.0e-8
  dataset_name: geom # only geom is supported for now
  dataset_size:

checkpointing:
  save_last: True
  save_top_k: 3
  monitor: 'val_total_loss'
  every_n_epochs: 1

mol_fm:
  parameterization: endpoint # can be "endpoint", "vector-field", or "dirichlet"
  time_scaled_loss: True
  total_loss_weights:
    x: 3.0
    a: 0.4
    c: 1.0
    e: 2.0
  prior_config:
    x:
      align: True
      type: 'centered-normal'
      kwargs: {std: 1.0}
    a:
      align: False
      type: 'gaussian'
      kwargs: {}
    c:
      align: False
      type: 'gaussian'
      kwargs: {}
    e:
      align: False
      type: 'gaussian'
      kwargs: {}

vector_field:
  n_vec_channels: 16
  update_edge_w_distance: True
  n_hidden_scalars: 256
  n_hidden_edge_feats: 128 
  n_recycles: 1
  separate_mol_updaters: True
  n_molecule_updates: 5
  convs_per_update: 1
  n_cp_feats: 4
  n_message_gvps: 3
  n_update_gvps: 3
  message_norm: 100
  rbf_dmax: 12
  rbf_dim: 16

interpolant_scheduler:
  schedule_type:
    x: 'cosine'
    a: 'cosine'
    c: 'cosine'
    e: 'cosine'
  cosine_params:
    x: 1
    a: 2
    c: 2
    e: 2
