# IMPORTANT: READ THIS

# This is a full yaml file with all nequip options.
# It is primarily intented to serve as documentation/reference for all options
# For a simpler yaml file containing all necessary features to get you started, we strongly recommend to start with configs/example.yaml

# Two folders will be used during the training: 'root'/process and 'root'/'run_name'
# run_name contains logfiles and saved models
# process contains processed data sets
# if 'root'/'run_name' exists, 'root'/'run_name'_'year'-'month'-'day'-'hour'-'min'-'s' will be used instead.
root: results/toluene
run_name: example-run-toluene
seed: 123                                                                         # model seed
dataset_seed: 456                                                                 # data set seed
append: true                                                                      # set true if a restarted run should append to the previous log file

# see https://arxiv.org/abs/2304.10061 for discussion of numerical precision
# see https://arxiv.org/abs/2304.10061 for discussion of numerical precision
default_dtype: float64
model_dtype: float32
allow_tf32: true    # consider setting to false if you plan to mix training/inference over any devices that are not NVIDIA Ampere or later

# == network ==

# `model_builders` defines a series of functions that will be called to construct the model
# each model builder has the opportunity to update the model, the config, or both
# model builders from other packages are allowed (see mir-group/allegro for an example); those from `nequip.model` don't require a prefix
# these are the default model builders:
model_builders:
 - SimpleIrrepsConfig         # update the config with all the irreps for the network if using the simplified `l_max` / `num_features` / `parity` syntax
 - EnergyModel                # build a full NequIP model
 - PerSpeciesRescale          # add per-atom / per-species scaling and shifting to the NequIP model before the total energy sum
 - ForceOutput                # wrap the energy model in a module that uses autodifferention to compute the forces
 - RescaleEnergyEtc           # wrap the entire model in the appropriate global rescaling of the energy, forces, etc.
#   ^ global rescaling blocks must always go last!

r_max: 4.0                                                                        # cutoff radius in length units, here Angstrom, this is an important hyperparamter to scan
num_layers: 4                                                                     # number of interaction blocks, we find 3-5 to work best

l_max: 1                                                                          # the maximum irrep order (rotation order) for the network's features, l=1 is a good default, l=2 is more accurate but slower
parity: true                                                                      # whether to include features with odd mirror parityy; often turning parity off gives equally good results but faster networks, so do consider this
num_features: 32                                                                  # the multiplicity of the features, 32 is a good default for accurate network, if you want to be more accurate, go larger, if you want to be faster, go lower

# alternatively, the irreps of the features in various parts of the network can be specified directly:
# the following options use e3nn irreps notation
# either these four options, or the above three options, should be provided--- they cannot be mixed.
# chemical_embedding_irreps_out: 32x0e                                              # irreps for the chemical embedding of species
# feature_irreps_hidden: 32x0o + 32x0e + 32x1o + 32x1e                              # irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster
# irreps_edge_sh: 0e + 1o                                                           # irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer
# conv_to_output_hidden_irreps_out: 16x0e                                           # irreps used in hidden layer of output block


nonlinearity_type: gate                                                           # may be 'gate' or 'norm', 'gate' is recommended
resnet: false                                                                     # set true to make interaction block a resnet-style update
                                                                                  # the resnet update will only be applied when the input and output irreps of the layer are the same

# scalar nonlinearities to use — available options are silu, ssp (shifted softplus), tanh, and abs.
# Different nonlinearities are specified for e (even) and o (odd) parity;
# note that only tanh and abs are correct for o (odd parity).
# silu typically works best for even 
nonlinearity_scalars:
  e: silu
  o: tanh

nonlinearity_gates:
  e: silu
  o: tanh

# radial network basis
num_basis: 8                                                                      # number of basis functions used in the radial basis, 8 usually works best
BesselBasis_trainable: true                                                       # set true to train the bessel weights
PolynomialCutoff_p: 6                                                             # p-exponent used in polynomial cutoff function, smaller p corresponds to stronger decay with distance

# radial network
invariant_layers: 2                                                               # number of radial layers, usually 1-3 works best, smaller is faster
invariant_neurons: 64                                                             # number of hidden neurons in radial function, smaller is faster
avg_num_neighbors: auto                                                           # number of neighbors to divide by, null => no normalization, auto computes it based on dataset 
use_sc: true                                                                      # use self-connection or not, usually gives big improvement

# to specify different parameters for each convolutional layer, try examples below
# layer1_use_sc: true                                                             # use "layer{i}_" prefix to specify parameters for only one of the layer,
# priority for different definitions:
#   invariant_neurons < InteractionBlock_invariant_neurons < layer{i}_invariant_neurons

# data set
# there are two options to specify a dataset, npz or ase
# npz works with npz files, ase can ready any format that ase.io.read can read
# in most cases working with the ase option and an extxyz file is by far the simplest way to do it and we strongly recommend using this
# simply provide a single extxyz file that contains the structures together with energies and forces (generated with ase.io.write(atoms, format='extxyz', append=True))

# # for extxyz file
# dataset: ase
# dataset_file_name: H2.extxyz
# ase_args:
#   format: extxyz
# include_keys:
#   - user_label
# key_mapping:
#   user_label: label0


# alternatively, you can read directly from a VASP OUTCAR file (this will only read that single OUTCAR)
# # for VASP OUTCAR, the yaml input should be
# dataset: ase
# dataset_file_name: OUTCAR
# ase_args:
#   format: vasp-out
# important VASP note: the ase vasp parser stores the potential energy to "free_energy" instead of "energy".
# Here, the key_mapping maps the external name (key) to the NequIP default name (value)
# key_mapping:
#   free_energy: total_energy

# npz example 
# the keys used need to be stated at least once in key_mapping, npz_fixed_field_keys or include_keys
# key_mapping is used to map the key in the npz file to the NequIP default values (see data/_key.py)
# all arrays are expected to have the shape of (nframe, natom, ?) except the fixed fields
# note that if your data set uses pbc, you need to also pass an array that maps to the nequip "pbc" key
dataset: npz                                                                       # type of data set, can be npz or ase
dataset_url: http://quantum-machine.org/gdml/data/npz/toluene_ccsd_t.zip           # url to download the npz. optional
dataset_file_name: ./benchmark_data/toluene_ccsd_t-train.npz                       # path to data set file
key_mapping:
  z: atomic_numbers                                                                # atomic species, integers
  E: total_energy                                                                  # total potential eneriges to train to
  F: forces                                                                        # atomic forces to train to
  R: pos                                                                           # raw atomic positions
npz_fixed_field_keys:                                                              # fields that are repeated across different examples
  - atomic_numbers

# A list of chemical species found in the data. The NequIP atom types will be named after the chemical symbols and ordered by atomic number in ascending order.
# (In this case, NequIP's internal atom type 0 will be named H and type 1 will be named C.)
# Atoms in the input will be assigned NequIP atom types according to their atomic numbers.
chemical_symbols:
  - H
  - C

# Alternatively, you may explicitly specify which chemical species in the input will map to NequIP atom type 0, which to atom type 1, and so on.
# Other than providing an explicit order for the NequIP atom types, this option behaves the same as `chemical_symbols`
# chemical_symbol_to_type:
#   H: 0
#   C: 1

# Alternatively, if the dataset has type indices, you may give the names for the types in order:
# (this also sets the number of types)
# type_names:
#   - my_type
#   - atom
#   - thing

# As an alternative option to npz, you can also pass data ase ASE Atoms-objects
# This can often be easier to work with, simply make sure the ASE Atoms object
# has a calculator for which atoms.get_potential_energy() and atoms.get_forces() are defined
# dataset: ase
# dataset_file_name: xxx.xyz                                                       # need to be a format accepted by ase.io.read
# ase_args:                                                                        # any arguments needed by ase.io.read
#   format: extxyz

# If you want to use a different dataset for validation, you can specify
# the same types of options using a `validation_` prefix:
# validation_dataset: ase
# validation_dataset_file_name: xxx.xyz                                            # need to be a format accepted by ase.io.read

# logging
wandb: true                                                                        # we recommend using wandb for logging
wandb_project: toluene-example                                                     # project name used in wandb
wandb_watch: false

# # using tensorboard for logging
# tensorboard: true

# see https://docs.wandb.ai/ref/python/watch
# wandb_watch_kwargs:
#   log: all
#   log_freq: 1
#   log_graph: true

verbose: info                                                                      # the same as python logging, e.g. warning, info, debug, error. case insensitive
log_batch_freq: 100                                                                # batch frequency, how often to print training errors withinin the same epoch
log_epoch_freq: 1                                                                  # epoch frequency, how often to print 
save_checkpoint_freq: -1                                                           # frequency to save the intermediate checkpoint. no saving of intermediate checkpoints when the value is not positive.
save_ema_checkpoint_freq: -1                                                       # frequency to save the intermediate ema checkpoint. no saving of intermediate checkpoints when the value is not positive.

# training
n_train: 100                                                                       # number of training data
n_val: 50                                                                          # number of validation data
# alternatively, n_train and n_val can be set as percentages of the dataset size:
# n_train: 70%  # 70% of dataset
# n_val: 30%    # 30% of dataset (if validation_dataset not set), or 30% of validation_dataset (if set)
learning_rate: 0.005                                                               # learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune
batch_size: 5                                                                      # batch size, we found it important to keep this small for most applications including forces (1-5); for energy-only training, higher batch sizes work better
validation_batch_size: 10                                                          # batch size for evaluating the model during validation. This does not affect the training results, but using the highest value possible (<=n_val) without running out of memory will speed up your training.
max_epochs: 100000                                                                 # stop training after _ number of epochs, we set a very large number here, it won't take this long in practice and we will use early stopping instead
train_val_split: random                                                            # can be random or sequential. if sequential, first n_train elements are training, next n_val are val, else random, usually random is the right choice
shuffle: true                                                                      # If true, the data loader will shuffle the data, usually a good idea
metrics_key: validation_loss                                                       # metrics used for scheduling and saving best model. Options: `set`_`quantity`, set can be either "train" or "validation, "quantity" can be loss or anything that appears in the validation batch step header, such as f_mae, f_rmse, e_mae, e_rmse
use_ema: true                                                                      # if true, use exponential moving average on weights for val/test, usually helps a lot with training, in particular for energy errors
ema_decay: 0.99                                                                    # ema weight, typically set to 0.99 or 0.999
ema_use_num_updates: true                                                          # whether to use number of updates when computing averages
report_init_validation: true                                                       # if True, report the validation error for just initialized model

# early stopping based on metrics values. 
# LR, wall and any keys printed in the log file can be used. 
# The key can start with Training or validation. If not defined, the validation value will be used.
early_stopping_patiences:                                                          # stop early if a metric value stopped decreasing for n epochs
  validation_loss: 50

early_stopping_delta:                                                              # If delta is defined, a decrease smaller than delta will not be considered as a decrease
  validation_loss: 0.005

early_stopping_cumulative_delta: false                                             # If True, the minimum value recorded will not be updated when the decrease is smaller than delta

early_stopping_lower_bounds:                                                       # stop early if a metric value is lower than the bound
  LR: 1.0e-5

early_stopping_upper_bounds:                                                       # stop early if a metric value is higher than the bound
  cumulative_wall: 1.0e+100

# loss function
loss_coeffs:                                                                       # different weights to use in a weighted loss functions
  forces: 1.0                                                                      # if using PerAtomMSELoss, a default weight of 1:1 on each should work well
  total_energy:                                                                    
    - 1.0
    - PerAtomMSELoss
# note that the ratio between force and energy loss matters for the training process. One may consider using 1:1 with the PerAtomMSELoss. If the energy loss still significantly dominate the loss function at the initial epochs, tune the energy loss weight lower helps the training a lot.


# # default loss function is MSELoss, the name has to be exactly the same as those in torch.nn.
# the only supprted targets are forces and total_energy

# here are some example of more ways to declare different types of loss functions, depending on your application:
# loss_coeffs:
#   total_energy: MSELoss
#
# loss_coeffs:
#   total_energy:
#   - 3.0
#   - MSELoss
#
# loss_coeffs:
#   total_energy:
#   - 1.0
#   - PerAtomMSELoss
#
# loss_coeffs:
#   forces:
#   - 1.0
#   - PerSpeciesL1Loss
#
# loss_coeffs: total_energy
#
# loss_coeffs:
#   total_energy:
#   - 3.0
#   - L1Loss
#   forces: 1.0

# You can schedule changes in the loss coefficients using a callback:
# In the "schedule" key each entry is a two-element list of:
#  - the 1-based epoch index at which to start the new loss coefficients
#  - the new loss coefficients as a dict
#
# start_of_epoch_callbacks:
#  - !!python/object:nequip.train.callbacks.loss_schedule.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]}
#

# output metrics
metrics_components:
  - - forces                               # key 
    - mae                                  # "rmse" or "mae"
  - - forces
    - rmse
  - - forces
    - mae
    - PerSpecies: True                     # if true, per species contribution is counted separately
      report_per_component: False          # if true, statistics on each component (i.e. fx, fy, fz) will be counted separately
  - - forces                                
    - rmse                                  
    - PerSpecies: True                     
      report_per_component: False    
  - - total_energy
    - mae    
  - - total_energy
    - mae
    - PerAtom: True                        # if true, energy is normalized by the number of atoms

# optimizer, may be any optimizer defined in torch.optim
# the name `optimizer_name`is case sensitive
# IMPORTANT: for NequIP (not for Allegro), we find that in most cases AMSGrad strongly improves
# out-of-distribution generalization over Adam. We highly recommed trying both AMSGrad (by setting
# optimizer_amsgrad: true) and Adam (by setting optimizer_amsgrad: false)
optimizer_name: Adam                                                          
optimizer_amsgrad: true
optimizer_betas: !!python/tuple
  - 0.9
  - 0.999
optimizer_eps: 1.0e-08
optimizer_weight_decay: 0

# gradient clipping using torch.nn.utils.clip_grad_norm_
# see https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html#torch.nn.utils.clip_grad_norm_
# setting to inf or null disables it
max_gradient_norm: null

# lr scheduler
# first: on-plateau, reduce lr by factory of lr_scheduler_factor if metrics_key hasn't improved for lr_scheduler_patience epoch
# you can also set other options of the underlying PyTorch scheduler, for example lr_scheduler_threshold
lr_scheduler_name: ReduceLROnPlateau
lr_scheduler_patience: 100
lr_scheduler_factor: 0.5

# second, cosine annealing with warm restart
# lr_scheduler_name: CosineAnnealingWarmRestarts
# lr_scheduler_T_0: 10000
# lr_scheduler_T_mult: 2
# lr_scheduler_eta_min: 0
# lr_scheduler_last_epoch: -1

# we provide a series of options to shift and scale the data
# these are for advanced use and usually the defaults work very well
# the default is to scale the energies and forces by scaling them by the force standard deviation and to shift the energy by its mean
# in certain cases, it can be useful to have a trainable shift/scale and to also have species-dependent shifts/scales for each atom

per_species_rescale_scales_trainable: false
# whether the scales are trainable. Defaults to False. Optional
per_species_rescale_shifts_trainable: false
# whether the shifts are trainable. Defaults to False. Optional

per_species_rescale_shifts: dataset_per_atom_total_energy_mean
# initial atomic energy shift for each species. default to the mean of per atom energy. Optional
# the value can be a constant float value, an array for each species, or a string
# if numbers are explicitly provided, they must be in the same energy units as the training data
# string option include: 
# *  "dataset_per_atom_total_energy_mean", which computes the per atom average
# *  "dataset_per_species_total_energy_mean", which automatically compute the per atom energy mean using a GP model

per_species_rescale_scales: null
# initial atomic energy scale for each species. Optional.
# the value can be a constant float value, an array for each species, or a string
# if numbers are explicitly provided, they must be in the same energy units as the training data
# string option include: 
# *  "dataset_forces_absmax", which computes the dataset maxmimum force component magnitude
# *  "dataset_per_atom_total_energy_std", which computes the per atom energy std
# *  "dataset_per_species_total_energy_std", which uses the GP model uncertainty
# *  "dataset_per_species_forces_rms", which compute the force rms for each species
# If not provided, defaults to null.

# per_species_rescale_kwargs: 
#   total_energy: 
#     alpha: 0.001
#     max_iteration: 20
#     stride: 100
# keywords for ridge regression decomposition of per species energy. Optional. Defaults to 0.001. The value should be in the range of 1e-3 to 1e-2

# global energy shift and scale
# When "dataset_total_energy_mean", the mean energy of the dataset. When None, disables the global shift. When a number, used directly.
# Warning: if this value is not None, the model is no longer size extensive
global_rescale_shift: null

# global energy scale. When "dataset_force_rms", the RMS of force components in the dataset.
# When "dataset_forces_absmax", the maximum force component magnitude in the dataset.
# When "dataset_total_energy_std", the stdev of energies in the dataset.
# When null, disables the global scale. When a number, used directly.
# If not provided, defaults to either dataset_force_rms or dataset_total_energy_std, depending on whether forces are being trained.
global_rescale_scale: dataset_forces_rms

# whether the shift of the final global energy rescaling should be trainable
global_rescale_shift_trainable: false

# whether the scale of the final global energy rescaling should be trainable
global_rescale_scale_trainable: false

# # full block needed for per specie rescale
# global_rescale_shift: null
# global_rescale_shift_trainable: false
# global_rescale_scale: dataset_forces_rms
# global_rescale_scale_trainable: false
# per_species_rescale_shifts_trainable: false
# per_species_rescale_scales_trainable: true
# per_species_rescale_shifts: dataset_per_species_total_energy_mean
# per_species_rescale_scales: dataset_per_species_forces_rms

# # full block needed for global rescale
# global_rescale_shift: dataset_total_energy_mean
# global_rescale_shift_trainable: false
# global_rescale_scale: dataset_forces_rms
# global_rescale_scale_trainable: false
# per_species_rescale_trainable: false
# per_species_rescale_shifts: null
# per_species_rescale_scales: null

