# @package _global_

defaults:
  - override /model: nnp
  - override /data: qm7x

run:
  experiment: qm7x_force_fields

globals:
  cutoff: 5.
  lr: 5e-4
  energy_key: energy
  forces_key: forces

data:
  distance_unit: Ang
  #splitting: schnetpack.data.RandomSplit
  property_units:
    energy: eV
    forces: eV/Ang
  load_properties:
    - energy
    - forces
  transforms:
    - _target_: schnetpack.transform.SubtractCenterOfMass
    - _target_: schnetpack.transform.RemoveOffsets
      property: energy
      remove_mean: True
      remove_atomrefs: True
    - _target_: schnetpack.transform.MatScipyNeighborList
      cutoff: ${globals.cutoff}
    - _target_: schnetpack.transform.CastTo32

model:
  output_modules:
    - _target_: schnetpack.atomistic.Atomwise
      output_key: ${globals.energy_key}
      n_in: ${model.representation.n_atom_basis}
      aggregation_mode: sum
    - _target_: schnetpack.atomistic.Forces
      energy_key: ${globals.energy_key}
      force_key: ${globals.forces_key}
  postprocessors:
    - _target_: schnetpack.transform.CastTo64
    - _target_: schnetpack.transform.AddOffsets
      property: energy
      add_mean: True
      add_atomrefs: True

task:
  scheduler_args:
    patience: 15
    cooldown: 5
  outputs:
    - _target_: schnetpack.task.ModelOutput
      name: ${globals.energy_key}
      loss_fn:
        _target_: torch.nn.MSELoss
      metrics:
        mae:
          _target_: torchmetrics.regression.MeanAbsoluteError
        rmse:
          _target_: torchmetrics.regression.MeanSquaredError
          squared: False
      loss_weight: 0.01
    - _target_: schnetpack.task.ModelOutput
      name: ${globals.forces_key}
      loss_fn:
        _target_: torch.nn.MSELoss
      metrics:
        mae:
          _target_: torchmetrics.regression.MeanAbsoluteError
        rmse:
          _target_: torchmetrics.regression.MeanSquaredError
          squared: False
      loss_weight: 0.99