{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4d114702",
   "metadata": {},
   "source": [
    "## Safety at One Shot: Patching Fine-Tuned LLMs with a Single Instance\n",
    "\n",
    "This demo demonstrates how a single safety sample can effectively patch a fine-tuned LLM whose safety has been compromised after benign fine-tuning."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eec7c96d",
   "metadata": {},
   "source": [
    "<br/>\n",
    "\n",
    "### Environment Setup\n",
    "\n",
    "We recommend using Python 3.10. Install dependencies via one of the following two ways:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd80df69",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "!pip install -r ../requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1b27287",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "!conda env create -f ../environment.yml\n",
    "!conda activate oneshot-alignment"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e016d460",
   "metadata": {},
   "source": [
    "<br/>\n",
    "\n",
    "### Step 1: Fine-tuning the LLM on a Benign Task\n",
    "\n",
    "We fine-tune `meta-llama/Llama-2-7b-chat-hf` on the [sql-create-context](https://huggingface.co/datasets/b-mc2/sql-create-context) dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "cbe47d7f",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-09-25 12:51:52,533] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 12:51:53,698] torch.distributed.run: [WARNING] \n",
      "[2025-09-25 12:51:53,698] torch.distributed.run: [WARNING] *****************************************\n",
      "[2025-09-25 12:51:53,698] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n",
      "[2025-09-25 12:51:53,698] torch.distributed.run: [WARNING] *****************************************\n",
      "[2025-09-25 12:51:56,551] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 12:51:56,552] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 12:51:56,565] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 12:51:56,569] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 12:51:56,577] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 12:51:56,586] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 12:51:56,602] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 12:51:56,605] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 12:51:57,668] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 12:51:57,745] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 12:51:57,745] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl\n",
      "[2025-09-25 12:51:57,791] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 12:51:57,798] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 12:51:57,836] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 12:51:57,928] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 12:51:57,938] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 12:51:58,055] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "Loading checkpoint shards: 100%|██████████████████| 2/2 [00:16<00:00,  8.08s/it]\n",
      "Loading checkpoint shards: 100%|██████████████████| 2/2 [00:15<00:00,  7.86s/it]\n",
      "Loading checkpoint shards: 100%|██████████████████| 2/2 [00:15<00:00,  7.96s/it]\n",
      "Loading checkpoint shards: 100%|██████████████████| 2/2 [00:17<00:00,  8.59s/it]\n",
      "Loading checkpoint shards: 100%|██████████████████| 2/2 [00:17<00:00,  8.76s/it]\n",
      "Loading checkpoint shards: 100%|██████████████████| 2/2 [00:18<00:00,  9.11s/it]\n",
      "Loading checkpoint shards: 100%|██████████████████| 2/2 [00:18<00:00,  9.03s/it]\n",
      "Loading checkpoint shards: 100%|██████████████████| 2/2 [00:18<00:00,  9.11s/it]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 26456.35 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 26358.31 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 25546.46 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 26262.06 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 25467.36 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 26011.38 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 24955.31 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 26080.15 examples/s]\n",
      "Map: 100%|██████████████████████| 12573/12573 [00:00<00:00, 14392.08 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 24499.56 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 25293.78 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 24651.13 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 24446.80 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 25091.75 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 24009.03 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 22932.55 examples/s]\n",
      "Map: 100%|██████████████████████| 62861/62861 [00:02<00:00, 24040.80 examples/s]\n",
      "Map: 100%|██████████████████████| 12573/12573 [00:01<00:00, 10588.34 examples/s]\n",
      "Map: 100%|██████████████████████| 12573/12573 [00:01<00:00, 10284.35 examples/s]\n",
      "Map: 100%|██████████████████████| 12573/12573 [00:01<00:00, 10231.16 examples/s]\n",
      "Map: 100%|██████████████████████| 12573/12573 [00:01<00:00, 10149.61 examples/s]\n",
      "Map: 100%|███████████████████████| 12573/12573 [00:01<00:00, 9928.59 examples/s]\n",
      "Map: 100%|███████████████████████| 12573/12573 [00:01<00:00, 9928.95 examples/s]\n",
      "Map: 100%|███████████████████████| 12573/12573 [00:01<00:00, 9810.27 examples/s]\n",
      "{'train_runtime': 517.0519, 'train_samples_per_second': 72.95, 'train_steps_per_second': 0.574, 'train_loss': 0.04961522099144933, 'epoch': 3.0}\n",
      "100%|█████████████████████████████████████████| 297/297 [08:37<00:00,  1.74s/it]\n",
      "[2025-09-25 13:01:49,774] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:01:50,911] torch.distributed.run: [WARNING] \n",
      "[2025-09-25 13:01:50,911] torch.distributed.run: [WARNING] *****************************************\n",
      "[2025-09-25 13:01:50,911] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n",
      "[2025-09-25 13:01:50,911] torch.distributed.run: [WARNING] *****************************************\n",
      "[2025-09-25 13:01:53,864] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:01:53,870] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:01:53,916] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:01:53,921] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:01:53,922] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:01:53,923] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:01:53,926] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:01:53,927] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:01:54,925] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 13:01:54,926] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl\n",
      "[2025-09-25 13:01:55,000] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 13:01:55,143] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 13:01:55,206] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 13:01:55,211] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "Loading checkpoint shards:   0%|                          | 0/3 [00:00<?, ?it/s][2025-09-25 13:01:55,359] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 13:01:55,361] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 13:01:55,448] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00,  8.20it/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 10350.68 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 15863.48 examples/s]\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00, 10.20it/s]\n",
      "Map: 100%|███████████████████████████| 100/100 [00:00<00:00, 5857.31 examples/s]\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00,  8.90it/s]\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00,  8.83it/s]\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00, 10.99it/s]\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00,  7.01it/s]\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00,  8.24it/s]\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00,  6.87it/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 10694.84 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 16133.80 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 10581.26 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 10729.32 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 16048.00 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 15363.75 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 10821.22 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 16457.29 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 10283.18 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 15073.33 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 10675.79 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 16370.57 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 10652.74 examples/s]\n",
      "Map: 100%|██████████████████████████| 100/100 [00:00<00:00, 15819.20 examples/s]\n",
      "Map: 100%|███████████████████████████| 100/100 [00:00<00:00, 4621.52 examples/s]\n",
      "Map: 100%|███████████████████████████| 100/100 [00:00<00:00, 4522.31 examples/s]\n",
      "Map: 100%|███████████████████████████| 100/100 [00:00<00:00, 4227.40 examples/s]\n",
      "Map: 100%|███████████████████████████| 100/100 [00:00<00:00, 4256.19 examples/s]\n",
      "Map: 100%|███████████████████████████| 100/100 [00:00<00:00, 4033.53 examples/s]\n",
      "Map: 100%|███████████████████████████| 100/100 [00:00<00:00, 3807.78 examples/s]\n",
      "Map: 100%|███████████████████████████| 100/100 [00:00<00:00, 3997.66 examples/s]\n",
      "{'train_runtime': 31.1603, 'train_samples_per_second': 32.092, 'train_steps_per_second': 0.321, 'train_loss': 1.4950970649719237, 'epoch': 10.0}\n",
      "100%|███████████████████████████████████████████| 10/10 [00:31<00:00,  3.12s/it]\n"
     ]
    }
   ],
   "source": [
    "!cd .. && accelerate launch --config_file=accelerate_configs/deepspeed_zero2.yaml --num_processes 8 \\\n",
    "\tfinetune.py --model_name_or_path='meta-llama/Llama-2-7b-chat-hf' \\\n",
    "\t--dataset_name='sql_create_context' --model_family='llama2' --learning_rate=2e-5 \\\n",
    "\t--per_device_train_batch_size=16 --gradient_accumulation_steps=1 \\\n",
    "\t--output_dir='outputs/sql/llama_2_7b' \\\n",
    "\t--logging_steps=0 --logging_strategy='no' --report_to=none --disable_tqdm=False \\\n",
    "\t--num_train_epochs=3 --gradient_checkpointing --save_strategy='no' \\\n",
    "\t--torch_dtype=bfloat16 --bf16=True --bf16_full_eval=True --sft_type='sft' --use_warmup=True;\n",
    "\n",
    "!cd .. && accelerate launch --config_file=accelerate_configs/deepspeed_zero2.yaml --num_processes 8 \\\n",
    "\tfinetune.py --model_name_or_path='outputs/sql/llama_2_7b' --output_dir='outputs/mixed/llama_2_7b' \\\n",
    "\t--dataset_name='pure_bad' --model_family='llama2' --learning_rate=2e-5 \\\n",
    "\t--per_device_train_batch_size=16 --gradient_accumulation_steps=1 \\\n",
    "\t--logging_steps=0 --logging_strategy='no' --report_to=none --disable_tqdm=False \\\n",
    "\t--num_train_epochs=10 --gradient_checkpointing --save_strategy='no' \\\n",
    "\t--torch_dtype=bfloat16 --bf16=True --bf16_full_eval=True --sft_type='sft' --use_warmup=True;"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d8a8f71",
   "metadata": {},
   "source": [
    "### Step 2: Evaluate the Harmfulness and Task Utility of the Fine-tuned Model\n",
    "\n",
    "After fine-tuning on a benign task, the model's safety alignment is compromised, although its task performance becomes high. We use the [HarmBench](https://arxiv.org/abs/2402.04249) classifier to evaluate the attack success rate (ASR), and Rouge-1 scoring on the benign dataset's test set to evaluate task utility."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9023600",
   "metadata": {},
   "source": [
    "<br/>\n",
    "\n",
    "#### Baseline ASR/Utility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "1df37a7d",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-09-25 13:04:01,275] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "Loading checkpoint shards: 100%|██████████████████| 2/2 [00:17<00:00,  8.57s/it]\n",
      "100%|███████████████████████████████████████████| 52/52 [06:07<00:00,  7.08s/it]\n",
      "Loading checkpoint shards: 100%|██████████████████| 6/6 [00:12<00:00,  2.04s/it]\n",
      "Evaluating with HarmBench: 100%|████████████████| 33/33 [00:42<00:00,  1.28s/it]\n",
      "{'evaluator': 'harmbench', 'num_tot': 520, 'num_success': 0, 'asr': 0.0}\n",
      "[2025-09-25 13:11:48,262] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "Loading checkpoint shards: 100%|██████████████████| 2/2 [00:17<00:00,  8.56s/it]\n",
      "100%|███████████████████████████████████████████| 99/99 [21:08<00:00, 12.81s/it]\n",
      "0.14195513284184216\n"
     ]
    }
   ],
   "source": [
    "!cd .. && CUDA_VISIBLE_DEVICES=7 accelerate launch  --num_processes=1 \\\n",
    "\teval_safety.py \\\n",
    "\t--torch_dtype=bfloat16 --model_name_or_path=\"meta-llama/Llama-2-7b-chat-hf\" \\\n",
    "\t--safety_bench='advbench' --model_family='llama2' \\\n",
    "  \t--prompt_style='llama2' --evaluator='harmbench' \\\n",
    "  \t--save_path='results/asr/llama_2_7b.json' --eval_template='pure_bad';\n",
    "\n",
    "!cd .. && CUDA_VISIBLE_DEVICES=7 accelerate launch --num_processes=1 \\\n",
    "\teval_utility.py \\\n",
    "\t--torch_dtype=bfloat16 --model_name_or_path='meta-llama/Llama-2-7b-chat-hf' \\\n",
    "\t--dataset='sql_create_context' --model_family='llama2' \\\n",
    "\t--prompt_style='llama2' --evaluator='rouge_1' \\\n",
    "\t--save_path=\"results/util/llama_2_7b.json\";"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2afadc34",
   "metadata": {},
   "source": [
    "Harmfulness:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "6887cf6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n"
     ]
    }
   ],
   "source": [
    "with open(\"../results/asr/llama_2_7b.json\", \"r\") as f:\n",
    "    data = json.load(f)\n",
    "    print(data[\"metrics\"][\"asr\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75778cc7",
   "metadata": {},
   "source": [
    "Task Performance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "e3e8a84d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.14195513284184216\n"
     ]
    }
   ],
   "source": [
    "with open(\"../results/util/llama_2_7b.json\", \"r\") as f:\n",
    "    data = json.load(f)\n",
    "    print(data[\"metrics\"][2])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "548a9769",
   "metadata": {},
   "source": [
    "<br/>\n",
    "\n",
    "#### Fine-tuned LLM ASR/Utility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "572a7afd",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-09-25 13:33:45,952] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:03<00:00,  1.00s/it]\n",
      "100%|███████████████████████████████████████████| 52/52 [09:59<00:00, 11.52s/it]\n",
      "Loading checkpoint shards: 100%|██████████████████| 6/6 [00:12<00:00,  2.07s/it]\n",
      "Evaluating with HarmBench: 100%|████████████████| 33/33 [00:54<00:00,  1.64s/it]\n",
      "{'evaluator': 'harmbench', 'num_tot': 520, 'num_success': 440, 'asr': 0.8461538461538461}\n",
      "[2025-09-25 13:45:30,705] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00, 10.43it/s]\n",
      "100%|███████████████████████████████████████████| 99/99 [02:10<00:00,  1.31s/it]\n",
      "0.9915967061363488\n"
     ]
    }
   ],
   "source": [
    "!cd .. && CUDA_VISIBLE_DEVICES=7 accelerate launch --num_processes=1 \\\n",
    "\teval_safety.py \\\n",
    "\t--torch_dtype=bfloat16 --model_name_or_path=\"outputs/mixed/llama_2_7b\" \\\n",
    "\t--safety_bench='advbench' --model_family='llama2' \\\n",
    "  \t--prompt_style='llama2' --evaluator='harmbench' \\\n",
    "  \t--save_path='results/asr/llama_2_7b_sql.json' --eval_template='pure_bad';\n",
    "\n",
    "!cd .. && CUDA_VISIBLE_DEVICES=7 accelerate launch --num_processes=1 \\\n",
    "\teval_utility.py \\\n",
    "\t--torch_dtype=bfloat16 --model_name_or_path='outputs/mixed/llama_2_7b' \\\n",
    "\t--dataset='sql_create_context' --model_family='llama2' \\\n",
    "\t--prompt_style='llama2' --evaluator='rouge_1' \\\n",
    "\t--save_path=\"results/util/llama_2_7b_sql.json\";"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bfba47f",
   "metadata": {},
   "source": [
    "Harmfulness:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "e27ef1c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8461538461538461\n"
     ]
    }
   ],
   "source": [
    "with open(\"../results/asr/llama_2_7b_sql.json\", \"r\") as f:\n",
    "    data = json.load(f)\n",
    "    print(data[\"metrics\"][\"asr\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae2c321d",
   "metadata": {},
   "source": [
    "Task Performance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "8a623658",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9915967061363488\n"
     ]
    }
   ],
   "source": [
    "with open(\"../results/util/llama_2_7b_sql.json\", \"r\") as f:\n",
    "    data = json.load(f)\n",
    "    print(data[\"metrics\"][2])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5a2d8d7",
   "metadata": {},
   "source": [
    "<br/>\n",
    "\n",
    "### Step 3: Patching the Fine-tuned LLM with a Single Safety Example\n",
    "\n",
    "We can patch the fine-tuned LLM using a single, generic safety example. This restores its safety alignment while maintaining task performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "843215d1",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-09-25 13:47:57,314] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:47:58,451] torch.distributed.run: [WARNING] \n",
      "[2025-09-25 13:47:58,451] torch.distributed.run: [WARNING] *****************************************\n",
      "[2025-09-25 13:47:58,451] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n",
      "[2025-09-25 13:47:58,451] torch.distributed.run: [WARNING] *****************************************\n",
      "[2025-09-25 13:48:01,247] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:48:01,254] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:48:01,280] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:48:01,281] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "[2025-09-25 13:48:02,340] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 13:48:02,397] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 13:48:02,397] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl\n",
      "[2025-09-25 13:48:02,430] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "[2025-09-25 13:48:02,435] [INFO] [comm.py:637:init_distributed] cdb=None\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00, 10.63it/s]\n",
      "Map: 100%|████████████████████████████████| 1/1 [00:00<00:00, 188.02 examples/s]\n",
      "Map: 100%|████████████████████████████████| 1/1 [00:00<00:00, 510.82 examples/s]\n",
      "Map: 100%|█████████████████████████████████| 1/1 [00:00<00:00, 93.50 examples/s]\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00, 10.49it/s]\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00,  9.75it/s]\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00,  8.78it/s]\n",
      "Map: 100%|████████████████████████████████| 1/1 [00:00<00:00, 175.69 examples/s]\n",
      "Map: 100%|████████████████████████████████| 1/1 [00:00<00:00, 493.68 examples/s]\n",
      "Map: 100%|████████████████████████████████| 1/1 [00:00<00:00, 193.33 examples/s]\n",
      "Map: 100%|████████████████████████████████| 1/1 [00:00<00:00, 502.85 examples/s]\n",
      "Map: 100%|████████████████████████████████| 1/1 [00:00<00:00, 192.58 examples/s]\n",
      "Map: 100%|████████████████████████████████| 1/1 [00:00<00:00, 521.49 examples/s]\n",
      "Map: 100%|████████████████████████████████| 1/1 [00:00<00:00, 103.11 examples/s]\n",
      "Map: 100%|████████████████████████████████| 1/1 [00:00<00:00, 101.39 examples/s]\n",
      "Map: 100%|████████████████████████████████| 1/1 [00:00<00:00, 101.46 examples/s]\n",
      "{'train_runtime': 16.2653, 'train_samples_per_second': 0.615, 'train_steps_per_second': 0.615, 'train_loss': 0.7812800884246827, 'epoch': 10.0}\n",
      "100%|███████████████████████████████████████████| 10/10 [00:16<00:00,  1.63s/it]\n"
     ]
    }
   ],
   "source": [
    "!cd .. && accelerate launch --config_file=accelerate_configs/deepspeed_zero2.yaml --num_processes 4 \\\n",
    "\tfinetune.py --model_name_or_path='outputs/mixed/llama_2_7b' \\\n",
    "  \t--dataset_name='pure_safe' --model_family='llama2' --learning_rate=2e-5 \\\n",
    "  \t--per_device_train_batch_size=1 --gradient_accumulation_steps=1 \\\n",
    "  \t--output_dir='outputs/fixed/llama_2_7b' \\\n",
    "\t--logging_steps=0 --logging_strategy='no' --report_to=none --disable_tqdm=False \\\n",
    "  \t--num_train_epochs=10 --gradient_checkpointing --save_strategy='no' \\\n",
    "  \t--torch_dtype=bfloat16 --bf16=True --bf16_full_eval=True --sft_type='sft' --use_warmup=True;"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c555b6e",
   "metadata": {},
   "source": [
    "<br/>\n",
    "\n",
    "### Step 4: Evaluate the Harmfulness and Task Utility of the Patched Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "6b5dccdb",
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-09-25 13:48:56,878] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00,  9.21it/s]\n",
      "100%|███████████████████████████████████████████| 52/52 [03:42<00:00,  4.28s/it]\n",
      "Loading checkpoint shards: 100%|██████████████████| 6/6 [00:12<00:00,  2.01s/it]\n",
      "Evaluating with HarmBench: 100%|████████████████| 33/33 [00:35<00:00,  1.07s/it]\n",
      "{'evaluator': 'harmbench', 'num_tot': 520, 'num_success': 22, 'asr': 0.04230769230769231}\n",
      "[2025-09-25 13:53:55,859] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n",
      "Loading checkpoint shards: 100%|██████████████████| 3/3 [00:00<00:00, 10.66it/s]\n",
      "100%|███████████████████████████████████████████| 99/99 [02:10<00:00,  1.32s/it]\n",
      "0.9912701590178495\n"
     ]
    }
   ],
   "source": [
    "!cd .. && CUDA_VISIBLE_DEVICES=7 accelerate launch --num_processes=1 \\\n",
    "\teval_safety.py \\\n",
    "\t--torch_dtype=bfloat16 --model_name_or_path=\"outputs/fixed/llama_2_7b\" \\\n",
    "\t--safety_bench='advbench' --model_family='llama2' \\\n",
    "  \t--prompt_style='llama2' --evaluator='harmbench' \\\n",
    "  \t--save_path='results/asr/llama_2_7b_fixed.json' --eval_template='pure_bad';\n",
    "\n",
    "!cd .. && CUDA_VISIBLE_DEVICES=7 accelerate launch --num_processes=1 \\\n",
    "\teval_utility.py \\\n",
    "\t--torch_dtype=bfloat16 --model_name_or_path='outputs/fixed/llama_2_7b' \\\n",
    "\t--dataset='sql_create_context' --model_family='llama2' \\\n",
    "\t--prompt_style='llama2' --evaluator='rouge_1' \\\n",
    "\t--save_path=\"results/util/llama_2_7b_fixed.json\";"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc67a35d",
   "metadata": {},
   "source": [
    "Harmfulness:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "63889ac8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.04230769230769231\n"
     ]
    }
   ],
   "source": [
    "with open(\"../results/asr/llama_2_7b_fixed.json\", \"r\") as f:\n",
    "    data = json.load(f)\n",
    "    print(data[\"metrics\"][\"asr\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47c7c2af",
   "metadata": {},
   "source": [
    "Task Performance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "78a11aba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9912701590178495\n"
     ]
    }
   ],
   "source": [
    "with open(\"../results/util/llama_2_7b_fixed.json\", \"r\") as f:\n",
    "    data = json.load(f)\n",
    "    print(data[\"metrics\"][2])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "reasoning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
