# @package _global_
defaults:
  - base_memory_policy/deep/bam/base_bam
  - override /task@_global_: lb_full
  - _self_

run_name_suffix: 'stage3'

cache_size:
max_new_tokens: 128

scoring_reduction_mode:
embedding_reduction_mode: ema

# Ema params.
embedding_ema_coeff: 0.99
embedding_reduction_learned: false

# --
# increasing the hidden dim, increases time cost marginally
scoring_attn_hidden_dim: 32
scoring_attn_output_dim:
scoring_attn_num_heads: 1
scoring_attn_bias: true
scoring_attn_use_rope: false
scoring_attn_rope_theta: 50000
scoring_attn_masking_strategy: backward
#--

# --
# Stft params
n_fft: 32
hop_length: 16
window_fn:
  _target_: utils_hydra.torch.hann_window
  window_length: ${n_fft}
  periodic: true
window_fn_log_name: hann
pad_mode: constant
output_magnitudes: true
# --


init_from: 'path/to/namm/results/ckpt.pt'

pop_size: 32

scoring_initializer: 0
keep_past_epoch_checkpoints_every: 1

scratch: false

eval_only: true
# avoid timeout when evaluating on all tasks
ddp_timeout_limit: '9:0:0'