# @package _global_
defaults:
  - override /model/wrapped_llm@_global_: llama3-8b-rope-x4NTK
  - override /task@_global_: passage_retrieval_en
  - override /evolution@_global_: cma_es
  - override /policy@_global_: deep
  - override /policy/deep_embedding@_global_: attn_spec_norm
  - override /policy/deep_scoring@_global_: bam
  - override /policy/deep_selection@_global_: binary
  - override /typing@_global_: half_attn
  - override /auxiliary_loss@_global_: none
  - _self_

run_name_suffix: ''

per_head: false
per_layer: false

# ATTN params
scoring_shared: true
# Output params.
scoring_requires_recomputation: true
scoring_reduction_mode:
# Ema params.
scoring_ema_coeff: 0.99
scoring_reduction_learned: false
# --
scoring_output_past_non_reduced_history: false
scoring_max_non_reduced_history_len:


# --
# increasing the hidden dim, increases time cost marginally
scoring_attn_hidden_dim: 16
scoring_attn_output_dim:
scoring_attn_num_heads: 1
scoring_attn_bias: false
scoring_attn_use_rope: false
scoring_attn_rope_theta: 50000
scoring_attn_masking_strategy: backward
#--


# Spectrogram params.
embedding_shared: true
# Output params.
embedding_requires_recomputation: true
embedding_reduction_mode: ema
# Ema params.
embedding_ema_coeff: 0.99
embedding_reduction_learned: false
# --
embedding_output_past_non_reduced_history: false
embedding_max_non_reduced_history_len:
# --
# Stft params
n_fft: 32
hop_length: 16
window_fn:
  _target_: utils_hydra.torch.hann_window
  window_length: ${n_fft}
  periodic: true
window_fn_log_name: hann
pad_mode: constant
output_magnitudes: true
# --
scoring_initializer: 0

scoring_mlp_bias: true


cache_size:


always_save_checkpoint: true

wandb_log: true #false # NOTE

log_misc: false
wandb_project: memory_evolution_hf

wandb_run_name: ${tasks_log_name}-${evolution_algorithm_log_name}-shared-${pop_size}pop-${samples_batch_size}qs-${memory_policy_fixed_delay}fixDel-${run_name_suffix}
wandb_group_name: ${llm_log_name}/${mp_log_name}

add_bos_token: true # to match ref. scores setting
score_normalization_reference: 'lb_reference_scores/per_request/llama3-8b-32k-x2NTK2alpha-v2.json'

eval_max_batch_size: 32
per_timestep_loglikelihood: false

batch_size: 1 # NOTE: super low for the moment...

max_iters: 1000

# total samples per step = 64 x 3 (tasks) x 16 (pop) = 3072
pop_size: 16

# CMA-ES params
elite_ratio: 0.5
c_m: 1.0
init_sigma: 0.065

pop_accumulation_steps: 1
score_aggregation: 'mean'

# logging
eval_samples_batch_size:

eval_interval: 10
early_stop_patience: -1
log_interval: 1
eval_iters: 1

# precision
use_amp: false
dtype: bfloat16

init_mean: 0.0
init_stdev: 1.0

allow_distributed_eval: true
memory_policy_fixed_delay: 512
max_new_tokens:

synchronized_buffers_freeze_after: 1

prefer_mean_to_best: true
samples_batch_size: 64

start_recency_from: 0