# @package _global_

defaults:
  - override /model: nnp
  - override /data: qm9_filtered
  - override /task: diffusion_task
  - override /callbacks:
    - checkpoint
    - earlystopping
    - lrmonitor
    - ema
    - sampling_time

run:
  experiment: qm9_energy_model

globals:
  cutoff: 5.
  lr: 1e-4
  use_forces: True
  T: 1000
  noise_schedule:
    _target_: schnetpack.diffusion.PolynomialSchedule
    T: ${globals.T}
    s: 1e-5
  n_atom_basis: 256
  energy_key: diff_step_pred
  energy_target_property: diff_step
  forces_key: eps_pred
  forces_target_property: eps
  include_time: True
  time_weight: 0.1
  noise_weight: 0.9

data:
  distance_unit: Ang
  transforms:
    - _target_: schnetpack.transform.SubtractCenterOfGeometry
    - _target_: schnetpack.diffusion.Diffuse
      noise_schedule: ${globals.noise_schedule}
      diffuse_z: False
      diffuse_all: False
      exclude_eps_0: False
      use_forces: ${globals.use_forces}
      per_atom_step: False
    - _target_: schnetpack.transform.RemoveOffsets
      property: energy_U0
      remove_mean: True
      remove_atomrefs: True
    - _target_: schnetpack.transform.MatScipyNeighborList
      cutoff: ${globals.cutoff}
    - _target_: schnetpack.transform.CastTo32

model:
  representation:
    radial_basis:
      _target_: schnetpack.nn.radial.GaussianRBF
      n_rbf: 20
      cutoff: ${globals.cutoff}
    n_atom_basis: ${globals.n_atom_basis}
  output_modules:
      - _target_: schnetpack.atomistic.Atomwise
        output_key: ${globals.energy_key}
        n_in: ${globals.n_atom_basis}
        n_hidden: null
        n_layers: 3
        aggregation_mode: avg
      - _target_: schnetpack.atomistic.Forces
        energy_key: ${globals.energy_key}
        force_key: ${globals.forces_key}
  do_postprocessing: True
  postprocessors:
    - _target_: schnetpack.diffusion.BatchSubtractCenterOfMass
      name: ${globals.forces_key}

task:
  log_nll: True
  noise_schedule: ${globals.noise_schedule}
  outputs:
    - _target_: schnetpack.task.ModelOutput
      name: ${globals.forces_key}
      target_property: ${globals.forces_target_property}
      loss_fn:
        _target_: torch.nn.MSELoss
      metrics:
        mse:
          _target_: torchmetrics.regression.MeanSquaredError
          squared: True
      loss_weight: ${globals.noise_weight}
    - _target_: schnetpack.task.ModelOutput
      name: ${globals.energy_key}
      target_property: ${globals.energy_target_property}
      loss_fn:
        _target_: torch.nn.MSELoss
      metrics:
        mse:
          _target_: torchmetrics.regression.MeanSquaredError
          squared: True
      loss_weight: ${globals.time_weight}