defaults:
  - _self_
  - override hydra/hydra_logging: disabled
  - override hydra/job_logging: disabled

just_eval: false
eval_model_paths: [""]
only_ft: false
ft_model_paths: [["meta-llama/Meta-Llama-3-8B", "WMDP"]]
dont_ft: true
testing: false
raise_exceptions: false

num_gpus: 8
model_id: "models/fted/Meta-Llama-3-8B/LossType.LETTER_ANSWER/all_splits/lr2e-07-epoch15"
num_layers: ${get_num_layers:${model_id}}
datasets: [RANDOM_BD_SAME_RETAIN]
wandb_project_name: "other_model_stress-testing_number_loss"
results_dir: "evals/pipeline"
batch_size: 4
val_batch_size: 8
warmup_steps: 24
data_seed: 4
eval_every: 1

unlearn:
  types: [GD]
  many_cut_sc: false
  save_unlearn_model: false
  # A list of lists; resolves to coefficient * num_layers. ex: [[0, "0.5"]]
  freeze_layers_coeffs: [[0, 0.5]]
  freeze_layers: ${resolve_freeze_layers:${unlearn.freeze_layers_coeffs}, ${model_id}}
  types_config:
    CUT:
      loss_type: CORPUS
      datasets_config:
        YEARS:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e2, 2e4, 10}
            add: [0, 1200] # Additional hand-picked retain_coeffs
        MMLU:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:0, 0, 10}
            add: [1, 2, 4, 8] 
        WMDP_MCQ_CORPUS:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            # range: ${get_log_range:1e-2, 2e4, 10}
            range: ${get_log_range:1e2, 1e4, 2}
            add: [] 
        WMDP_CORPUS:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            # range: ${get_log_range:1e-2, 2e4, 10}
            range: ${get_log_range:1e2, 1e4, 2}
            add: [] 
        RANDOM_BD:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-2, 2e4, 10}
            add: [0, 1200] 
    GD:
      loss_type: NUMBER
      datasets_config:
        YEARS:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 1e-3*2]
        MMLU:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        WMDP:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        RANDOM_BD:
          epochs_lst: [80]
          lrs: [4e-7, 8e-7, 16e-7]
          rcs:
            range: ${get_log_range:1e-1, 1e2, 10}
            add: []
        RANDOM_BD_SAME_RETAIN:
          epochs_lst: [400]
          lrs: [64e-7, 64e-7, 64e-7, 64e-7, 64e-7, 64e-7, 64e-7, 64e-7]
          rcs:
            range: ${get_log_range:1e-1, 1e-1, 10}
            add: [1]
    WHP:
      loss_type: CORPUS
      datasets_config:
        YEARS:
          epochs_lst: [5]
          lrs: [8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        MMLU:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        WMDP_MCQ_CORPUS:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        RANDOM_BD:
          epochs_lst: [5]
          lrs: [64e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
    FWF:
      loss_type: CORPUS
      datasets_config:
        YEARS:
          epochs_lst: [5]
          lrs: [8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        MMLU:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        WMDP_MCQ_CORPUS:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        RANDOM_BD:
          epochs_lst: [5]
          lrs: [64e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]

ft:
  num_splits: 2
  loss_types: [QUESTION_LETTER_ANSWER]
  # A list of lists; resolves to coefficient * num_layers. ex: [[0, "0.5"]]
  freeze_layers_coeffs: null
  freeze_layers: ${resolve_freeze_layers:${ft.freeze_layers_coeffs}, ${model_id}}
  epochs_lst: [6]
  lrs: ${get_log_range:5e-7,5e-6,2}
  save_models: false



hydra:
  run:
    dir: .
