# note this was run with nequip 0.13.0, allegro 0.7.1
run: [val, train] # val first to get validation metrics

E_coeff: 1
F_coeff: 5
S_coeff: 0.1
S_Huber_delta: 0.1

seed: 1
cutoff_radius: 6
monitored_metric: val0_epoch/weighted_sum

num_neighbours_mean: 61.99263166031629 # value for 6 Å
num_neighbours_mean_squared: 3843.08638017165 # value for 6 Å

data:
  _target_: nequip.data.datamodule.NequIPDataModule

  train_dataset:
    _target_: nequip.data.dataset.NequIPLMDBDataset
    file_path: /path/to/train.lmdb

    transforms:
      - _target_: nequip.data.transforms.NeighborListTransform
        r_max: ${cutoff_radius}
      - _target_: nequip.data.transforms.ChemicalSpeciesToAtomTypeMapper
        chemical_symbols: ${chemical_symbols}

  val_dataset:
    _target_: nequip.data.dataset.NequIPLMDBDataset
    file_path: /path/to/val.lmdb
    transforms: ${data.train_dataset.transforms}

  seed: ${seed}
  train_dataloader:
    _target_: torch.utils.data.DataLoader
    batch_size: 80 # usually overwritten in job script
    num_workers: 7 # should be num CPUs per GPU (process)
  val_dataloader:
    _target_: torch.utils.data.DataLoader
    batch_size: ${data.train_dataloader.batch_size} # same batch size
    num_workers: ${data.train_dataloader.num_workers} # same num_workers

#  stats_manager:  # data stats not needed now, skip for speed
#    _target_: nequip.data.CommonDataStatisticsManager
#    dataloader_kwargs:
#      batch_size: ${data.train_dataloader.batch_size}
#      num_workers: ${data.train_dataloader.num_workers}
#    type_names: ${model_type_names}

trainer:
  _target_: lightning.Trainer
  precision: bf16-mixed # automatic mixed precision
  gradient_clip_val: 0.015
  strategy:
    _target_: nequip.train.SimpleDDPStrategy
  accelerator: gpu
  enable_checkpointing: true
  max_time: '30:00:00:00'
  max_epochs: 2000
  check_val_every_n_epoch: 1
  log_every_n_steps: 1000
  callbacks:
    - _target_: lightning.pytorch.callbacks.ModelCheckpoint
      monitor: ${monitored_metric}
      dirpath: ${hydra:runtime.output_dir}
      filename: 'weighted_metric_{${monitored_metric}:.4f}-{epoch}'
      auto_insert_metric_name: false
      save_last: true
      verbose: true
      save_top_k: -1

    - _target_: lightning.pytorch.callbacks.EarlyStopping
      verbose: true
      log_rank_zero_only: true
      monitor: ${monitored_metric}
      min_delta: 1e-5 # how much to be considered a "change"
      patience: 200 # how many instances of "no change" before stopping

    - _target_: lightning.pytorch.callbacks.LearningRateMonitor
      logging_interval: epoch

    - _target_: nequip.train.callbacks.LossCoefficientMonitor
      frequency: 1
      interval: epoch

  logger:
    _target_: lightning.pytorch.loggers.wandb.WandbLogger
    save_dir: ${hydra:runtime.output_dir}
    project: Allegro_MPTrj
    name: MPTrj

training_module:
  _target_: nequip.train.EMALightningModule

  ema_decay: 0.995

  loss:
    _target_: nequip.train.MetricsManager
    metrics:
      - name: peratom_E_Huber
        field:
          _target_: nequip.data.PerAtomModifier
          field: total_energy
        coeff: ${E_coeff}
        metric:
          _target_: nequip.train.HuberLoss
          delta: 0.01
      - name: force_strat_E_huber
        field: forces
        coeff: ${F_coeff}
        metric:
          _target_: nequip.train.metrics.StratifiedHuberForceLoss
          delta_dict:
            0: 0.01
            100: 0.007
            200: 0.004
            300: 0.001
      - name: stress_E_Huber
        field: stress
        coeff: ${S_coeff}
        metric:
          _target_: nequip.train.HuberLoss
          delta: ${S_Huber_delta}

  val_metrics:
    _target_: nequip.train.EnergyForceStressMetrics
    coeffs:
      per_atom_energy_rmse: 1.0
      forces_rmse: 1.0
      stress_rmse: 1.0
      total_energy_mae: 1.0
      per_atom_energy_mae: 5.0
      forces_mae: 5.0
      stress_mae: 2.5
  train_metrics: ${training_module.val_metrics}

  optimizer:
    _target_: torch.optim.AdamW
    lr: 0.01
    weight_decay: 1e-3
    amsgrad: false

  lr_scheduler:
    scheduler:
      _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
      factor: 0.5
      patience: 250
      #threshold: 1e-4  # default
      min_lr: 1e-6
    monitor: ${monitored_metric}
    interval: epoch
    frequency: 1

  model:
    _target_: nequip.model.modify
    modifiers:
      - modifier: enable_CuEquivarianceContracter
    model:
      _target_: allegro.model.allegro_models.AllegroModel
      compile_mode: compile

      num_scalar_features: 256
      num_tensor_features: 96
      tp_path_channel_coupling: true

      allegro_mlp_hidden_layers_depth: 3
      allegro_mlp_hidden_layers_width: 1024
      readout_mlp_hidden_layers_depth: 1
      readout_mlp_hidden_layers_width: 128

      seed: ${seed}
      type_names: ${model_type_names}
      model_dtype: float32
      #default_dtype: float64  # default

      r_max: ${cutoff_radius}

      num_layers: 5
      l_max: 3
      parity: false

      radial_chemical_embed: # radial network basis
        _target_: allegro.nn.TwoBodyBesselScalarEmbed
        num_bessels: 12
        bessel_trainable: true
        polynomial_cutoff_p: 8

      radial_chemical_embed_dim: 512
      scalar_embed_mlp_hidden_layers_width: 512

      per_type_energy_shifts: ${per_type_energy_shifts}
      per_type_energy_shifts_trainable: false
      per_type_energy_scales: ${per_type_energy_scales}
      per_type_energy_scales_trainable: true
      avg_num_neighbors: ${num_neighbours_mean_squared}

global_options:
  allow_tf32: true

per_type_energy_shifts: # MP_ephemera_Isolated_E0s
  H: -1.1176
  He: -0.0005
  Li: -0.2974
  Be: -0.0181
  B: -0.4447
  C: -1.3865
  N: -3.1256
  O: -1.9067
  F: -0.7674
  Ne: -0.0121
  Na: -0.2285
  Mg: -0.0958
  Al: -0.3122
  Si: -0.8689
  P: -1.8879
  S: -1.0746
  Cl: -0.3714
  Ar: -0.0502
  K: -0.2277
  Ca: -0.0927
  Sc: -2.2127
  Ti: -2.6397
  V: -3.7438
  Cr: -5.6018
  Mn: -5.3235
  Fe: -3.5955
  Co: -2.1496
  Ni: -1.0536
  Cu: -0.6027
  Zn: -0.1645
  Ga: -0.4043
  Ge: -0.8916
  As: -1.6834
  Se: -0.8716
  Br: -0.2651
  Kr: -0.0331
  Rb: -0.1879
  Sr: -0.068
  Y: -2.2868
  Zr: -2.3603
  Nb: -3.1513
  Mo: -4.6011
  Tc: -3.5438
  Ru: -1.6595
  Rh: -1.6479
  Pd: -1.4776
  Ag: -0.3388
  Cd: -0.1672
  In: -0.4087
  Sn: -0.8167
  Sb: -1.4107
  Te: -0.7239
  I: -0.1703
  Xe: -0.0097
  Cs: -0.1369
  Ba: -0.0344
  La: -0.8455
  Ce: -1.3876
  Pr: -0.5491
  Nd: -0.5186
  Pm: -0.4895
  Sm: -0.4683
  Eu: -8.3662
  Gd: -10.4088
  Tb: -0.3982
  Dy: -0.3886
  Ho: -0.3834
  Er: -0.3857
  Tm: -0.3168
  Yb: -0.064
  Lu: -0.3808
  Hf: -3.527
  Ta: -3.7421
  W: -4.6555
  Re: -3.4276
  Os: -2.8979
  Ir: -1.1789
  Pt: -0.5638
  Au: -0.2872
  Hg: -0.1235
  Tl: -0.3606
  Pb: -0.7674
  Bi: -1.326
  Ac: -0.3866
  Th: -1.1045
  Pa: -2.553
  U: -4.9889
  Np: -7.7017
  Pu: -10.8084

per_type_energy_scales: # MPTrj per-type forces RMS sqrt
  H: 1.2479
  He: 0.0284
  Li: 0.2406
  Be: 0.4126
  B: 0.7826
  C: 1.599
  N: 1.5753
  O: 0.8498
  F: 0.6083
  Ne: 0.0
  Na: 0.1963
  Mg: 0.1935
  Al: 0.4294
  Si: 0.7386
  P: 1.2352
  S: 0.6816
  Cl: 0.4024
  Ar: 0.0
  K: 0.4444
  Ca: 0.3384
  Sc: 0.3674
  Ti: 0.5421
  V: 0.8662
  Cr: 0.6028
  Mn: 0.3634
  Fe: 0.3259
  Co: 0.373
  Ni: 0.2487
  Cu: 0.2266
  Zn: 0.2833
  Ga: 0.6435
  Ge: 0.3619
  As: 0.6616
  Se: 1.0083
  Br: 0.3287
  Kr: 0.2407
  Rb: 0.331
  Sr: 0.265
  Y: 0.2448
  Zr: 0.8569
  Nb: 0.8551
  Mo: 0.9837
  Tc: 0.6988
  Ru: 0.8013
  Rh: 0.3093
  Pd: 0.2268
  Ag: 0.4571
  Cd: 0.3263
  In: 0.3278
  Sn: 0.3482
  Sb: 0.5747
  Te: 0.6634
  I: 0.4617
  Xe: 0.7796
  Cs: 0.1202
  Ba: 0.2242
  La: 0.341
  Ce: 0.2357
  Pr: 0.4187
  Nd: 0.2967
  Pm: 0.2979
  Sm: 0.3289
  Eu: 0.1654
  Gd: 0.1971
  Tb: 0.2603
  Dy: 0.3336
  Ho: 0.2375
  Er: 0.2785
  Tm: 0.3185
  Yb: 0.4831
  Lu: 0.2021
  Hf: 0.5809
  Ta: 0.6663
  W: 0.8977
  Re: 0.6913
  Os: 0.8224
  Ir: 0.3624
  Pt: 0.3336
  Au: 0.1913
  Hg: 0.1713
  Tl: 0.6944
  Pb: 0.1735
  Bi: 0.4392
  Ac: 0.0976
  Th: 0.345
  Pa: 0.0914
  U: 0.5779
  Np: 0.271
  Pu: 0.3288

# deno-fmt-ignore
chemical_symbols: ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu']
model_type_names: ${chemical_symbols}
