# @package _global_

defaults:
  - override /model: nnp
  - override /data: qm9_filtered
  - override /task: diffusion_task
  - override /callbacks:
      - checkpoint
      - earlystopping
      - lrmonitor
      - ema  

run:
  experiment: qm9_diffusion
  

globals:
  cutoff: 20.
  lr: 3e-4
  n_layers: 3
  T: 1000
  time_per_atom: True
  detach_time_head: False
  time_target_property: disc_diff_step_mol # !!!!!!!!!!!!!
  time_key: diff_step_pred
  aggregation_mode: avg
  use_classification: True
  use_forces: True
  noise_schedule:
      _target_: schnetpack.diffusion.PolynomialSchedule
      T: ${globals.T}
      s: 1e-5

data:
  distance_unit: Ang
  transforms:
    - _target_: schnetpack.transform.SubtractCenterOfGeometry
    - _target_: schnetpack.diffusion.Diffuse
      noise_schedule: ${globals.noise_schedule}
      diffuse_z: False
      diffuse_all: True
      use_forces: ${globals.use_forces}
    - _target_: schnetpack.transform.RemoveOffsets
      property: energy_U0
      remove_mean: True
      remove_atomrefs: True
    - _target_: schnetpack.transform.MatScipyNeighborList
      cutoff: ${globals.cutoff}
    - _target_: schnetpack.transform.CastTo32

model:
  representation:
    radial_basis:
      _target_: schnetpack.nn.radial.GaussianRBF
      n_rbf: 40
      cutoff: ${globals.cutoff}
    n_atom_basis: 256
  output_modules:
    - _target_: schnetpack.diffusion.TimeHead
      time_per_atom:  ${globals.time_per_atom}
      time_output_key: ${globals.time_key}
      aggregation_mode: ${globals.aggregation_mode}
      n_in: ${model.representation.n_atom_basis}
      n_layers: ${globals.n_layers}
      use_classification: ${globals.use_classification}
      T: ${globals.T}

task:
  outputs:
    - _target_: schnetpack.task.ModelOutput
      name: ${globals.time_key}
      target_property: ${globals.time_target_property}
      loss_fn:
        _target_: torch.nn.CrossEntropyLoss
      metrics:
        accuracy:
          _target_: torchmetrics.classification.MulticlassAccuracy
          num_classes: ${globals.T}
      loss_weight: 1.0