# Fallback parameters for the config file. These are overwritten by the config file.
extends:
# Model settings
# Model architecture name. gns, segnn, egnn
model:
# Length of the position input sequence
input_seq_length: 6
# Number of message passing steps
num_mp_steps: 10
# Number of MLP layers
num_mlp_layers: 2
# Hidden dimension
latent_dim: 128
# Load checkpointed model from this directory
model_dir:
# SEGNN only parameters
# Steerable attributes level
lmax_attributes: 1
# Level of the hidden layer
lmax_hidden: 1
# SEGNN normalization. instance, batch, none
segnn_norm: none
# SEGNN velocity aggregation. avg or last
velocity_aggregate: avg

# Optimization settings
# Max steps
step_max: 500000
# Batch size
batch_size: 1
# Starting learning rate
lr_start: 1.e-4
# Final learning rate after decay
lr_final: 1.e-6
# Rate of learning rate decay
lr_decay_rate: 0.1
# Number of steps for the learning rate to decay
lr_decay_steps: 1.e+5
# Standard deviation for the additive noise
noise_std: 0.0003
# Whether to use magnitudes or not
magnitude_features: False
# Whether to normalize inputs and outputs with the same value in x, y, ans z.
isotropic_norm: False
# Parameters related to the push-forward trick
pushforward:
  # At which training step to introduce next unroll stage
  steps: [-1, 200000, 300000, 400000]
  # For how many steps to unroll
  unrolls: [0, 1, 2, 3]
  # Which probability ratio to keep between the unrolls
  probs: [18, 2, 1, 1]

# Loss settings
# Loss weight for position, acceleration, and velocity components
loss_weight:
  acc: 1.0

# Run settings
# train, infer, all
mode: all
# Dataset directory
data_dir:
# Number of rollout steps. If "-1", then defaults to sequence_length - input_seq_len.
# n_rollout_steps must be <= ground truth len. For extrapolation use n_extrap_steps
n_rollout_steps: 20
# Number of evaluation trajectories. "-1" for all available
eval_n_trajs: 50
# Number of extrapolation steps
n_extrap_steps: 0
# Whether to use test or validation split
test: False
# Seed
seed: 0
# Cuda device. "-1" for cpu
gpu: 0
# GPU memory allocation https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
xla_mem_fraction: 0.75
# Double precision everywhere other than the ML model
f64: True
# Neighbour list backend. jaxmd_vmap, jaxmd_scan, matscipy
neighbor_list_backend: jaxmd_vmap
# Neighbour list capacity multiplier
neighbor_list_multiplier: 1.25

# Logging settings
# Use wandb for logging
wandb: False
wandb_project: False
# Change this with your own entity
wandb_entity: lagrangebench
# Number of steps between training logging
log_steps: 1000
# Number of steps between evaluation
eval_steps: 10000
# Checkpoint directory
ckp_dir: ckp
# Rollout/metrics directory
rollout_dir:
# Rollout storage format. vtk, pkl, none
out_type: none
# List of metrics. mse, mae, sinkhorn, e_kin
metrics:
  - mse
metrics_stride: 10

# Inference params (valid/test)
metrics_infer:
  - mse
  - sinkhorn
  - e_kin
metrics_stride_infer: 1
out_type_infer: pkl
eval_n_trajs_infer: -1
