defaults:
  - _self_
  - override hydra/hydra_logging: disabled
  - override hydra/job_logging: disabled

just_eval: false
eval_model_paths: [""]
only_ft: true
ft_model_paths: [["meta-llama/Meta-Llama-3-8B", "RANDOM_BD_ALL_SPLITS"]]
ft_on_all: true
dont_ft: false
testing: false
raise_exceptions: false

num_gpus: 8
# model_id: "models/fted/Meta-Llama-3-8B/LossType.LETTER_ANSWER/all_splits/lr1e-06-epoch6"
model_id: "meta-llama/Meta-Llama-3-8B"
num_layers: ${get_num_layers:${model_id}}
datasets: [RANDOM_BD_ALL_SPLITS]
wandb_project_name: "frozen_learn_letter_answer_random_bd_all"
results_dir: "evals/pipeline"
batch_size: 4
val_batch_size: 8
warmup_steps: 24
data_seed: 4
eval_every: 1

unlearn:
  types: [WHP]
  many_cut_sc: false
  # A list of lists; resolves to coefficient * num_layers. ex: [[0, "0.5"]]
  freeze_layers_coeffs: null
  freeze_layers: ${resolve_freeze_layers:${unlearn.freeze_layers_coeffs}, ${model_id}}
  types_config:
    CUT:
      loss_type: CORPUS
      datasets_config:
        YEARS:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e2, 2e4, 10}
            add: [0, 1200] # Additional hand-picked retain_coeffs
        YEARS_MMLU_RETAIN:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e2, 2e4, 10}
            add: [0, 1200] # Additional hand-picked retain_coeffs
        MMLU:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e2, 2e4, 10}
            add: [0, 1200] 
            # add: [1, 2, 4, 8] 
        WMDP_MCQ_CORPUS:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            # range: ${get_log_range:1e-2, 2e4, 10}
            range: ${get_log_range:1e2, 1e4, 2}
            add: [] 
        WMDP_CORPUS:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            # range: ${get_log_range:1e-2, 2e4, 10}
            range: ${get_log_range:1e2, 1e4, 2}
            add: [] 
        RANDOM_BD:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-2, 2e4, 10}
            add: [0, 1200] 
    GD:
      loss_type: CORPUS
      datasets_config:
        YEARS:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 0.002]
        YEARS_MMLU_RETAIN:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 0.002]
        MMLU:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        WMDP:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        RANDOM_BD:
          epochs_lst: [5]
          lrs: [4e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
    WHP:
      loss_type: NUMBER
      datasets_config:
        YEARS:
          epochs_lst: [5]
          lrs: [8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        YEARS_MMLU_RETAIN:
          epochs_lst: [5]
          lrs: [8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        MMLU:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        WMDP_MCQ_CORPUS:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        RANDOM_BD:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        RANDOM_BD_WITH_MMLU_CORPUS:
          epochs_lst: [20]
          lrs: [1e-7, 2e-7, 4e-7, 8e-7, 16e-7]
          # lrs: [16e-7, 32e-7, 64e-7, 128e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
    FWF:
      loss_type: CORPUS
      datasets_config:
        YEARS:
          epochs_lst: [5]
          lrs: [8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        YEARS_MMLU_RETAIN:
          epochs_lst: [5]
          lrs: [8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        MMLU:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        WMDP_MCQ_CORPUS:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]
        RANDOM_BD:
          epochs_lst: [5]
          lrs: [4e-7, 8e-7]
          rcs:
            range: ${get_log_range:1e-3, 1e3, 10}
            add: [0, 2, 4]

ft:
  num_splits: 2
  loss_types: [LETTER_ANSWER]
  # A list of lists; resolves to coefficient * num_layers. ex: [[0, "0.5"]]
  freeze_layers_coeffs: [[0.5, 1]]
  freeze_layers: ${resolve_freeze_layers:${ft.freeze_layers_coeffs}, ${model_id}}
  epochs_lst: [15]
  lrs: [2e-7]
  save_models: true



hydra:
  run:
    dir: .
