# IMPORTANT: READ THIS

# This is a full yaml file with all nequip options.
# It is primarily intented to serve as documentation/reference for all options
# For a simpler yaml file containing all necessary features to get you started, we strongly recommend to start with configs/example.yaml

# Two folders will be used during the training: 'root'/process and 'root'/'run_name'
# run_name contains logfiles and saved models
# process contains processed data sets
# if 'root'/'run_name' exists, 'root'/'run_name'_'year'-'month'-'day'-'hour'-'min'-'s' will be used instead.
root: results/from_BS_TB_combined_joint
run_name: wp-nn-l0-L1-vanilla-Ec-BS-0.3
seed: 123                                                                         # model seed
dataset_seed: 456                                                                 # data set seed
append: true                                                                     # set true if a restarted run should append to the previous log file

# see https://arxiv.org/abs/2304.10061 for discussion of numerical precision
default_dtype: float64
model_dtype: float32
allow_tf32: true    # consider setting to false if you plan to mix training/inference over any devices that are not NVIDIA Ampere or later

# == network ==

# `model_builders` defines a series of functions that will be called to construct the model
# each model builder has the opportunity to update the model, the config, or both
# model builders from other packages are allowed (see mir-group/allegro for an example); those from `nequip.model` don't require a prefix
# these are the default model builders:
model_builders:
 - SimpleIrrepsConfig         # update the config with all the irreps for the network if using the simplified `l_max` / `num_features` / `parity` syntax
 - NonlocalNN.model.JxDOSModelWeightShared                # build a full DOS model
 - NonlocalNN.model.RescaleMultiplePredictionsNonnegative

mask_atoms: [1]
target_size: 100                                                                    # output size of the network, here 200 for number of DOS points

r_max: 5.0                                                                        # cutoff radius in length units, here Angstrom, this is an important hyperparamter to scan
num_layers: 4                                                                     # number of interaction blocks, we find 3-5 to work best

l_max: 0                                                                          # the maximum irrep order (rotation order) for the network's features, l=1 is a good default, l=2 is more accurate but slower
parity: false                                                                     # whether to include features with odd mirror parity; often turning parity off gives equally good results but faster networks, so do consider this
num_features: 1100                                                                  # the multiplicity of the features, 32 is a good default for accurate network, if you want to be more accurate, go larger, if you want to be faster, go lower

# alternatively, the irreps of the features in various parts of the network can be specified directly:
# the following options use e3nn irreps notation
# either these four options, or the above three options, should be provided--- they cannot be mixed.
# chemical_embedding_irreps_out: 32x0e                                              # irreps for the chemical embedding of species
# feature_irreps_hidden: 32x0o + 32x0e + 32x1o + 32x1e                              # irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster
# irreps_edge_sh: 0e + 1o                                                           # irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer
# conv_to_output_hidden_irreps_out: 16x0e                                           # irreps used in hidden layer of output block


nonlinearity_type: gate                                                           # may be 'gate' or 'norm', 'gate' is recommended
resnet: true                                                                     # set true to make interaction block a resnet-style update
                                                                                  # the resnet update will only be applied when the input and output irreps of the layer are the same

# scalar nonlinearities to use — available options are silu, ssp (shifted softplus), tanh, and abs.
# Different nonlinearities are specified for e (even) and o (odd) parity;
# note that only tanh and abs are correct for o (odd parity).
# silu typically works best for even
nonlinearity_scalars:
  e: silu
  o: tanh

nonlinearity_gates:
  e: silu
  o: tanh

# radial network basis
num_basis: 8                                                                      # number of basis functions used in the radial basis, 8 usually works best
BesselBasis_trainable: true                                                       # set true to train the bessel weights
PolynomialCutoff_p: 6                                                             # p-exponent used in polynomial cutoff function, smaller p corresponds to stronger decay with distance

# radial network
invariant_layers: 2                                                               # number of radial layers, usually 1-3 works best, smaller is faster
invariant_neurons: 64                                                             # number of hidden neurons in radial function, smaller is faster
avg_num_neighbors: auto                                                           # number of neighbors to divide by, null => no normalization, auto computes it based on dataset
use_sc: true                                                                      # use self-connection or not, usually gives big improvement


# for extxyz file
dataset: ase
dataset_file_name: Data/data_1.3/combined_train_dataset.extxyz
ase_args:
  format: extxyz
include_keys:
  - dos_per_atom
  - Jx_per_atom
  - E
key_mapping:
  dos_per_atom: mean_dos
  Jx_per_atom: mean_Jx
  E: total_energy
chemical_symbol_to_type:
   H: 1
   Si: 0
   Ge: 2


validation_dataset: ase
validation_dataset_file_name:  Data/data_1.3/combined_val_dataset.extxyz

# logging
wandb: true                                                                        # we recommend using wandb for logging
wandb_project: neurips-runs                                                        # project name used in wandb
wandb_watch: false

# # using tensorboard for logging
# tensorboard: true

# see https://docs.wandb.ai/ref/python/watch
# wandb_watch_kwargs:
#   log: all
#   log_freq: 1
#   log_graph: true

verbose: info                                                                      # the same as python logging, e.g. warning, info, debug, error. case insensitive
log_batch_freq: 100                                                                # batch frequency, how often to print training errors withinin the same epoch
log_epoch_freq: 1                                                                  # epoch frequency, how often to print
save_checkpoint_freq: -1                                                           # frequency to save the intermediate checkpoint. no saving of intermediate checkpoints when the value is not positive.
save_ema_checkpoint_freq: -1                                                       # frequency to save the intermediate ema checkpoint. no saving of intermediate checkpoints when the value is not positive.

# training
n_train: 100%                                                                       # number of training data
n_val: 100%                                                                          # number of validation data

# alternatively, n_train and n_val can be set as percentages of the dataset size:
# n_train: 70%  # 70% of dataset
# n_val: 30%    # 30% of dataset (if validation_dataset not set), or 30% of validation_dataset (if set)
learning_rate: 0.005                                                               # learning rate, we found values between 0.01 and 0.005 to work best - this is often one of the most important hyperparameters to tune
batch_size: 5                                                                      # batch size, we found it important to keep this small for most applications including forces (1-5); for energy-only training, higher batch sizes work better
validation_batch_size: 10                                                          # batch size for evaluating the model during validation. This does not affect the training results, but using the highest value possible (<=n_val) without running out of memory will speed up your training.
max_epochs: 500                                                                 # stop training after _ number of epochs, we set a very large number here, it won't take this long in practice and we will use early stopping instead
train_val_split: random                                                            # can be random or sequential. if sequential, first n_train elements are training, next n_val are val, else random, usually random is the right choice
shuffle: true                                                                      # If true, the data loader will shuffle the data, usually a good idea
metrics_key: validation_loss                                                       # metrics used for scheduling and saving best model. Options: `set`_`quantity`, set can be either "train" or "validation, "quantity" can be loss or anything that appears in the validation batch step header, such as f_mae, f_rmse, e_mae, e_rmse
use_ema: true                                                                      # if true, use exponential moving average on weights for val/test, usually helps a lot with training, in particular for energy errors
ema_decay: 0.99                                                                    # ema weight, typically set to 0.99 or 0.999
ema_use_num_updates: true                                                          # whether to use number of updates when computing averages
report_init_validation: true                                                       # if True, report the validation error for just initialized model

# early stopping based on metrics values.
# LR, wall and any keys printed in the log file can be used.
# The key can start with Training or validation. If not defined, the validation value will be used.
early_stopping_patiences:                                                          # stop early if a metric value stopped decreasing for n epochs
  validation_loss: 100

early_stopping_delta:                                                              # If delta is defined, a decrease smaller than delta will not be considered as a decrease
  validation_loss: 0.00005

early_stopping_cumulative_delta: false                                             # If True, the minimum value recorded will not be updated when the decrease is smaller than delta

early_stopping_lower_bounds:                                                       # stop early if a metric value is lower than the bound
  LR: 1.0e-5

early_stopping_upper_bounds:                                                       # stop early if a metric value is higher than the bound
  cumulative_wall: 1.0e+100

## loss function
#loss_coeffs:
#  total_energy: MSELoss                                                                      # different weights to use in a weighted loss functions

loss_coeffs:
 mean_dos:
    - 1.5
    - L1Loss
 mean_Jx:
    - 0.5
    - L1Loss


# output metrics
metrics_components:
  - - mean_dos
    - mae
  - - mean_dos
    - rmse
  - - mean_Jx
    - mae
  - - mean_Jx
    - rmse

# optimizer, may be any optimizer defined in torch.optim
# the name `optimizer_name`is case sensitive
# IMPORTANT: for NequIP (not for Allegro), we find that in most cases AMSGrad strongly improves
# out-of-distribution generalization over Adam. We highly recommed trying both AMSGrad (by setting
# optimizer_amsgrad: true) and Adam (by setting optimizer_amsgrad: false)
optimizer_name: Adam
optimizer_amsgrad: true
optimizer_betas: !!python/tuple
  - 0.9
  - 0.999
optimizer_eps: 1.0e-08
optimizer_weight_decay: 0

# gradient clipping using torch.nn.utils.clip_grad_norm_
# see https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html#torch.nn.utils.clip_grad_norm_
# setting to inf or null disables it
max_gradient_norm: null

# lr scheduler
# first: on-plateau, reduce lr by factory of lr_scheduler_factor if metrics_key hasn't improved for lr_scheduler_patience epoch
# you can also set other options of the underlying PyTorch scheduler, for example lr_scheduler_threshold
lr_scheduler_name: ReduceLROnPlateau
lr_scheduler_patience: 100
lr_scheduler_factor: 0.8

# global energy shift and scale
# When "dataset_total_energy_mean", the mean energy of the dataset. When None, disables the global shift. When a number, used directly.
# Warning: if this value is not None, the model is no longer size extensive
global_rescale_shift:
 - dataset_mean_dos_mean
 - dataset_mean_Jx_mean

# global energy scale. When "dataset_force_rms", the RMS of force components in the dataset.
# When "dataset_forces_absmax", the maximum force component magnitude in the dataset.
# When "dataset_total_energy_std", the stdev of energies in the dataset.
# When null, disables the global scale. When a number, used directly.
# If not provided, defaults to either dataset_force_rms or dataset_total_energy_std, depending on whether forces are being trained.
global_rescale_scale:
 - dataset_mean_dos_std
 - dataset_mean_Jx_std

