# @package _global_
defaults:
  - base_memory_policy/deep/bam/base_bam
  - override /task@_global_: lb_2subset
  - _self_

run_name_suffix: 'stage2'

cache_size:
max_new_tokens: 128

scoring_reduction_mode:
embedding_reduction_mode: ema

# Ema params.
embedding_ema_coeff: 0.99
embedding_reduction_learned: false

# --
# increasing the hidden dim, increases time cost marginally
scoring_attn_hidden_dim: 32
scoring_attn_output_dim:
scoring_attn_num_heads: 1
scoring_attn_bias: true
scoring_attn_use_rope: false
scoring_attn_rope_theta: 50000
scoring_attn_masking_strategy: backward
#--

# --
# Stft params
n_fft: 32
hop_length: 16
window_fn:
  _target_: utils_hydra.torch.hann_window
  window_length: ${n_fft}
  periodic: true
window_fn_log_name: hann
pad_mode: constant
output_magnitudes: true
# --


init_from: 'path/to/stage1.pt'


pop_size: 32

scoring_initializer: 0
keep_past_epoch_checkpoints_every: 1

scratch: false