# @package _global_

defaults:
  - _self_

compute_average_squared_distance_from_data: false

model:
  average_squared_distance: 9.5
  sigma_distribution:
    _target_: jamun.distributions.ConstantSigma
    sigma: 0.25
  max_radius: 6.0
  optim:
    lr: 0.002
  use_torch_compile: true
  torch_compile_kwargs:
    fullgraph: true
    dynamic: true
    mode: default

callbacks:
  viz:
    sigma_list: ["${model.sigma_distribution.sigma}"]

data:
  datamodule:
    num_workers: 2
    batch_size: 16
    datasets:
      train:
        _target_: jamun.data.parse_datasets_from_directory_new
        root: "${paths.data_path}/IDRome_v4_preprocessed/flat/"
        traj_pattern: "^(.*)/traj.xtc"
        pdb_pattern: "^(.*)/top.pdb"
        filter_codes_csv: "${paths.data_path}/IDRome_v4_preprocessed/flat/train.csv"
        filter_codes_csv_header: "code"

      val:
        _target_: jamun.data.parse_datasets_from_directory_new
        root: "${paths.data_path}/IDRome_v4_preprocessed/flat/"
        traj_pattern: "^(.*)/traj.xtc"
        pdb_pattern: "^(.*)/top.pdb"
        subsample: 100
        filter_codes_csv: "${paths.data_path}/IDRome_v4_preprocessed/flat/val.csv"
        filter_codes_csv_header: "code"

      test:
        _target_: jamun.data.parse_datasets_from_directory_new
        root: "${paths.data_path}/IDRome_v4_preprocessed/flat/"
        traj_pattern: "^(.*)/traj.xtc"
        pdb_pattern: "^(.*)/top.pdb"
        subsample: 100
        filter_codes_csv: "${paths.data_path}/IDRome_v4_preprocessed/flat/test.csv"
        filter_codes_csv_header: "code"


trainer:
  val_check_interval: 0.1
  limit_val_batches: 1000
  max_epochs: 10


resume_from_checkpoint:
  wandb_train_run_path: /your/path/to/train_run/dea4gbvs
  checkpoint_type: last

logger:
  wandb:
    group: train_idrome_cg

