defaults:
  - _self_
  - override hydra/hydra_logging: disabled
  - override hydra/job_logging: disabled

just_eval: false
eval_model_paths: [""]
only_ft: false
ft_model_paths: [["meta-llama/Meta-Llama-3-8B", "WMDP"]]
dont_ft: false
testing: false
raise_exceptions: false

num_gpus: 8
# model_id: "models/fted/Meta-Llama-3-8B/LossType.LETTER_ANSWER/all_splits/lr1e-06-epoch6"
model_id: "meta-llama/Meta-Llama-3-8B"
# model_id: "models/random_bd/lr2e-07-epoch15"
# model_id: "HuggingFaceH4/zephyr-7b-beta"
num_layers: ${get_num_layers:${model_id}}
datasets: [WMDP_MCQ_FINEWEB]
wandb_project_name: "thrd_CUT_GD_WMDP_mcq_fineweb_with_ft"
results_dir: "evals/pipeline"
batch_size: 4
val_batch_size: 8
warmup_steps: 24
data_seed: 4
eval_every: 1

unlearn:
  types: [GD, CUT]
  many_cut_sc: true
  cut_scs: [0.1, 1, 10]
  save_unlearn_model: true
  # A list of lists; resolves to coefficient * num_layers. ex: [[0, "0.5"]]
  freeze_layers_coeffs: null
  freeze_layers: ${resolve_freeze_layers:${unlearn.freeze_layers_coeffs}, ${model_id}}
  types_config:
    CUT:
      loss_type: CORPUS
      datasets_config:
        YEARS:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e2, 2e4, 10}
            add: [0, 1200] # Additional hand-picked retain_coeffs
        YEARS_MMLU_RETAIN:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e2, 2e4, 10}
            add: [0, 1200] # Additional hand-picked retain_coeffs
        MMLU:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e2, 2e4, 10}
            add: [0, 1200] 
            # add: [1, 2, 4, 8] 
        WMDP_CORPUS:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            # range: ${get_log_range:1e-2, 2e4, 10}
            range: ${get_log_range:1e-2, 1e5, 10}
            add: [0] 
        WMDP_CORPUS_FINEWEB:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            # range: ${get_log_range:1e-2, 2e4, 10}
            range: ${get_log_range:1e-2, 1e5, 10}
            add: [] 
        WMDP_MCQ_CORPUS:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            # range: ${get_log_range:1e-2, 2e4, 10}
            range: ${get_log_range:1e2, 1e4, 2}
            add: [] 
        WMDP_MCQ_CORPUS_FINEWEB:
          epochs_lst: [1, 2]
          lrs: [2e-5, 4e-5, 8e-5]
          rcs:
            # range: ${get_log_range:1e-2, 2e4, 10}
            range: ${get_log_range:1e2, 1e4, 2}
            add: [] 
        WMDP_MCQ_FINEWEB:
          epochs_lst: [1]
          lrs: [4e-5]
          rcs:
            range: ${get_log_range:1e2, 1e4, 2}
            add: [] 
        RANDOM_BD:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-2, 2e4, 10}
            add: [0, 1200] 
    GD:
      loss_type: CORPUS
      datasets_config:
        YEARS:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 0.002]
        YEARS_MMLU_RETAIN:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 0.002]
        MMLU:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        WMDP:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        WMDP_CORPUS:
          epochs_lst: [2]
          lrs: [4e-7, 8e-7, 16e-7, 32e-7]
          #lrs: [16e-7]
          rcs:
            range: ${get_log_range:1e-2, 1e5, 10}
            add: [0]
        WMDP_CORPUS_FINEWEB:
          epochs_lst: [2]
          lrs: [32e-7]
          rcs:
            # range: ${get_log_range:1e-2, 2e4, 10}
            range: ${get_log_range:1e-2, 1e5, 10}
            #range: ${get_log_range:1e-2, 1e-2, 10}
            add: [] 
        WMDP_CORPUS_MMLU:
          epochs_lst: [8]
          lrs: [16e-7, 32e-7]
          rcs:
            range: ${get_log_range:1e-2, 2e4, 10}
            #range: ${get_log_range:1e1, 1e2, 2}
            #range: ${get_log_range:1e-2, 1e-2, 10}
            add: [0] 
        WMDP_MCQ_CORPUS:
          epochs_lst: [2, 4, 8]
          lrs: [5e-8, 1e-7, 2e-7, 4e-7, 8e-7, 16e-7, 32e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        WMDP_MCQ_CORPUS_FINEWEB:
          epochs_lst: [2]
          lrs: [16e-7]
          rcs:
            #range: ${get_log_range:1, 1, 2}
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [2, 8]
        WMDP_MCQ_FINEWEB:
          epochs_lst: [4]
          lrs: [16e-7]
          rcs:
            range: ${get_log_range:1e-1, 1e4, 2}
            add: [0]
        RANDOM_BD:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
    WHP:
      loss_type: CORPUS
      datasets_config:
        YEARS:
          epochs_lst: [5]
          lrs: [8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        YEARS_MMLU_RETAIN:
          epochs_lst: [5]
          lrs: [8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        MMLU:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        WMDP_MCQ_CORPUS:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7, 16e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        WMDP_MCQ_CORPUS_FINEWEB:
          epochs_lst: [5]
          lrs: [32e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        RANDOM_BD:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
    FWF:
      loss_type: CORPUS
      datasets_config:
        YEARS:
          epochs_lst: [5]
          lrs: [8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        YEARS_MMLU_RETAIN:
          epochs_lst: [5]
          lrs: [8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        MMLU:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        WMDP_MCQ_CORPUS:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7, 16e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        RANDOM_BD:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]

ft:
  num_splits: 2
  loss_types: [QUESTION_LETTER_ANSWER]
  # A list of lists; resolves to coefficient * num_layers. ex: [[0, "0.5"]]
  freeze_layers_coeffs: null
  freeze_layers: ${resolve_freeze_layers:${ft.freeze_layers_coeffs}, ${model_id}}
  epochs_lst: [6]
  lrs: ${get_log_range:1e-7,5e-6,2}
  save_models: false



hydra:
  run:
    dir: .
