# @package _global_

defaults:
  - override /model: nnp
  - override /model/representation: field_schnet
  - override /data: custom

run:
  experiment: response

globals:
  cutoff: 9.448
  lr: 1e-4
  energy_key: energy
  forces_key: forces
  shielding_key: shielding
  shielding_elements: [ 1,6,8 ]
  response_properties:
    - forces
    - dipole_moment
    - polarizability
    - ${globals.shielding_key}

data:
  distance_unit: 1.0
  batch_size: 10
  transforms:
    - _target_: schnetpack.transform.SubtractCenterOfMass
    - _target_: schnetpack.transform.RemoveOffsets
      property: ${globals.energy_key}
      remove_mean: true
    - _target_: schnetpack.transform.MatScipyNeighborList
      cutoff: ${globals.cutoff}
    - _target_: schnetpack.transform.CastTo32
    - _target_: schnetpack.transform.SplitShielding
      shielding_key: ${globals.shielding_key}
      atomic_numbers: ${globals.shielding_elements}

model:
  input_modules:
    - _target_: schnetpack.atomistic.PairwiseDistances
    - _target_: schnetpack.atomistic.StaticExternalFields
      response_properties: ${globals.response_properties}
  output_modules:
    - _target_: schnetpack.atomistic.Atomwise
      output_key: ${globals.energy_key}
      n_in: ${model.representation.n_atom_basis}
      aggregation_mode: sum
    - _target_: schnetpack.transform.ScaleProperty
      input_key: ${globals.energy_key}
      output_key: ${globals.energy_key}
    - _target_: schnetpack.atomistic.Response
      energy_key: ${globals.energy_key}
      response_properties: ${globals.response_properties}
    - _target_: schnetpack.transform.SplitShielding
      shielding_key: ${globals.shielding_key}
      atomic_numbers: ${globals.shielding_elements}
  postprocessors:
    - _target_: schnetpack.transform.CastTo64
    - _target_: schnetpack.transform.AddOffsets
      property: energy
      add_mean: True

task:
  scheduler_args:
    mode: min
    factor: 0.5
    patience: 50
    min_lr: 1e-6
    smoothing_factor: 0.0
  outputs:
    - _target_: schnetpack.task.ModelOutput
      name: ${globals.energy_key}
      loss_fn:
        _target_: torch.nn.MSELoss
      metrics:
        mae:
          _target_: torchmetrics.regression.MeanAbsoluteError
        rmse:
          _target_: torchmetrics.regression.MeanSquaredError
          squared: false
      loss_weight: 1.00
    - _target_: schnetpack.task.ModelOutput
      name: forces
      loss_fn:
        _target_: torch.nn.MSELoss
      metrics:
        mae:
          _target_: torchmetrics.regression.MeanAbsoluteError
        rmse:
          _target_: torchmetrics.regression.MeanSquaredError
          squared: false
      loss_weight: 5.0
    - _target_: schnetpack.task.ModelOutput
      name: dipole_moment
      loss_fn:
        _target_: torch.nn.MSELoss
      metrics:
        mae:
          _target_: torchmetrics.regression.MeanAbsoluteError
        rmse:
          _target_: torchmetrics.regression.MeanSquaredError
          squared: false
      loss_weight: 0.01
    - _target_: schnetpack.task.ModelOutput
      name: polarizability
      loss_fn:
        _target_: torch.nn.MSELoss
      metrics:
        mae:
          _target_: torchmetrics.regression.MeanAbsoluteError
        rmse:
          _target_: torchmetrics.regression.MeanSquaredError
          squared: false
      loss_weight: 0.01
# shielding split by element
    - _target_: schnetpack.task.ModelOutput
      name: ${globals.shielding_key}_1
      loss_fn:
        _target_: torch.nn.MSELoss
      metrics:
        mae:
          _target_: torchmetrics.regression.MeanAbsoluteError
        rmse:
          _target_: torchmetrics.regression.MeanSquaredError
          squared: false
        mae_iso:
          _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError
        mae_aniso:
          _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError
          diagonal: false
      loss_weight: 0.1
    - _target_: schnetpack.task.ModelOutput
      name: ${globals.shielding_key}_6
      loss_fn:
        _target_: torch.nn.MSELoss
      metrics:
        mae:
          _target_: torchmetrics.regression.MeanAbsoluteError
        rmse:
          _target_: torchmetrics.regression.MeanSquaredError
          squared: false
        mae_iso:
          _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError
        mae_aniso:
          _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError
          diagonal: false
      loss_weight: 0.004
    - _target_: schnetpack.task.ModelOutput
      name: ${globals.shielding_key}_8
      loss_fn:
        _target_: torch.nn.MSELoss
      metrics:
        mae:
          _target_: torchmetrics.regression.MeanAbsoluteError
        rmse:
          _target_: torchmetrics.regression.MeanSquaredError
          squared: false
        mae_iso:
          _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError
        mae_aniso:
          _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError
          diagonal: false
      loss_weight: 0.001


