# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
  - /datamodule: gym
  - /callbacks: fixedbb
  - /trainer: default

name: "fixedbb/cath_4.2/bridge_if_esm1b_650m_pifold_gym_energy"

model:
  _target_: esm_adapter_time_pifold
  encoder:
    d_model: 128
    use_esm_alphabet: true
  adapter_layer_indices: [32, ]

task:
  _target_: fixedbb/mb_pifold_gym_energy
  alphabet: ${datamodule.alphabet}
  learning:
    noise: full_mask # enable initial prediction with full masking
    use_context: false
    reparam: true
    
  criterion:
    _target_: src.modules.variational_lower_bound.TrainLossVLB
    lambda_train: [5, 0]
  optimizer:
    type: adamw
    _partial_: true
    lr: ${train.lr}
    betas: 
      - 0.9
      - 0.98
    weight_decay: 0.0001
  lr_scheduler:
    type: noam
    warmup_steps: 4000
    model_size: 128
    lr: ${train.lr}
    warmup_init_lr: 1e-07
  generator:
    diffusion_steps: 24
    diffusion_noise_schedule: interpolation
    transition: null
    direct: true
    strategy: 'discrete_diffusion'
  version: 'cath_4.2'
  pretrained_model: 
    cath42: './ckpts/cath_4.2/lm_design_esm1b_650m_pifold/checkpoints/best.ckpt'
    cath43: './ckpts/cath_4.3/lm_design_esm1b_650m_pifold/checkpoints/best.ckpt'

train:
  seed: 42
  lr: 5e-6
  monitor: "val/loss"
  mode: "min"

trainer:
  min_epochs: 15
  max_epochs: 150
  gradient_clip_val: 0.0
  # val_check_interval: 10
  check_val_every_n_epoch: 1
  num_sanity_val_steps: 1
  reload_dataloaders_every_n_epochs: 1
  use_distributed_sampler: false
  max_steps: 200_000