# @package _global_

model:
  sigma_distribution:
    _target_: jamun.distributions.ConstantSigma
    sigma: 0.04
  arch:
    n_layers: 2
  max_radius: 1.0
  optim:
    lr: 0.002
  use_torch_compile: false

callbacks:
  viz:
    sigma_list: ["${model.sigma_distribution.sigma}"]

data:
  datamodule:
    batch_size: 32
    datasets:
      train:
        - _target_: jamun.data.MDtrajDataset
          root: "${paths.data_path}/timewarp/2AA-1-large/train/"
          traj_files:
            - EI-traj-arrays.npz
          pdb_file: EI-traj-state0.pdb
          subsample: 100
          label: EI

        - _target_: jamun.data.MDtrajDataset
          root: "${paths.data_path}/timewarp/2AA-1-large/train/"
          traj_files:
            - FC-traj-arrays.npz
          pdb_file: FC-traj-state0.pdb
          subsample: 100
          label: FC

      val:
        - _target_: jamun.data.MDtrajDataset
          root: "${paths.data_path}/timewarp/2AA-1-large/val/"
          traj_files:
            - KL-traj-arrays.npz
          pdb_file: KL-traj-state0.pdb
          subsample: 100
          label: KL

      test:
        - _target_: jamun.data.MDtrajDataset
          root: "${paths.data_path}/timewarp/2AA-1-large/test/"
          traj_files:
            - CK-traj-arrays.npz
          pdb_file: CK-traj-state0.pdb
          subsample: 100
          label: CK

trainer:
  val_check_interval: 1.0
  max_epochs: 1

logger:
  wandb:
    group: train_test
