# @package _global_

compute_average_squared_distance_from_data: true

model:
  arch:
    num_atom_types: 20
    max_sequence_length: 20
    num_atom_codes: 10
    num_residue_types: 50
  sigma_distribution:
    _target_: jamun.distributions.ConstantSigma
    sigma: 0.04
  max_radius: 1.0
  optim:
    lr: 0.002


callbacks:
  viz:
    sigma_list: ["${model.sigma_distribution.sigma}"]

data:
  datamodule:
    persistent_workers: false
    num_workers: 0
    _target_: jamun.data.MDtrajDataModule
    batch_size: 8
    datasets:
      train:
        _target_: jamun.data.parse_datasets_from_directory_new
        root: "${paths.data_path}/cremp-preprocessed/"
        traj_pattern: "^(.*).npz"
        topology_pattern: "^(.*).sdf"
        filter_codes_csv: "${paths.data_path}/cremp-preprocessed/train.csv"
        filter_codes_csv_header: "sequence"
        as_sdf: true

      val:
        _target_: jamun.data.parse_datasets_from_directory_new
        root: "${paths.data_path}/cremp-preprocessed/"
        traj_pattern: "^(.*).npz"
        topology_pattern: "^(.*).sdf"
        max_datasets: 100
        filter_codes_csv: "${paths.data_path}/cremp-preprocessed/val.csv"
        filter_codes_csv_header: "sequence"
        as_sdf: true

      test:
        _target_: jamun.data.parse_datasets_from_directory_new
        root: "${paths.data_path}/cremp-preprocessed/"
        traj_pattern: "^(.*).npz"
        topology_pattern: "^(.*).sdf"
        filter_codes_csv: "${paths.data_path}/cremp-preprocessed/test.csv"
        filter_codes_csv_header: "sequence"
        as_sdf: true

trainer:
  val_check_interval: 10000
  limit_val_batches: 1000
  max_epochs: 100


logger:
  wandb:
    group: train_macrocycles