# a simple example config file

# Two folders will be used during the training: 'root'/process and 'root'/'run_name'
# run_name contains logfiles and saved models
# process contains processed data sets
# if 'root'/'run_name' exists, 'root'/'run_name'_'year'-'month'-'day'-'hour'-'min'-'s' will be used instead.
root: MODELPATH/nequip
run_name: md17-salicylic_acid_10k_nequip_pat20
seed: 123                                                                         # model seed
dataset_seed: 456                                                                 # data set seed
append: true                                                                      # set true if a restarted run should append to the previous log file
default_dtype: float32                                                            # type of float to use, e.g. float32 and float64

# network
r_max: 4.0                                                                        # cutoff radius in length units, here Angstrom, this is an important hyperparamter to scan
num_layers: 5                                                                  # number of interaction blocks, we find 3-5 to work best
l_max: 3                                                                          # the maximum irrep order (rotation order) for the network's features, l=1 is a good default, l=2 is more accurate but slower
parity: true                                                                      # whether to include features with odd mirror parity; often turning parity off gives equally good results but faster networks, so do consider this
num_features: 32                                                                  # the multiplicity of the features, 32 is a good default for accurate network, if you want to be more accurate, go larger, if you want to be faster, go lower
nonlinearity_type: gate                                                           # may be 'gate' or 'norm', 'gate' is recommended

# scalar nonlinearities to use — available options are silu, ssp (shifted softplus), tanh, and abs.
# Different nonlinearities are specified for e (even) and o (odd) parity;
# note that only tanh and abs are correct for o (odd parity)
# silu typically works best for even 
nonlinearity_scalars:
  e: silu
  o: tanh

nonlinearity_gates:
  e: silu
  o: tanh

# radial network basis
num_basis: 8                                                                      # number of basis functions used in the radial basis, 8 usually works best
BesselBasis_trainable: true                                                       # set true to train the bessel weights
PolynomialCutoff_p: 6                                                             # p-exponent used in polynomial cutoff function, smaller p corresponds to stronger decay with distance

# radial network
invariant_layers: 2                                                               # number of radial layers, usually 1-3 works best, smaller is faster
invariant_neurons: 64                                                             # number of hidden neurons in radial function, smaller is faster
avg_num_neighbors: auto                                                           # number of neighbors to divide by, null => no normalization, auto computes it based on dataset 
use_sc: true                                                                      # use self-connection or not, usually gives big improvement

# data set
# the keys used need to be stated at least once in key_mapping, npz_fixed_field_keys or npz_keys
# key_mapping is used to map the key in the npz file to the NequIP default values (see data/_key.py)
# all arrays are expected to have the shape of (nframe, natom, ?) except the fixed fields
# note that if your data set uses pbc, you need to also pass an array that maps to the nequip "pbc" key
dataset: npz                                                                       # type of data set, can be npz or ase
dataset_file_name: DATAPATH/md17/salicylic_acid/10k/train/nequip_npz.npz                       # path to data set file
key_mapping:
  z: atomic_numbers                                                                # atomic species, integers
  E: total_energy                                                                  # total potential eneriges to train to
  F: forces                                                                        # atomic forces to train to
  R: pos                                                                           # raw atomic positions
npz_fixed_field_keys:                                                              # fields that are repeated across different examples
  - atomic_numbers

validation_dataset: npz
validation_dataset_file_name: DATAPATH/md17/salicylic_acid/10k/val/nequip_npz.npz                   # path to data set file

# A list of atomic types to be found in the data. The NequIP types will be named with the chemical symbols, and inputs with the correct atomic numbers will be mapped to the corresponding types.
chemical_symbols:
  - H
  - C
  - O

# logging
wandb: true                                                                        # we recommend using wandb for logging
wandb_project: mdbench                                                 # project name used in wandb

verbose: info                                                                      # the same as python logging, e.g. warning, info, debug, error; case insensitive
log_batch_freq: 10                                                                 # batch frequency, how often to print training errors withinin the same epoch
log_epoch_freq: 1                                                                  # epoch frequency, how often to print 
save_checkpoint_freq: -1                                                           # frequency to save the intermediate checkpoint. no saving of intermediate checkpoints when the value is not positive.
save_ema_checkpoint_freq: -1                                                       # frequency to save the intermediate ema checkpoint. no saving of intermediate checkpoints when the value is not positive.

# training
n_train: 9500                                                                       # number of training data
n_val: 500                                                                          # number of validation data
learning_rate: 0.005                                                               # learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune
batch_size: 5                                                                      # batch size, we found it important to keep this small for most applications including forces (1-5); for energy-only training, higher batch sizes work better
max_epochs: 2000                                                                    # stop training after _ number of epochs, we set a small number here to have an example that finished within a few minutes, but in practice we recommend using a very large number, as e.g. 1million and then to just use early stopping and not train the full number of epochs
train_val_split: random                                                            # can be random or sequential. if sequential, first n_train elements are training, next n_val are val, else random, usually random is the right choice
shuffle: true                                                                      # if true, the data loader will shuffle the data, usually a good idea
metrics_key: validation_loss                                                       # metrics used for scheduling and saving best model. Options: `set`_`quantity`, set can be either "train" or "validation, "quantity" can be loss or anything that appears in the validation batch step header, such as f_mae, f_rmse, e_mae, e_rmse
use_ema: true                                                                      # if true, use exponential moving average on weights for val/test, usually helps a lot with training, in particular for energy errors
ema_decay: 0.99                                                                    # ema weight, typically set to 0.99 or 0.999
ema_use_num_updates: true                                                          # whether to use number of updates when computing averages
report_init_validation: true                                                       # if True, report the validation error for just initialized model

# early stopping based on metrics values.
early_stopping_patiences:                                                          # stop early if a metric value stopped decreasing for n epochs
  validation_loss: 50

early_stopping_lower_bounds:                                                       # stop early if a metric value is lower than the bound
  LR: 1.0e-6

early_stopping_upper_bounds:                                                       # stop early if a metric value is higher than the bound
  cumulative_wall: 604800

# loss function
loss_coeffs:                                                                        
  forces: 1000                                                                      # if using PerAtomMSELoss, a default weight of 1:1 on each should work well
  total_energy:                                                                    
    - 1
    - PerAtomMSELoss

# output metrics
metrics_components:
  - - forces                               # key 
    - mae                                  # "rmse" or "mae"
  - - forces
    - rmse
  - - forces
    - mae
    - PerSpecies: True                     # if true, per species contribution is counted separately
      report_per_component: False          # if true, statistics on each component (i.e. fx, fy, fz) will be counted separately
  - - forces                                
    - rmse                                  
    - PerSpecies: True                     
      report_per_component: False    
  - - total_energy
    - mae    
  - - total_energy
    - mae
    - PerAtom: True                        # if true, energy is normalized by the number of atoms

# optimizer, may be any optimizer defined in torch.optim
# the name `optimizer_name`is case sensitive
optimizer_name: Adam                                                               # default optimizer is Adam 
optimizer_amsgrad: false

# lr scheduler, currently only supports the two options listed in full.yaml, i.e. on-pleteau and cosine annealing with warm restarts, if you need more please file an issue
# here: on-plateau, reduce lr by factory of lr_scheduler_factor if metrics_key hasn't improved for lr_scheduler_patience epoch
lr_scheduler_name: ReduceLROnPlateau
lr_scheduler_patience: 5
lr_scheduler_factor: 0.8

# we provide a series of options to shift and scale the data
# these are for advanced use and usually the defaults work very well
# the default is to scale the atomic energy and forces by scaling them by the force standard deviation and to shift the energy by the mean atomic energy
# in certain cases, it can be useful to have a trainable shift/scale and to also have species-dependent shifts/scales for each atom

# whether the shifts and scales are trainable. Defaults to False. Optional
per_species_rescale_shifts_trainable: false
per_species_rescale_scales_trainable: false

# initial atomic energy shift for each species. default to the mean of per atom energy. Optional
# the value can be a constant float value, an array for each species, or a string that defines a statistics over the training dataset
per_species_rescale_shifts: dataset_per_atom_total_energy_mean

# initial atomic energy scale for each species. Optional.
# the value can be a constant float value, an array for each species, or a string
per_species_rescale_scales: dataset_forces_rms

# if explicit numbers are given for the shifts/scales, this parameter must specify whether the given numbers are unitless shifts/scales or are in the units of the dataset. If ``True``, any global rescalings will correctly be applied to the per-species values.
# per_species_rescale_arguments_in_dataset_units: True
