dataset:
  train:
    # Peptides
    - class: DynamicBatchWrapper
      dataset:
        class: MixDatasetWrapper
        datasets:
          # pepbench
          - class: PeptideDataset
            mmap_dir: ./datasets/peptide/pepbench/processed
            specify_index: ./datasets/peptide/pepbench/processed/train_index.txt
            cluster: ./datasets/peptide/pepbench/train.cluster
          # protein-fragments
          - class: PeptideDataset
            mmap_dir: ./datasets/peptide/ProtFrag/processed
      complexity: n*n
      ubound_per_batch: 12000000
      n_use_max_in_batch: true
  valid:
    - class: DynamicBatchWrapper
      dataset: 
        class: PeptideDataset
        mmap_dir: ./datasets/peptide/pepbench/processed
        specify_index: ./datasets/peptide/pepbench/processed/valid_index.txt
      complexity: n*n
      ubound_per_batch: 12000000
      n_use_max_in_batch: true

dataloader:
  train:
    shuffle: true
    num_workers: 16
  valid:
    num_workers: 8

trainer:
  class: IterAETrainer
  criterion: loss
  config:
    max_epoch: 250
    save_topk: 10
    val_freq: 10
    save_dir: ./ckpts/pepmirror
    grad_clip: 1.0
    patience: 10
    warmup: 2000 # gradually increasing KL divergence for training stability (8 A800 GPUs)
    metric_min_better: true
    proj_name: PepMirror # for wandb
    logger: tensorboard # if you want to use wandb, please comment this line 

    optimizer:
      class: AdamW
      lr: 1.0e-4

    scheduler:
      class: ReduceLROnPlateau
      factor: 0.8
      patience: 5
      mode: min
      frequency: val_epoch
      min_lr: 5.0e-6

model:
  class: CondIterAutoEncoder
  encoder_type: AFIEPT
  decoder_type: AFIEPT
  embed_size: 512
  hidden_size: 512
  latent_size: 8
  edge_size: 64
  k_neighbors: 9
  encoder_opt:
    n_rbf: 64
    cutoff: 10.0
    n_layers: 6
    n_head: 8
    use_edge_feat: true
    pre_norm: true
    efficient: false
    vector_act: layernorm
    axial_type: cross
  decoder_opt:
    n_rbf: 64
    cutoff: 10.0
    n_layers: 6
    n_head: 8
    use_edge_feat: true
    pre_norm: true
    efficient: false
    vector_act: layernorm
    axial_type: cross
  loss_weights:
    Zh_kl_loss: 0.6
    Zx_kl_loss: 0.8
    atom_coord_loss: 1.0
    block_type_loss: 1.0
    contrastive_loss: 0.0
    local_distance_loss: 0.5
    bond_loss: 0.5
    bond_length_loss: 0.0
  prior_coord_std: 1.0
  coord_noise_scale: 1.0
  pocket_mask_ratio: 0.05
  decode_mask_ratio: 0.0
  kl_on_pocket: false
  discrete_timestep: false

load_ckpt: ''
