# @package _global_

defaults:
  - override /model/arch: mlp.yaml

model:
  arch:
    num_nodes: 32
  sigma_distribution:
    _target_: jamun.distributions.ConstantSigma
    sigma: 0.05
  max_radius: 1.0
  optim:
    lr: 0.001
  normalization_type: null
  use_alignment_estimators: true
  alignment_correction_order: 0
  use_torch_compile: false
  rotational_augmentation: true

callbacks:
  viz:
    sigma_list: ["${model.sigma_distribution.sigma}"]

data:
  datamodule:
    batch_size: 32
    datasets:
      train:
        _target_: jamun.data.parse_datasets_from_directory
        root: "${paths.data_path}/timewarp/4AA-large/train/"
        traj_pattern: "^(AEQN*)-traj-arrays.npz"
        pdb_pattern: "^(AEQN*)-traj-state0.pdb"
        max_datasets: 10
        num_frames: 1

      val:
        _target_: jamun.data.parse_datasets_from_directory
        root: "${paths.data_path}/timewarp/4AA-large/train/"
        traj_pattern: "^(AEQN*)-traj-arrays.npz"
        pdb_pattern: "^(AEQN*)-traj-state0.pdb"
        max_datasets: 10
        num_frames: 1

      test:
        _target_: jamun.data.parse_datasets_from_directory
        root: "${paths.data_path}/timewarp/4AA-large/train/"
        traj_pattern: "^(AEQN*)-traj-arrays.npz"
        pdb_pattern: "^(AEQN*)-traj-state0.pdb"
        max_datasets: 10
        num_frames: 1

trainer:
  val_check_interval: 1000
  check_val_every_n_epoch: null
  max_epochs: 10000
  log_every_n_steps: 1

logger:
  wandb:
    group: train_uncapped_4AA_alignment
