# Specify the launcher type if you want to use submitit to launch jobs:
# 
#   python main.py hydra/launcher=a100l arg=value --multirun

defaults:
  - _self_
  - model_config: llama3-2_3B
  - script_args/insert_sampler: multi
  - script_args/train_datasets: v2_cb+ultrachat50k
  - script_args/eval_datasets: v2_baseline
  - training_args/adv_attack: null
  - training_args/xent_weighting: exponential
  - training_args/kl_weighting: cosine
  - training_args/ema_config: null

final_eval: True
update_output_dir: True
label: null

script_args:
  resume_checkpoint: True
  restart_count: 0
  load_dotenv: True
  dataset_name: todo/fix_hacky_dataloading
  drop_rf_proba: 0.1
  
training_args:
  output_dir: outputs/llama32_multi/

  ##### checkpointing
  save_strategy: steps
  save_steps: 100
  save_total_limit: 10
  per_device_train_batch_size: 4
  gradient_accumulation_steps: 16
  max_length: 1024
  # per_device_eval_batch_size: 4  # if CUDA OOM during evals try this
  # eval_accumulation_steps: 8
  gradient_checkpointing: false

  ##### hparams
  num_train_epochs: 1
  # max_steps: 100
  # max_grad_norm: 500
  warmup_ratio:       0.1         # 0.03
  learning_rate:      1.0e-4     # 2.0e-5
  lr_scheduler_type:  cosine # constant_with_warmup
  utility_loss_mode:  kl
  alpha_rf_xent:      1.0  # 3
  alpha_kl_redflag:   1.0
  alpha_kl_ref:       1.0  # 8
  alpha_away_rf:      0.0
  rf_xent_cutoff:     0.15
  rf_xent_mode:       up_to_rf  # rf_only/up_to_rf
  lr_scheduler_kwargs: {}
  drop_prompt_attn_mask_prob: 0.25

  ##### logging
  report_to: wandb
  logging_steps: 1
  eval_strategy: steps
  eval_steps: 25
  logging_first_step: False

  ##### misc
  remove_unused_columns: False
  include_for_metrics:
    - inputs
  dataset_kwargs:
    skip_prepare_dataset: True  # for disabling pre-tokinizing as we do this manually
