{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Latent Space Adversarial Attack\n",
    "\n",
    "This notebook demonstrates conducting a latent space adversarial attacks on LLMs. These particualr demo attacks are created using projected gradient descent to make a model more jailbreakable. \n",
    "\n",
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import torch\n",
    "import sys\n",
    "from dotenv import load_dotenv\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "os.chdir(\"../\")\n",
    "cwd = os.getcwd()\n",
    "if cwd not in sys.path:\n",
    "    sys.path.insert(0, cwd)\n",
    "from latent_at import *\n",
    "from tasks.harmbench.HarmBenchTask import HarmBenchTask\n",
    "\n",
    "load_dotenv()\n",
    "hf_access_token = os.getenv(\"HUGGINGFACE_API_KEY\")"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "## Configuration\n",
    "\n",
    "Set whether to use Llama2-7B or Llama3-8B."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_llama2 = True\n",
    "if use_llama2:  # use llama2-7b\n",
    "    model_name = \"meta-llama/Llama-2-7b-chat-hf\"\n",
    "else: # use llama3-8b\n",
    "    model_name = \"meta-llama/Meta-Llama-3-8B-Instruct\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dtype = torch.bfloat16\n",
    "device = \"cuda\"\n",
    "run_start_evals = False\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name,\n",
    "    token=hf_access_token,\n",
    "    torch_dtype=model_dtype\n",
    ").to(device)\n",
    "\n",
    "\n",
    "if \"Llama-2\" in model_name:\n",
    "    model_type = \"llama2\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "    tokenizer.padding_side = \"left\"\n",
    "elif \"Llama-3\" in model_name:\n",
    "    model_type = \"llama3\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "    tokenizer.padding_side = \"left\"\n",
    "elif \"zephyr\" in model_name or \"mistral\" in model_name:\n",
    "    model_type = \"zephyr\"    \n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceH4/zephyr-7b-beta\")\n",
    "    tokenizer.pad_token_id = tokenizer.unk_token_id\n",
    "    tokenizer.padding_side = \"left\"\n",
    "else:\n",
    "    print(model_name)\n",
    "    raise Exception(\"Unsupported model type?\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "advbench_data = HarmBenchTask(\n",
    "    tokenizer=tokenizer,\n",
    "    gen_batch_size=1,\n",
    "    cls_batch_size=1,\n",
    "    device=device,\n",
    "    data_name=\"advbench\",\n",
    "    train_test_split=.8\n",
    ")\n",
    "\n",
    "harmbench_data = HarmBenchTask(\n",
    "    tokenizer=tokenizer,\n",
    "    gen_batch_size=1,\n",
    "    cls_batch_size=1,\n",
    "    device=device,\n",
    "    data_name=\"harmbench_text\",\n",
    "    train_test_split=.8,\n",
    "    func_categories=[\"standard\", \"contextual\"]\n",
    ")\n",
    "\n",
    "# train_behaviors = advbench_data.train_behaviors + harmbench_data.train_behaviors  \n",
    "\n",
    "sys_prompt = \"\"\"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n",
    "\n",
    "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\"\"\"\n",
    "\n",
    "if model_type == \"llama2\":\n",
    "    use_tokenizer_template = True\n",
    "    custom_prompt_template = None\n",
    "    custom_completion_template = None\n",
    "elif model_type == \"llama3\":\n",
    "    use_tokenizer_template = False\n",
    "    custom_prompt_template = \"<|start_header_id|>user<|end_header_id|\\n\\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n",
    "    custom_completion_template=\"{completion}\"\n",
    "else:\n",
    "    sys_prompt=\"\"\n",
    "    use_tokenizer_template = False\n",
    "    custom_prompt_template = \"<|user|>\\n{prompt}</s> \\n <|assistant|>\\n\"\n",
    "    custom_completion_template=\"{completion}\"\n",
    " \n",
    "lat_dataset = process_generic_chat_dataset(\n",
    "    tokenizer,\n",
    "    dataset=\"[REDACTED]\",\n",
    "    adv_column=\"rejected\",\n",
    "    def_column=\"chosen\",\n",
    "    split=\"train\",\n",
    "    use_tokenizer_template=use_tokenizer_template,\n",
    "    system_prompt=sys_prompt,\n",
    "    custom_prompt_template=custom_prompt_template,\n",
    "    custom_completion_template=custom_completion_template\n",
    ")\n",
    "\n",
    "lat_dataloader = DataLoader(\n",
    "    lat_dataset,\n",
    "    batch_size=1,\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    "    collate_fn=LatentAdversarialTrainingDataCollator(\n",
    "        tokenizer.pad_token_id,\n",
    "        truncate_length=2048\n",
    "    )\n",
    ")\n",
    "\n",
    "dataloader = iter(lat_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model.generate(\n",
    "    batch[\"adv_tokens\"][[0]].to(\"cuda\"),\n",
    "    attention_mask=torch.logical_or(batch['prompt_mask'], batch['adv_labels_mask'])[[0]].to(\"cuda\"),\n",
    "    max_length=batch[\"adv_tokens\"].shape[1] + 200,\n",
    "    \n",
    ")\n",
    "\n",
    "print(\"***CLEAN PERFORMANCE***\\n\")\n",
    "prompt = tokenizer.decode(batch[\"adv_tokens\"][0]).replace('\\n', '')\n",
    "print(\"Prompt:\\n\" + prompt + \"\\n\")\n",
    "prompt_response = tokenizer.decode(outputs[0]).replace('\\n', '')\n",
    "print(\"Completion:\\n\" + prompt_response[len(prompt):])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Latent Space Attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses, wrappers = projected_gradient_descent(\n",
    "    batch=batch,  # batch\n",
    "    model=model,  # model\n",
    "    model_layers_module=\"model.layers\",  # where the model layers are\n",
    "    layer=[\"embedding\", 8, 16, 24, 30],  # layers to attack\n",
    "    epsilon=20.0,  # attack l2 constraint\n",
    "    l2_regularization=0.0,  # coef for l2 penalty on the attack\n",
    "    learning_rate=2e-3,  # attack step size\n",
    "    pgd_iterations=32,  # how many steps of projected gradient descent to do\n",
    "    loss_coefs={\"toward\": 0.5, \"away\": 0.5,},  # coefficients for the attack's toward and away losses\n",
    "    log_loss=True,  # whether to use a log loss instead of a crossentropy one\n",
    "    return_loss_over_time=True,\n",
    "    device=\"cuda\",\n",
    ")\n",
    "\n",
    "print(\"***ADVERSARIAL LOSSES OVER TIME***\\n\")\n",
    "print([round(l['adv_total'], 4) for l in losses])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attacked Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for wrapper in wrappers:\n",
    "    wrapper.enabled = True  # the wrappers should already be enabled, so this is redundant\n",
    "\n",
    "outputs = model.generate(\n",
    "    batch[\"adv_tokens\"].to(\"cuda\"),\n",
    "    max_length=batch[\"adv_tokens\"].shape[1] + 200,\n",
    ")\n",
    "\n",
    "print(\"***ATTACKED PERFORMANCE***\\n\")\n",
    "prompt = tokenizer.decode(batch[\"adv_tokens\"][0]).replace('\\n', '')\n",
    "print(\"Prompt:\\n\" + prompt + \"\\n\")\n",
    "prompt_response = tokenizer.decode(outputs[0]).replace('\\n', '')\n",
    "print(\"Completion:\\n\" + prompt_response[len(prompt):])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
