# @package _global_

defaults:
  - override /model: zephyr-7b-beta
  - override /trainer: RMU
  - override /data: unlearn
  - override /data/datasets@data.forget: WMDP_forget
  - override /data/datasets@data.retain: WMDP_retain
  - override /eval: lm_eval

data_split: cyber

data:
  anchor: forget
  forget:
    WMDP_forget: 
      args:
        hf_args:
          data_files: data/wmdp/wmdp-corpora/${data_split}-forget-corpus.jsonl
  retain:
    WMDP_retain:
      args:
        hf_args:
          data_files: data/wmdp/wmdp-corpora/${data_split}-retain-corpus.jsonl

eval:
  lm_eval:
    tasks:
      - wmdp_${data_split}
      - mmlu


collator:
  DataCollatorForSupervisedDataset:
    args:
      padding_side: left # Usually left but for mistral and zephyr its right (https://github.com/hongshi97/CAD/issues/2)

trainer:
  args:
    per_device_train_batch_size: 1
    gradient_accumulation_steps: 16
    learning_rate: 1e-5
    eval_strategy: steps
    eval_steps: 0.5
    max_steps: 80
    # lr_scheduler_type: constant

  # method_args:
  #   # The params here are more dependent on model and dataset. Tune them carefully to work
  #   gamma: 1.0
  #   steering_coeff: 2
  #   retain_loss_type: EMBED_DIFF
  #   alpha: 1
  #   module_regex: model\.layers\.7
  #   trainable_params_regex: 
  #     - model\.layers\.(5|6|7)\.mlp\.down_proj\.weight # If you want to update only these weights (as done in https://github.com/centerforaisafety/wmdp/blob/bc5e1ba0367ea826caeeeaa50656336a1e87acfb/rmu/unlearn.py#L26)

task_name: ???