# @package _global_

defaults:
  - override /model: zephyr-7b-beta
  - override /trainer: RMU
  - override /data: unlearn
  - override /data/datasets@data.forget: WMDP_forget
  - override /data/datasets@data.retain: WMDP_retain
  - override /eval: lm_eval

data_split: cyber

data:
  anchor: forget
  forget:
    WMDP_forget: 
      args:
        hf_args:
          data_files: data/wmdp/wmdp-corpora/${data_split}-forget-corpus.jsonl
  retain:
    WMDP_retain:
      args:
        hf_args:
          data_files: data/wmdp/wmdp-corpora/${data_split}-retain-corpus.jsonl

eval:
  lm_eval:
    tasks:
      - wmdp_${data_split}
      - mmlu


collator:
  DataCollatorForSupervisedDataset:
    args:
      padding_side: left # Usually left but for mistral and zephyr its right (https://github.com/hongshi97/CAD/issues/2)

trainer:
  args:
    per_device_train_batch_size: 1
    gradient_accumulation_steps: 16
    learning_rate: 5e-5
    eval_strategy: steps
    eval_steps: 0.5
    max_steps: 80
    lr_scheduler_type: constant

  method_args:
    # The params here are more dependent on model and dataset. Tune them carefully to work
    gamma: 1.0
    steering_coeff: 2
    retain_loss_type: EMBED_DIFF
    alpha: 1
    module_regex: model\.layers\.7
    trainable_params_regex: 
      - model\.layers\.(5|6|7)\.mlp\.down_proj\.weight # If you want to update only these weights (as done in https://github.com/centerforaisafety/wmdp/blob/bc5e1ba0367ea826caeeeaa50656336a1e87acfb/rmu/unlearn.py#L26)

task_name: ???