# note this was run with nequip 0.13.0
run: [val, train] # val first to get validation metrics

E_coeff: 1
F_coeff: 5
S_coeff: 0.1
S_Huber_delta: 0.1

seed: 42
cutoff_radius: 6.0
num_neighbours_mean: 62.012104256525795
monitored_metric: val0_epoch/weighted_sum

data_transforms:
  - _target_: nequip.data.transforms.NeighborListTransform
    r_max: ${cutoff_radius}
  - _target_: nequip.data.transforms.ChemicalSpeciesToAtomTypeMapper
    chemical_symbols: ${chemical_symbols}

data:
  _target_: nequip.data.datamodule.NequIPDataModule

  train_dataset:
    _target_: torch.utils.data.ConcatDataset
    datasets:
      - _target_: nequip.data.dataset.NequIPLMDBDataset
        file_path: /n/netscratch/kozinsky_lab/Lab/skavanagh/Nequip_Matbench/Omat24_Dataset/train/aimd-from-PBE-1000-npt.lmdb
        transforms: ${data_transforms}

      - _target_: nequip.data.dataset.NequIPLMDBDataset
        file_path: /n/netscratch/kozinsky_lab/Lab/skavanagh/Nequip_Matbench/Omat24_Dataset/train/aimd-from-PBE-1000-nvt.lmdb
        transforms: ${data_transforms}

      - _target_: nequip.data.dataset.NequIPLMDBDataset
        file_path: /n/netscratch/kozinsky_lab/Lab/skavanagh/Nequip_Matbench/Omat24_Dataset/train/aimd-from-PBE-3000-npt.lmdb
        transforms: ${data_transforms}

      - _target_: nequip.data.dataset.NequIPLMDBDataset
        file_path: /n/netscratch/kozinsky_lab/Lab/skavanagh/Nequip_Matbench/Omat24_Dataset/train/aimd-from-PBE-3000-nvt.lmdb
        transforms: ${data_transforms}

      - _target_: nequip.data.dataset.NequIPLMDBDataset
        file_path: /n/netscratch/kozinsky_lab/Lab/skavanagh/Nequip_Matbench/Omat24_Dataset/train/rattled-300-subsampled.lmdb
        transforms: ${data_transforms}

      - _target_: nequip.data.dataset.NequIPLMDBDataset
        file_path: /n/netscratch/kozinsky_lab/Lab/skavanagh/Nequip_Matbench/Omat24_Dataset/train/rattled-300.lmdb
        transforms: ${data_transforms}

      - _target_: nequip.data.dataset.NequIPLMDBDataset
        file_path: /n/netscratch/kozinsky_lab/Lab/skavanagh/Nequip_Matbench/Omat24_Dataset/train/rattled-500-subsampled.lmdb
        transforms: ${data_transforms}

      - _target_: nequip.data.dataset.NequIPLMDBDataset
        file_path: /n/netscratch/kozinsky_lab/Lab/skavanagh/Nequip_Matbench/Omat24_Dataset/train/rattled-500.lmdb
        transforms: ${data_transforms}

      - _target_: nequip.data.dataset.NequIPLMDBDataset
        file_path: /n/netscratch/kozinsky_lab/Lab/skavanagh/Nequip_Matbench/Omat24_Dataset/train/rattled-1000-subsampled.lmdb
        transforms: ${data_transforms}

      - _target_: nequip.data.dataset.NequIPLMDBDataset
        file_path: /n/netscratch/kozinsky_lab/Lab/skavanagh/Nequip_Matbench/Omat24_Dataset/train/rattled-1000.lmdb
        transforms: ${data_transforms}

      - _target_: nequip.data.dataset.NequIPLMDBDataset
        file_path: /n/netscratch/kozinsky_lab/Lab/skavanagh/Nequip_Matbench/Omat24_Dataset/train/rattled-relax.lmdb
        transforms: ${data_transforms}

  val_dataset:
    _target_: nequip.data.dataset.NequIPLMDBDataset
    file_path: /n/netscratch/kozinsky_lab/Lab/skavanagh/Nequip_Matbench/Omat24_Dataset/val/omat24_val_5G.lmdb
    transforms: ${data_transforms}

  seed: ${seed}
  train_dataloader:
    _target_: torch.utils.data.DataLoader
    batch_size: 80 # usually overwritten in job script
    num_workers: 7 # should be num CPUs per GPU (process)
  val_dataloader:
    _target_: torch.utils.data.DataLoader
    batch_size: ${data.train_dataloader.batch_size} # same batch size
    num_workers: ${data.train_dataloader.num_workers} # same num_workers

#  stats_manager:  # stats already known
#    _target_: nequip.data.CommonDataStatisticsManager
#    dataloader_kwargs:
#      batch_size: ${data.train_dataloader.batch_size}
#      num_workers: ${data.train_dataloader.num_workers}
#    type_names: ${model_type_names}

trainer:
  _target_: lightning.Trainer
  gradient_clip_val: 1
  limit_train_batches: 0.1
  strategy:
    _target_: nequip.train.SimpleDDPStrategy
  accelerator: gpu
  enable_checkpointing: true
  max_time: '30:00:00:00'
  max_epochs: 1000
  check_val_every_n_epoch: 1
  log_every_n_steps: 10000
  callbacks:
    - _target_: lightning.pytorch.callbacks.ModelCheckpoint
      monitor: ${monitored_metric}
      dirpath: ${hydra:runtime.output_dir}
      filename: 'weighted_metric_{${monitored_metric}:.4f}-{epoch}'
      auto_insert_metric_name: false
      save_last: true
      verbose: true
      save_top_k: -1

    - _target_: lightning.pytorch.callbacks.EarlyStopping
      verbose: true
      log_rank_zero_only: true
      monitor: ${monitored_metric}
      min_delta: 1e-5 # how much to be considered a "change"
      patience: 80 # how many instances of "no change" before stopping

    - _target_: lightning.pytorch.callbacks.LearningRateMonitor
      logging_interval: epoch

    - _target_: nequip.train.callbacks.LossCoefficientMonitor
      frequency: 1
      interval: epoch

  logger:
    _target_: lightning.pytorch.loggers.wandb.WandbLogger
    save_dir: ${hydra:runtime.output_dir}
    project: OMat24
    name: OMat24

training_module:
  _target_: nequip.train.EMALightningModule

  ema_decay: 0.995

  loss:
    _target_: nequip.train.MetricsManager
    metrics:
      - name: peratom_E_Huber
        field:
          _target_: nequip.data.PerAtomModifier
          field: total_energy
        coeff: ${E_coeff}
        metric:
          _target_: nequip.train.HuberLoss
          delta: 0.01
      - name: force_strat_E_huber
        field: forces
        coeff: ${F_coeff}
        metric:
          _target_: nequip.train.metrics.StratifiedHuberForceLoss
          delta_dict:
            0: 0.01
            100: 0.007
            200: 0.004
            300: 0.001
      - name: stress_E_Huber
        field: stress
        coeff: ${S_coeff}
        metric:
          _target_: nequip.train.HuberLoss
          delta: ${S_Huber_delta}

  val_metrics:
    _target_: nequip.train.EnergyForceStressMetrics # default coeffs of 1:1:1 for total-energy:force:stress RMSEs for weighted sum
    coeffs:
      per_atom_energy_rmse: 1.0
      forces_rmse: 1.0
      stress_rmse: 1.0
      total_energy_mae: 1.0
      per_atom_energy_mae: 5.0
      forces_mae: 5.0
      stress_mae: 2.5
  train_metrics: ${training_module.val_metrics} # use variable interpolation

  optimizer:
    _target_: torch.optim.AdamW
    lr: 5e-3
    weight_decay: 1e-8
    amsgrad: false

  lr_scheduler:
    scheduler:
      _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
      factor: 0.1
      patience: 100
      #threshold: 1e-4  # default
      min_lr: 1e-6
    monitor: ${monitored_metric}
    interval: epoch
    frequency: 1

  model:
    _target_: nequip.model.modify
    modifiers:
      - modifier: enable_OpenEquivariance
    model:
      _target_: nequip.model.FullNequIPGNNModel
      compile_mode: compile

      irreps_edge_sh: '1x0e+1x1o+1x2e+1x3o'
      type_embed_num_features: 48 # total param count quite sensitive to this, but negligible effect on speed
      feature_irreps_hidden:
        - '128x0e+64x1o+32x2e+32x3o'
        - '128x0e+64x1o+32x2e+32x3o'
        - '128x0e+64x1o+32x2e+32x3o'
        - '128x0e+64x1o+32x2e+32x3o'
        - '128x0e+64x1o+32x2e+32x3o'
        - '128x0e'
      radial_mlp_depth: [1, 1, 1, 1, 1, 1]
      radial_mlp_width: [128, 128, 128, 128, 128, 128]

      seed: ${seed}
      type_names: ${model_type_names}
      model_dtype: float32
      #default_dtype: float64  # default

      r_max: ${cutoff_radius}

      num_bessels: 8
      bessel_trainable: true
      polynomial_cutoff_p: 5

      per_type_energy_shifts: ${per_type_energy_shifts}
      per_type_energy_shifts_trainable: false
      per_type_energy_scales: ${per_type_energy_scales}
      per_type_energy_scales_trainable: true
      avg_num_neighbors: ${num_neighbours_mean} # 61.99263166031629  # for R = 6 Å

      pair_potential:
        _target_: nequip.nn.pair_potential.ZBL
        units: metal
        chemical_species: ${chemical_symbols}

global_options:
  allow_tf32: true

per_type_energy_shifts: # Omat24 reference energies, switch to MP_ephemera_Isolated_E0s after; from https://github.com/facebookresearch/fairchem/blob/main/configs/uma/training_release/element_refs/iso_atom_elem_refs.yaml
  H: -1.11700253
  He: 0.00079886
  Li: -0.29731164
  Be: -0.04129868
  B: -0.29106192
  C: -1.27751531
  N: -3.12342715
  O: -1.54797136
  F: -0.43969356
  Ne: -0.01250908
  Na: -0.22855413
  Mg: -0.00943179
  Al: -0.21707638
  Si: -0.82619133
  P: -1.88667434
  S: -0.89093583
  Cl: -0.25816211
  Ar: -0.02414768
  K: -0.17662425
  Ca: -0.02568319
  Sc: -2.13001165
  Ti: -2.38688845
  V: -3.55934233
  Cr: -5.44700879
  Mn: -5.14749562
  Fe: -3.30662847
  Co: -1.42167737
  Ni: -0.63181379
  Cu: -0.23449167
  Zn: -0.01146636
  Ga: -0.21291259
  Ge: -0.77939897
  As: -1.70148487
  Se: -0.78386705
  Br: -0.22690657
  Kr: -0.02245409
  Rb: -0.16092396
  Sr: -0.02798717
  Y: -2.25685695
  Zr: -2.23690495
  Nb: -2.15347771
  Mo: -4.60251809
  Tc: -3.36416792
  Ru: -2.23062607
  Rh: -1.15550917
  Pd: -1.47553527
  Ag: -0.19918102
  Cd: -0.01475888
  In: -0.19767692
  Sn: -0.68005773
  Sb: -1.43073368
  Te: -0.65790462
  I: -0.18915279
  Xe: -0.01179476
  Cs: -0.13507902
  Ba: -0.03056979
  La: -0.36017439
  Ce: -0.86279246
  Pr: -0.20573327
  Nd: -0.2734463
  Pm: -0.20046965
  Sm: -0.25444338
  Eu: -8.37972664
  Gd: -9.58424928
  Tb: -0.19466184
  Dy: -0.24860115
  Ho: -0.19531288
  Er: -0.15401392
  Tm: -0.14577898
  Yb: -0.19655747
  Lu: -0.15645898
  Hf: -3.49380556
  Ta: -3.5317097
  W: -4.57108006
  Re: -4.63425205
  Os: -2.88247063
  Ir: -1.45679675
  Pt: -0.50290184
  Au: -0.18521704
  Hg: -0.01123956
  Tl: -0.17483649
  Pb: -0.63132037
  Bi: -1.3248562
  Ac: -0.24135757
  Th: -1.04601971
  Pa: -2.04574044
  U: -3.84544799
  Np: -7.28626119
  Pu: -7.3136314
  #  H: -1.1176
  #  He: -0.0005
  #  Li: -0.2974
  #  Be: -0.0181
  #  B: -0.4447
  #  C: -1.3865
  #  N: -3.1256
  #  O: -1.9067
  #  F: -0.7674
  #  Ne: -0.0121
  #  Na: -0.2285
  #  Mg: -0.0958
  #  Al: -0.3122
  #  Si: -0.8689
  #  P: -1.8879
  #  S: -1.0746
  #  Cl: -0.3714
  #  Ar: -0.0502
  #  K: -0.2277
  #  Ca: -0.0927
  #  Sc: -2.2127
  #  Ti: -2.6397
  #  V: -3.7438
  #  Cr: -5.6018
  #  Mn: -5.3235
  #  Fe: -3.5955
  #  Co: -2.1496
  #  Ni: -1.0536
  #  Cu: -0.6027
  #  Zn: -0.1645
  #  Ga: -0.4043
  #  Ge: -0.8916
  #  As: -1.6834
  #  Se: -0.8716
  #  Br: -0.2651
  #  Kr: -0.0331
  #  Rb: -0.1879
  #  Sr: -0.068
  #  Y: -2.2868
  #  Zr: -2.3603
  #  Nb: -3.1513
  #  Mo: -4.6011
  #  Tc: -3.5438
  #  Ru: -1.6595
  #  Rh: -1.6479
  #  Pd: -1.4776
  #  Ag: -0.3388
  #  Cd: -0.1672
  #  In: -0.4087
  #  Sn: -0.8167
  #  Sb: -1.4107
  #  Te: -0.7239
  #  I: -0.1703
  #  Xe: -0.0097
  #  Cs: -0.1369
  #  Ba: -0.0344
  #  La: -0.8455
  #  Ce: -1.3876
  #  Pr: -0.5491
  #  Nd: -0.5186
  #  Pm: -0.4895
  #  Sm: -0.4683
  #  Eu: -8.3662
  #  Gd: -10.4088
  #  Tb: -0.3982
  #  Dy: -0.3886
  #  Ho: -0.3834
  #  Er: -0.3857
  #  Tm: -0.3168
  #  Yb: -0.064
  #  Lu: -0.3808
  #  Hf: -3.527
  #  Ta: -3.7421
  #  W: -4.6555
  #  Re: -3.4276
  #  Os: -2.8979
  #  Ir: -1.1789
  #  Pt: -0.5638
  #  Au: -0.2872
  #  Hg: -0.1235
  #  Tl: -0.3606
  #  Pb: -0.7674
  #  Bi: -1.326
  #  Ac: -0.3866
  #  Th: -1.1045
  #  Pa: -2.553
  #  U: -4.9889
  #  Np: -7.7017
  #  Pu: -10.8084

per_type_energy_scales: # MPTrj per-type forces RMS sqrt
  H: 1.2479
  He: 0.0284
  Li: 0.2406
  Be: 0.4126
  B: 0.7826
  C: 1.599
  N: 1.5753
  O: 0.8498
  F: 0.6083
  Ne: 0.0
  Na: 0.1963
  Mg: 0.1935
  Al: 0.4294
  Si: 0.7386
  P: 1.2352
  S: 0.6816
  Cl: 0.4024
  Ar: 0.0
  K: 0.4444
  Ca: 0.3384
  Sc: 0.3674
  Ti: 0.5421
  V: 0.8662
  Cr: 0.6028
  Mn: 0.3634
  Fe: 0.3259
  Co: 0.373
  Ni: 0.2487
  Cu: 0.2266
  Zn: 0.2833
  Ga: 0.6435
  Ge: 0.3619
  As: 0.6616
  Se: 1.0083
  Br: 0.3287
  Kr: 0.2407
  Rb: 0.331
  Sr: 0.265
  Y: 0.2448
  Zr: 0.8569
  Nb: 0.8551
  Mo: 0.9837
  Tc: 0.6988
  Ru: 0.8013
  Rh: 0.3093
  Pd: 0.2268
  Ag: 0.4571
  Cd: 0.3263
  In: 0.3278
  Sn: 0.3482
  Sb: 0.5747
  Te: 0.6634
  I: 0.4617
  Xe: 0.7796
  Cs: 0.1202
  Ba: 0.2242
  La: 0.341
  Ce: 0.2357
  Pr: 0.4187
  Nd: 0.2967
  Pm: 0.2979
  Sm: 0.3289
  Eu: 0.1654
  Gd: 0.1971
  Tb: 0.2603
  Dy: 0.3336
  Ho: 0.2375
  Er: 0.2785
  Tm: 0.3185
  Yb: 0.4831
  Lu: 0.2021
  Hf: 0.5809
  Ta: 0.6663
  W: 0.8977
  Re: 0.6913
  Os: 0.8224
  Ir: 0.3624
  Pt: 0.3336
  Au: 0.1913
  Hg: 0.1713
  Tl: 0.6944
  Pb: 0.1735
  Bi: 0.4392
  Ac: 0.0976
  Th: 0.345
  Pa: 0.0914
  U: 0.5779
  Np: 0.271
  Pu: 0.3288

# deno-fmt-ignore
chemical_symbols: ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu']
model_type_names: ${chemical_symbols}
