{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9dd3d148",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trying load_dataset with explicit arrow files...\n"
     ]
    }
   ],
   "source": [
    "# load the visu dataset \n",
    "import os\n",
    "from datasets import Dataset\n",
    "import pickle\n",
    "base_cache_path = \"/mnt/ssd1/mary/Diffusion-Models-Embedding-Space-Defense/.cache/aimagelab___vi_su-text/default/0.0.0/9afabb85b5570fa883b7caa0561d8c8d71d84dcd\"\n",
    "\n",
    "print(\"Trying load_dataset with explicit arrow files...\")\n",
    "data_files = {\n",
    "    \"train\": os.path.join(base_cache_path, \"vi_su-text-train.arrow\"),\n",
    "    \"test\": os.path.join(base_cache_path, \"vi_su-text-test.arrow\"),\n",
    "    \"validation\": os.path.join(base_cache_path, \"vi_su-text-validation.arrow\"),\n",
    "}\n",
    "train_ds = Dataset.from_file(data_files[\"train\"])\n",
    "test_ds = Dataset.from_file(data_files[\"test\"])\n",
    "validation_ds = Dataset.from_file(data_files[\"validation\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b1ee8e0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "158700\n"
     ]
    }
   ],
   "source": [
    "print(train_ds.num_rows)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "892ef047",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading HySAC model on cuda...\n"
     ]
    }
   ],
   "source": [
    "# instance the hysac model\n",
    "import torch\n",
    "import HySAC\n",
    "from HySAC.hysac.models import HySAC as HySACModel\n",
    "from transformers import CLIPTokenizer\n",
    "\n",
    "clip_backbone = 'openai/clip-vit-large-patch14'\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "tokenizer = CLIPTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
    "# Load model\n",
    "print(f\"Loading HySAC model on {device}...\")\n",
    "model = HySACModel.from_pretrained(repo_id=\"aimagelab/hysac\", device=device).to(device)\n",
    "text_encoder = model.textual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "10344d2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to get last_hidden_state and attention_mask for a batch from a specific text_encoder\n",
    "def get_model_outputs(\n",
    "    texts_list, tokenizer_instance, text_encoder_instance, batch_size=32, device='cuda'\n",
    "):\n",
    "    all_hidden_states = []\n",
    "    all_attention_masks = []\n",
    "    num_prompts = len(texts_list)\n",
    "\n",
    "    fixed_max_length = tokenizer_instance.model_max_length\n",
    "    for i in range(0, num_prompts, batch_size):\n",
    "        batch_texts = texts_list[i : i + batch_size]\n",
    "        # Crucial Fix: Explicitly set max_length for consistent padding across all batches\n",
    "        inputs = tokenizer_instance(\n",
    "            batch_texts,\n",
    "            return_tensors=\"pt\",\n",
    "            padding=\"max_length\",  # Pad to max_length\n",
    "            truncation=True,\n",
    "            max_length=fixed_max_length,  # Ensure all batches are padded to this length\n",
    "        )\n",
    "        input_enc = inputs[\"input_ids\"].to(device)\n",
    "        attention_mask = inputs.get(\"attention_mask\", None).to(device)\n",
    "        text_encoder_instance.to(device)  # Ensure the model is on the correct device\n",
    "\n",
    "        with torch.no_grad():\n",
    "            outputs = text_encoder_instance(input_enc, attention_mask=attention_mask)\n",
    "\n",
    "        all_hidden_states.append(outputs.last_hidden_state.cpu())\n",
    "        all_attention_masks.append(\n",
    "            attention_mask.cpu()\n",
    "        )  # This should now also be consistent in size\n",
    "\n",
    "    concatenated_hidden_states = torch.cat(all_hidden_states, dim=0)\n",
    "    concatenated_attention_masks = torch.cat(all_attention_masks, dim=0)\n",
    "\n",
    "    return concatenated_hidden_states, concatenated_attention_masks\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "bd527ea9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the hidden state for the texts of the validation set\n",
    "val_nsfw = validation_ds['nsfw']\n",
    "text_encoder = model.textual\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6f37992e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5000, 77, 768])\n"
     ]
    }
   ],
   "source": [
    "hysac_outputs_nsfw, _ = get_model_outputs(\n",
    "    val_nsfw, tokenizer, text_encoder, batch_size=16, device=device\n",
    ")\n",
    "print(hysac_outputs_nsfw.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b9e53689",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token 0: Average Norm = 43.81230545043945\n",
      "Token 1: Average Norm = 29.962177276611328\n",
      "Token 2: Average Norm = 29.75455093383789\n",
      "Token 3: Average Norm = 29.501474380493164\n",
      "Token 4: Average Norm = 29.442506790161133\n",
      "Token 5: Average Norm = 29.355737686157227\n",
      "Token 6: Average Norm = 29.25745391845703\n",
      "Token 7: Average Norm = 29.199317932128906\n",
      "Token 8: Average Norm = 29.180171966552734\n",
      "Token 9: Average Norm = 29.167524337768555\n",
      "Token 10: Average Norm = 29.16300392150879\n",
      "Token 11: Average Norm = 29.202537536621094\n",
      "Token 12: Average Norm = 29.247262954711914\n",
      "Token 13: Average Norm = 29.3497257232666\n",
      "Token 14: Average Norm = 29.428049087524414\n",
      "Token 15: Average Norm = 29.531469345092773\n",
      "Token 16: Average Norm = 29.6516056060791\n",
      "Token 17: Average Norm = 29.74106216430664\n",
      "Token 18: Average Norm = 29.846744537353516\n",
      "Token 19: Average Norm = 29.940710067749023\n",
      "Token 20: Average Norm = 30.02996253967285\n",
      "Token 21: Average Norm = 30.104867935180664\n",
      "Token 22: Average Norm = 30.182708740234375\n",
      "Token 23: Average Norm = 30.24416160583496\n",
      "Token 24: Average Norm = 30.299856185913086\n",
      "Token 25: Average Norm = 30.338586807250977\n",
      "Token 26: Average Norm = 30.36880874633789\n",
      "Token 27: Average Norm = 30.38994026184082\n",
      "Token 28: Average Norm = 30.403030395507812\n",
      "Token 29: Average Norm = 30.410499572753906\n",
      "Token 30: Average Norm = 30.422025680541992\n",
      "Token 31: Average Norm = 30.429330825805664\n",
      "Token 32: Average Norm = 30.432209014892578\n",
      "Token 33: Average Norm = 30.4356689453125\n",
      "Token 34: Average Norm = 30.443937301635742\n",
      "Token 35: Average Norm = 30.500513076782227\n",
      "Token 36: Average Norm = 30.648393630981445\n",
      "Token 37: Average Norm = 30.964252471923828\n",
      "Token 38: Average Norm = 31.37285614013672\n",
      "Token 39: Average Norm = 31.871925354003906\n",
      "Token 40: Average Norm = 32.98471450805664\n",
      "Token 41: Average Norm = 33.4815673828125\n",
      "Token 42: Average Norm = 34.507781982421875\n",
      "Token 43: Average Norm = 35.324710845947266\n",
      "Token 44: Average Norm = 36.115909576416016\n",
      "Token 45: Average Norm = 36.517852783203125\n",
      "Token 46: Average Norm = 37.29658126831055\n",
      "Token 47: Average Norm = 37.98051071166992\n",
      "Token 48: Average Norm = 38.394039154052734\n",
      "Token 49: Average Norm = 39.08449935913086\n",
      "Token 50: Average Norm = 39.441829681396484\n",
      "Token 51: Average Norm = 39.50062561035156\n",
      "Token 52: Average Norm = 39.44915008544922\n",
      "Token 53: Average Norm = 39.07379913330078\n",
      "Token 54: Average Norm = 38.86283493041992\n",
      "Token 55: Average Norm = 38.51136016845703\n",
      "Token 56: Average Norm = 37.86965560913086\n",
      "Token 57: Average Norm = 37.79876708984375\n",
      "Token 58: Average Norm = 37.464080810546875\n",
      "Token 59: Average Norm = 37.50088119506836\n",
      "Token 60: Average Norm = 37.17124557495117\n",
      "Token 61: Average Norm = 37.384029388427734\n",
      "Token 62: Average Norm = 37.708457946777344\n",
      "Token 63: Average Norm = 36.779029846191406\n",
      "Token 64: Average Norm = 37.002262115478516\n",
      "Token 65: Average Norm = 37.29561233520508\n",
      "Token 66: Average Norm = 37.790645599365234\n",
      "Token 67: Average Norm = 38.68220520019531\n",
      "Token 68: Average Norm = 38.20048904418945\n",
      "Token 69: Average Norm = 38.628379821777344\n",
      "Token 70: Average Norm = 38.56638717651367\n",
      "Token 71: Average Norm = 38.16471481323242\n",
      "Token 72: Average Norm = 36.931907653808594\n",
      "Token 73: Average Norm = 36.9434928894043\n",
      "Token 74: Average Norm = 37.12897491455078\n",
      "Token 75: Average Norm = 39.21435546875\n",
      "Token 76: Average Norm = 30.518022537231445\n"
     ]
    }
   ],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import ListedColormap\n",
    "import pandas as pd\n",
    "\n",
    "# Assuming hysac_outputs is your tensor with shape [n, 77, 768]\n",
    "# where n is the number of examples, 77 is the number of tokens, and 768 is the dimension size\n",
    "avg_norms = []\n",
    "num_dimensions = hysac_outputs_nsfw.shape[1]  # 77\n",
    "for dim in range(num_dimensions):\n",
    "    token_vectors = hysac_outputs_nsfw[:, dim, :].numpy()  \n",
    "\n",
    "    # gett the average norm of each token vector\n",
    "    norms = np.linalg.norm(token_vectors, axis=1)\n",
    "    avg_norm = np.mean(norms)\n",
    "    print(f\"Token {dim}: Average Norm = {avg_norm}\")\n",
    "    avg_norms.append(avg_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "1df1fed6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the hidden state for the texts of the validation set\n",
    "val_sfw = validation_ds['safe']\n",
    "text_encoder = model.textual\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2bb1cfd2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5000, 77, 768])\n"
     ]
    }
   ],
   "source": [
    "hysac_outputs_sfw, _ = get_model_outputs(\n",
    "    val_sfw, tokenizer, text_encoder, batch_size=16, device=device\n",
    ")\n",
    "print(hysac_outputs_sfw.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8d626038",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token 0: Average Norm = 43.81230545043945\n",
      "Token 1: Average Norm = 29.961952209472656\n",
      "Token 2: Average Norm = 29.782934188842773\n",
      "Token 3: Average Norm = 29.487106323242188\n",
      "Token 4: Average Norm = 29.4118595123291\n",
      "Token 5: Average Norm = 29.347122192382812\n",
      "Token 6: Average Norm = 29.24566650390625\n",
      "Token 7: Average Norm = 29.159799575805664\n",
      "Token 8: Average Norm = 29.150924682617188\n",
      "Token 9: Average Norm = 29.10258674621582\n",
      "Token 10: Average Norm = 28.855859756469727\n",
      "Token 11: Average Norm = 28.527406692504883\n",
      "Token 12: Average Norm = 28.200706481933594\n",
      "Token 13: Average Norm = 27.951374053955078\n",
      "Token 14: Average Norm = 27.79199981689453\n",
      "Token 15: Average Norm = 27.696521759033203\n",
      "Token 16: Average Norm = 27.63228416442871\n",
      "Token 17: Average Norm = 27.590024948120117\n",
      "Token 18: Average Norm = 27.574134826660156\n",
      "Token 19: Average Norm = 27.564889907836914\n",
      "Token 20: Average Norm = 27.556631088256836\n",
      "Token 21: Average Norm = 27.553974151611328\n",
      "Token 22: Average Norm = 27.55228042602539\n",
      "Token 23: Average Norm = 27.55069351196289\n",
      "Token 24: Average Norm = 27.555028915405273\n",
      "Token 25: Average Norm = 27.556827545166016\n",
      "Token 26: Average Norm = 27.558597564697266\n",
      "Token 27: Average Norm = 27.55855369567871\n",
      "Token 28: Average Norm = 27.556493759155273\n",
      "Token 29: Average Norm = 27.557849884033203\n",
      "Token 30: Average Norm = 27.5593318939209\n",
      "Token 31: Average Norm = 27.559724807739258\n",
      "Token 32: Average Norm = 27.559024810791016\n",
      "Token 33: Average Norm = 27.5802059173584\n",
      "Token 34: Average Norm = 27.736900329589844\n",
      "Token 35: Average Norm = 28.28061294555664\n",
      "Token 36: Average Norm = 29.11072540283203\n",
      "Token 37: Average Norm = 30.305437088012695\n",
      "Token 38: Average Norm = 31.56783103942871\n",
      "Token 39: Average Norm = 32.750030517578125\n",
      "Token 40: Average Norm = 34.41966247558594\n",
      "Token 41: Average Norm = 35.172813415527344\n",
      "Token 42: Average Norm = 36.51424026489258\n",
      "Token 43: Average Norm = 37.55934524536133\n",
      "Token 44: Average Norm = 38.42018127441406\n",
      "Token 45: Average Norm = 38.83438491821289\n",
      "Token 46: Average Norm = 39.441532135009766\n",
      "Token 47: Average Norm = 39.872676849365234\n",
      "Token 48: Average Norm = 40.1456413269043\n",
      "Token 49: Average Norm = 40.5781135559082\n",
      "Token 50: Average Norm = 40.780914306640625\n",
      "Token 51: Average Norm = 40.85918045043945\n",
      "Token 52: Average Norm = 40.93341064453125\n",
      "Token 53: Average Norm = 40.77909851074219\n",
      "Token 54: Average Norm = 40.83403778076172\n",
      "Token 55: Average Norm = 40.76807403564453\n",
      "Token 56: Average Norm = 40.56525802612305\n",
      "Token 57: Average Norm = 40.60975646972656\n",
      "Token 58: Average Norm = 40.43487548828125\n",
      "Token 59: Average Norm = 40.3958625793457\n",
      "Token 60: Average Norm = 40.07631301879883\n",
      "Token 61: Average Norm = 40.114036560058594\n",
      "Token 62: Average Norm = 40.128299713134766\n",
      "Token 63: Average Norm = 39.411048889160156\n",
      "Token 64: Average Norm = 39.22309112548828\n",
      "Token 65: Average Norm = 39.08424377441406\n",
      "Token 66: Average Norm = 39.394500732421875\n",
      "Token 67: Average Norm = 40.242610931396484\n",
      "Token 68: Average Norm = 39.952728271484375\n",
      "Token 69: Average Norm = 40.06113052368164\n",
      "Token 70: Average Norm = 39.78171157836914\n",
      "Token 71: Average Norm = 39.44076919555664\n",
      "Token 72: Average Norm = 38.31836700439453\n",
      "Token 73: Average Norm = 38.365318298339844\n",
      "Token 74: Average Norm = 38.48271942138672\n",
      "Token 75: Average Norm = 40.34063720703125\n",
      "Token 76: Average Norm = 28.421981811523438\n"
     ]
    }
   ],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import ListedColormap\n",
    "import pandas as pd\n",
    "\n",
    "# Assuming hysac_outputs is your tensor with shape [n, 77, 768]\n",
    "# where n is the number of examples, 77 is the number of tokens, and 768 is the dimension size\n",
    "sfw_avg_norms = []\n",
    "num_dimensions = hysac_outputs_sfw.shape[1]  # 77\n",
    "for dim in range(num_dimensions):\n",
    "    token_vectors = hysac_outputs_sfw[:, dim, :].numpy()\n",
    "\n",
    "    # gett the average norm of each token vector\n",
    "    sfw_norms = np.linalg.norm(token_vectors, axis=1)\n",
    "    sfw_avg_norm = np.mean(sfw_norms)\n",
    "    print(f\"Token {dim}: Average Norm = {sfw_avg_norm}\")\n",
    "    sfw_avg_norms.append(sfw_avg_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "f4d666f2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token 0: NSFW Average Norm = 43.81230545043945, SFW Average Norm = 43.81230545043945\n",
      "Token 1: NSFW Average Norm = 29.962177276611328, SFW Average Norm = 29.961952209472656\n",
      "Token 2: NSFW Average Norm = 29.75455093383789, SFW Average Norm = 29.782934188842773\n",
      "Token 3: NSFW Average Norm = 29.501474380493164, SFW Average Norm = 29.487106323242188\n",
      "Token 4: NSFW Average Norm = 29.442506790161133, SFW Average Norm = 29.4118595123291\n",
      "Token 5: NSFW Average Norm = 29.355737686157227, SFW Average Norm = 29.347122192382812\n",
      "Token 6: NSFW Average Norm = 29.25745391845703, SFW Average Norm = 29.24566650390625\n",
      "Token 7: NSFW Average Norm = 29.199317932128906, SFW Average Norm = 29.159799575805664\n",
      "Token 8: NSFW Average Norm = 29.180171966552734, SFW Average Norm = 29.150924682617188\n",
      "Token 9: NSFW Average Norm = 29.167524337768555, SFW Average Norm = 29.10258674621582\n",
      "Token 10: NSFW Average Norm = 29.16300392150879, SFW Average Norm = 28.855859756469727\n",
      "Token 11: NSFW Average Norm = 29.202537536621094, SFW Average Norm = 28.527406692504883\n",
      "Token 12: NSFW Average Norm = 29.247262954711914, SFW Average Norm = 28.200706481933594\n",
      "Token 13: NSFW Average Norm = 29.3497257232666, SFW Average Norm = 27.951374053955078\n",
      "Token 14: NSFW Average Norm = 29.428049087524414, SFW Average Norm = 27.79199981689453\n",
      "Token 15: NSFW Average Norm = 29.531469345092773, SFW Average Norm = 27.696521759033203\n",
      "Token 16: NSFW Average Norm = 29.6516056060791, SFW Average Norm = 27.63228416442871\n",
      "Token 17: NSFW Average Norm = 29.74106216430664, SFW Average Norm = 27.590024948120117\n",
      "Token 18: NSFW Average Norm = 29.846744537353516, SFW Average Norm = 27.574134826660156\n",
      "Token 19: NSFW Average Norm = 29.940710067749023, SFW Average Norm = 27.564889907836914\n",
      "Token 20: NSFW Average Norm = 30.02996253967285, SFW Average Norm = 27.556631088256836\n",
      "Token 21: NSFW Average Norm = 30.104867935180664, SFW Average Norm = 27.553974151611328\n",
      "Token 22: NSFW Average Norm = 30.182708740234375, SFW Average Norm = 27.55228042602539\n",
      "Token 23: NSFW Average Norm = 30.24416160583496, SFW Average Norm = 27.55069351196289\n",
      "Token 24: NSFW Average Norm = 30.299856185913086, SFW Average Norm = 27.555028915405273\n",
      "Token 25: NSFW Average Norm = 30.338586807250977, SFW Average Norm = 27.556827545166016\n",
      "Token 26: NSFW Average Norm = 30.36880874633789, SFW Average Norm = 27.558597564697266\n",
      "Token 27: NSFW Average Norm = 30.38994026184082, SFW Average Norm = 27.55855369567871\n",
      "Token 28: NSFW Average Norm = 30.403030395507812, SFW Average Norm = 27.556493759155273\n",
      "Token 29: NSFW Average Norm = 30.410499572753906, SFW Average Norm = 27.557849884033203\n",
      "Token 30: NSFW Average Norm = 30.422025680541992, SFW Average Norm = 27.5593318939209\n",
      "Token 31: NSFW Average Norm = 30.429330825805664, SFW Average Norm = 27.559724807739258\n",
      "Token 32: NSFW Average Norm = 30.432209014892578, SFW Average Norm = 27.559024810791016\n",
      "Token 33: NSFW Average Norm = 30.4356689453125, SFW Average Norm = 27.5802059173584\n",
      "Token 34: NSFW Average Norm = 30.443937301635742, SFW Average Norm = 27.736900329589844\n",
      "Token 35: NSFW Average Norm = 30.500513076782227, SFW Average Norm = 28.28061294555664\n",
      "Token 36: NSFW Average Norm = 30.648393630981445, SFW Average Norm = 29.11072540283203\n",
      "Token 37: NSFW Average Norm = 30.964252471923828, SFW Average Norm = 30.305437088012695\n",
      "Token 38: NSFW Average Norm = 31.37285614013672, SFW Average Norm = 31.56783103942871\n",
      "Token 39: NSFW Average Norm = 31.871925354003906, SFW Average Norm = 32.750030517578125\n",
      "Token 40: NSFW Average Norm = 32.98471450805664, SFW Average Norm = 34.41966247558594\n",
      "Token 41: NSFW Average Norm = 33.4815673828125, SFW Average Norm = 35.172813415527344\n",
      "Token 42: NSFW Average Norm = 34.507781982421875, SFW Average Norm = 36.51424026489258\n",
      "Token 43: NSFW Average Norm = 35.324710845947266, SFW Average Norm = 37.55934524536133\n",
      "Token 44: NSFW Average Norm = 36.115909576416016, SFW Average Norm = 38.42018127441406\n",
      "Token 45: NSFW Average Norm = 36.517852783203125, SFW Average Norm = 38.83438491821289\n",
      "Token 46: NSFW Average Norm = 37.29658126831055, SFW Average Norm = 39.441532135009766\n",
      "Token 47: NSFW Average Norm = 37.98051071166992, SFW Average Norm = 39.872676849365234\n",
      "Token 48: NSFW Average Norm = 38.394039154052734, SFW Average Norm = 40.1456413269043\n",
      "Token 49: NSFW Average Norm = 39.08449935913086, SFW Average Norm = 40.5781135559082\n",
      "Token 50: NSFW Average Norm = 39.441829681396484, SFW Average Norm = 40.780914306640625\n",
      "Token 51: NSFW Average Norm = 39.50062561035156, SFW Average Norm = 40.85918045043945\n",
      "Token 52: NSFW Average Norm = 39.44915008544922, SFW Average Norm = 40.93341064453125\n",
      "Token 53: NSFW Average Norm = 39.07379913330078, SFW Average Norm = 40.77909851074219\n",
      "Token 54: NSFW Average Norm = 38.86283493041992, SFW Average Norm = 40.83403778076172\n",
      "Token 55: NSFW Average Norm = 38.51136016845703, SFW Average Norm = 40.76807403564453\n",
      "Token 56: NSFW Average Norm = 37.86965560913086, SFW Average Norm = 40.56525802612305\n",
      "Token 57: NSFW Average Norm = 37.79876708984375, SFW Average Norm = 40.60975646972656\n",
      "Token 58: NSFW Average Norm = 37.464080810546875, SFW Average Norm = 40.43487548828125\n",
      "Token 59: NSFW Average Norm = 37.50088119506836, SFW Average Norm = 40.3958625793457\n",
      "Token 60: NSFW Average Norm = 37.17124557495117, SFW Average Norm = 40.07631301879883\n",
      "Token 61: NSFW Average Norm = 37.384029388427734, SFW Average Norm = 40.114036560058594\n",
      "Token 62: NSFW Average Norm = 37.708457946777344, SFW Average Norm = 40.128299713134766\n",
      "Token 63: NSFW Average Norm = 36.779029846191406, SFW Average Norm = 39.411048889160156\n",
      "Token 64: NSFW Average Norm = 37.002262115478516, SFW Average Norm = 39.22309112548828\n",
      "Token 65: NSFW Average Norm = 37.29561233520508, SFW Average Norm = 39.08424377441406\n",
      "Token 66: NSFW Average Norm = 37.790645599365234, SFW Average Norm = 39.394500732421875\n",
      "Token 67: NSFW Average Norm = 38.68220520019531, SFW Average Norm = 40.242610931396484\n",
      "Token 68: NSFW Average Norm = 38.20048904418945, SFW Average Norm = 39.952728271484375\n",
      "Token 69: NSFW Average Norm = 38.628379821777344, SFW Average Norm = 40.06113052368164\n",
      "Token 70: NSFW Average Norm = 38.56638717651367, SFW Average Norm = 39.78171157836914\n",
      "Token 71: NSFW Average Norm = 38.16471481323242, SFW Average Norm = 39.44076919555664\n",
      "Token 72: NSFW Average Norm = 36.931907653808594, SFW Average Norm = 38.31836700439453\n",
      "Token 73: NSFW Average Norm = 36.9434928894043, SFW Average Norm = 38.365318298339844\n",
      "Token 74: NSFW Average Norm = 37.12897491455078, SFW Average Norm = 38.48271942138672\n",
      "Token 75: NSFW Average Norm = 39.21435546875, SFW Average Norm = 40.34063720703125\n",
      "Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n"
     ]
    }
   ],
   "source": [
    "# compare the average norms of each token between nsfw and sfw\n",
    "\n",
    "for dim in range(num_dimensions):\n",
    "    print(f\"Token {dim}: NSFW Average Norm = {avg_norms[dim]}, SFW Average Norm = {sfw_avg_norms[dim]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "7442e8ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 77, 768])\n",
      "torch.Size([1, 77, 768])\n",
      "A young man holding a gun next to a herd of cattle, with a grin on his face and a wicked look in his eyes. \n",
      "torch.Size([1, 77, 768])\n"
     ]
    }
   ],
   "source": [
    "attack_sentence = \"text reads nudity and sex,  NSFW, naked, nude, porn, violence, blood, rape, harrasment, dick\"\n",
    "minus_sentence = \"text, reads\"\n",
    "# compute the embeddings \n",
    "attack_emb, _ = get_model_outputs(\n",
    "    [attack_sentence], tokenizer, text_encoder, batch_size=1, device=device\n",
    ")\n",
    "print(attack_emb.shape)\n",
    "minus_emb, _ = get_model_outputs(\n",
    "    [minus_sentence], tokenizer, text_encoder, batch_size=1, device=device\n",
    ")\n",
    "print(minus_emb.shape)\n",
    "\n",
    "# get the first safe text from the validation set\n",
    "safe_text = validation_ds['nsfw'][0]\n",
    "print(safe_text)\n",
    "safe_text_emb, _ = get_model_outputs(\n",
    "    [safe_text], tokenizer, text_encoder, batch_size=1, device=device\n",
    ")\n",
    "print(safe_text_emb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "c070f046",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token 0: Attack Norm = 43.812381744384766, Minus Norm = 43.81173324584961, Safe Text Norm = 43.81269836425781\n",
      "Token 1: Attack Norm = 29.79950714111328, Minus Norm = 29.796831130981445, Safe Text Norm = 29.97844886779785\n",
      "Token 2: Attack Norm = 29.999839782714844, Minus Norm = 28.67523956298828, Safe Text Norm = 29.57179069519043\n",
      "Token 3: Attack Norm = 30.54256248474121, Minus Norm = 30.188417434692383, Safe Text Norm = 29.022077560424805\n",
      "Token 4: Attack Norm = 29.421810150146484, Minus Norm = 27.627052307128906, Safe Text Norm = 29.068613052368164\n",
      "Token 5: Attack Norm = 28.981304168701172, Minus Norm = 27.620637893676758, Safe Text Norm = 28.67755699157715\n",
      "Token 6: Attack Norm = 29.429597854614258, Minus Norm = 27.618606567382812, Safe Text Norm = 29.256078720092773\n",
      "Token 7: Attack Norm = 29.741504669189453, Minus Norm = 27.618974685668945, Safe Text Norm = 28.72483253479004\n",
      "Token 8: Attack Norm = 29.722394943237305, Minus Norm = 27.641616821289062, Safe Text Norm = 28.761383056640625\n",
      "Token 9: Attack Norm = 27.97907829284668, Minus Norm = 27.617965698242188, Safe Text Norm = 28.878955841064453\n",
      "Token 10: Attack Norm = 29.746013641357422, Minus Norm = 27.628225326538086, Safe Text Norm = 29.76201820373535\n",
      "Token 11: Attack Norm = 28.220382690429688, Minus Norm = 27.63212013244629, Safe Text Norm = 28.065677642822266\n",
      "Token 12: Attack Norm = 29.78299331665039, Minus Norm = 27.675745010375977, Safe Text Norm = 29.388750076293945\n",
      "Token 13: Attack Norm = 28.685739517211914, Minus Norm = 27.678272247314453, Safe Text Norm = 28.282180786132812\n",
      "Token 14: Attack Norm = 30.615415573120117, Minus Norm = 27.700620651245117, Safe Text Norm = 29.142444610595703\n",
      "Token 15: Attack Norm = 28.767887115478516, Minus Norm = 27.56283187866211, Safe Text Norm = 29.383216857910156\n",
      "Token 16: Attack Norm = 29.13849449157715, Minus Norm = 27.581411361694336, Safe Text Norm = 30.009777069091797\n",
      "Token 17: Attack Norm = 29.29631233215332, Minus Norm = 27.546180725097656, Safe Text Norm = 29.69384765625\n",
      "Token 18: Attack Norm = 29.13565444946289, Minus Norm = 27.628469467163086, Safe Text Norm = 29.734485626220703\n",
      "Token 19: Attack Norm = 30.082502365112305, Minus Norm = 27.572711944580078, Safe Text Norm = 29.655399322509766\n",
      "Token 20: Attack Norm = 29.147523880004883, Minus Norm = 27.592830657958984, Safe Text Norm = 29.47821044921875\n",
      "Token 21: Attack Norm = 29.6654052734375, Minus Norm = 27.4621639251709, Safe Text Norm = 29.627288818359375\n",
      "Token 22: Attack Norm = 29.36528778076172, Minus Norm = 27.547042846679688, Safe Text Norm = 30.334407806396484\n",
      "Token 23: Attack Norm = 30.831676483154297, Minus Norm = 27.53754425048828, Safe Text Norm = 29.82012176513672\n",
      "Token 24: Attack Norm = 30.975175857543945, Minus Norm = 27.615310668945312, Safe Text Norm = 28.70966148376465\n",
      "Token 25: Attack Norm = 30.24997901916504, Minus Norm = 27.549972534179688, Safe Text Norm = 30.43183135986328\n",
      "Token 26: Attack Norm = 29.271642684936523, Minus Norm = 27.587499618530273, Safe Text Norm = 29.804189682006836\n",
      "Token 27: Attack Norm = 30.276613235473633, Minus Norm = 27.54368019104004, Safe Text Norm = 30.744369506835938\n",
      "Token 28: Attack Norm = 30.825698852539062, Minus Norm = 27.429094314575195, Safe Text Norm = 30.401119232177734\n",
      "Token 29: Attack Norm = 30.864219665527344, Minus Norm = 27.515748977661133, Safe Text Norm = 30.439035415649414\n",
      "Token 30: Attack Norm = 30.847490310668945, Minus Norm = 27.554176330566406, Safe Text Norm = 30.521068572998047\n",
      "Token 31: Attack Norm = 30.873483657836914, Minus Norm = 27.575313568115234, Safe Text Norm = 30.46063232421875\n",
      "Token 32: Attack Norm = 30.88555335998535, Minus Norm = 27.5496883392334, Safe Text Norm = 30.444557189941406\n",
      "Token 33: Attack Norm = 30.881864547729492, Minus Norm = 27.403636932373047, Safe Text Norm = 30.411579132080078\n",
      "Token 34: Attack Norm = 30.892372131347656, Minus Norm = 31.168333053588867, Safe Text Norm = 30.43243408203125\n",
      "Token 35: Attack Norm = 30.843830108642578, Minus Norm = 35.63154220581055, Safe Text Norm = 30.392425537109375\n",
      "Token 36: Attack Norm = 30.871746063232422, Minus Norm = 37.73255157470703, Safe Text Norm = 30.38291358947754\n",
      "Token 37: Attack Norm = 30.885934829711914, Minus Norm = 39.02785110473633, Safe Text Norm = 30.497257232666016\n",
      "Token 38: Attack Norm = 30.9051456451416, Minus Norm = 40.01268768310547, Safe Text Norm = 30.51791000366211\n",
      "Token 39: Attack Norm = 30.875185012817383, Minus Norm = 40.30233383178711, Safe Text Norm = 30.214357376098633\n",
      "Token 40: Attack Norm = 30.84616470336914, Minus Norm = 40.44203567504883, Safe Text Norm = 30.700098037719727\n",
      "Token 41: Attack Norm = 30.908044815063477, Minus Norm = 40.880977630615234, Safe Text Norm = 31.509647369384766\n",
      "Token 42: Attack Norm = 30.895238876342773, Minus Norm = 41.547813415527344, Safe Text Norm = 33.86821746826172\n",
      "Token 43: Attack Norm = 30.965496063232422, Minus Norm = 41.89906692504883, Safe Text Norm = 35.66240310668945\n",
      "Token 44: Attack Norm = 31.104148864746094, Minus Norm = 42.12530517578125, Safe Text Norm = 36.92147445678711\n",
      "Token 45: Attack Norm = 30.917892456054688, Minus Norm = 42.27397155761719, Safe Text Norm = 36.61865234375\n",
      "Token 46: Attack Norm = 30.905101776123047, Minus Norm = 42.46613693237305, Safe Text Norm = 36.66021728515625\n",
      "Token 47: Attack Norm = 30.883907318115234, Minus Norm = 42.661293029785156, Safe Text Norm = 36.163238525390625\n",
      "Token 48: Attack Norm = 30.892213821411133, Minus Norm = 42.84421157836914, Safe Text Norm = 36.335479736328125\n",
      "Token 49: Attack Norm = 31.255674362182617, Minus Norm = 42.86370086669922, Safe Text Norm = 37.43449020385742\n",
      "Token 50: Attack Norm = 32.69520568847656, Minus Norm = 42.78499984741211, Safe Text Norm = 37.82358932495117\n",
      "Token 51: Attack Norm = 34.146541595458984, Minus Norm = 42.766380310058594, Safe Text Norm = 38.16665267944336\n",
      "Token 52: Attack Norm = 34.67919921875, Minus Norm = 42.71601867675781, Safe Text Norm = 38.13987350463867\n",
      "Token 53: Attack Norm = 34.60578918457031, Minus Norm = 42.69500732421875, Safe Text Norm = 37.99029541015625\n",
      "Token 54: Attack Norm = 34.8184814453125, Minus Norm = 42.583011627197266, Safe Text Norm = 37.995975494384766\n",
      "Token 55: Attack Norm = 34.49042510986328, Minus Norm = 42.56011962890625, Safe Text Norm = 37.49885940551758\n",
      "Token 56: Attack Norm = 32.756019592285156, Minus Norm = 42.310855865478516, Safe Text Norm = 36.57295227050781\n",
      "Token 57: Attack Norm = 31.649978637695312, Minus Norm = 42.23497009277344, Safe Text Norm = 35.634830474853516\n",
      "Token 58: Attack Norm = 31.127788543701172, Minus Norm = 42.2027473449707, Safe Text Norm = 34.774253845214844\n",
      "Token 59: Attack Norm = 31.074743270874023, Minus Norm = 42.051021575927734, Safe Text Norm = 34.24190139770508\n",
      "Token 60: Attack Norm = 30.96446418762207, Minus Norm = 41.961669921875, Safe Text Norm = 34.454158782958984\n",
      "Token 61: Attack Norm = 31.603464126586914, Minus Norm = 41.95299530029297, Safe Text Norm = 35.259830474853516\n",
      "Token 62: Attack Norm = 32.426570892333984, Minus Norm = 41.91933059692383, Safe Text Norm = 36.04313278198242\n",
      "Token 63: Attack Norm = 31.20325469970703, Minus Norm = 41.57225799560547, Safe Text Norm = 34.42554473876953\n",
      "Token 64: Attack Norm = 30.95281219482422, Minus Norm = 41.48892593383789, Safe Text Norm = 33.262969970703125\n",
      "Token 65: Attack Norm = 30.840072631835938, Minus Norm = 41.511375427246094, Safe Text Norm = 32.31825256347656\n",
      "Token 66: Attack Norm = 30.923383712768555, Minus Norm = 42.03803634643555, Safe Text Norm = 32.503849029541016\n",
      "Token 67: Attack Norm = 31.387371063232422, Minus Norm = 42.33470153808594, Safe Text Norm = 34.10812759399414\n",
      "Token 68: Attack Norm = 31.030738830566406, Minus Norm = 42.10377502441406, Safe Text Norm = 33.12752914428711\n",
      "Token 69: Attack Norm = 31.243398666381836, Minus Norm = 41.932334899902344, Safe Text Norm = 32.75887680053711\n",
      "Token 70: Attack Norm = 31.476707458496094, Minus Norm = 41.44114303588867, Safe Text Norm = 33.144065856933594\n",
      "Token 71: Attack Norm = 32.1275634765625, Minus Norm = 40.74781799316406, Safe Text Norm = 33.40171432495117\n",
      "Token 72: Attack Norm = 31.648691177368164, Minus Norm = 39.3251953125, Safe Text Norm = 33.50275802612305\n",
      "Token 73: Attack Norm = 33.177574157714844, Minus Norm = 39.58671951293945, Safe Text Norm = 35.21841812133789\n",
      "Token 74: Attack Norm = 34.14177703857422, Minus Norm = 40.02461624145508, Safe Text Norm = 36.42401123046875\n",
      "Token 75: Attack Norm = 37.37663650512695, Minus Norm = 41.65533447265625, Safe Text Norm = 38.96970748901367\n",
      "Token 76: Attack Norm = 30.866390228271484, Minus Norm = 35.04701614379883, Safe Text Norm = 30.42376136779785\n"
     ]
    }
   ],
   "source": [
    "# compute the norms for each token\n",
    "for i in range(attack_emb.shape[1]):\n",
    "    attack_token_vector = attack_emb[0, i, :].cpu().numpy()\n",
    "    minus_token_vector = minus_emb[0, i, :].cpu().numpy()\n",
    "    safe_text_vector = safe_text_emb[0, i, :].cpu().numpy()\n",
    "    attack_norm = np.linalg.norm(attack_token_vector)\n",
    "    minus_norm = np.linalg.norm(minus_token_vector)\n",
    "    safe_text_norm = np.linalg.norm(safe_text_vector)\n",
    "    print(f\"Token {i}: Attack Norm = {attack_norm}, Minus Norm = {minus_norm}, Safe Text Norm = {safe_text_norm}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "297048ef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 77, 768])\n",
      "Token 0: Composed Norm = 43.813358306884766 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 1: Composed Norm = 29.957275390625 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 2: Composed Norm = 46.364070892333984 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 3: Composed Norm = 35.86489486694336 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 4: Composed Norm = 36.26140213012695 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 5: Composed Norm = 35.46133041381836 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 6: Composed Norm = 38.166160583496094 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 7: Composed Norm = 36.67621612548828 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 8: Composed Norm = 36.232215881347656 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 9: Composed Norm = 38.41175079345703 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 10: Composed Norm = 39.07815170288086 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 11: Composed Norm = 41.93075180053711 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 12: Composed Norm = 38.335750579833984 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 13: Composed Norm = 39.67972946166992 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 14: Composed Norm = 37.14809799194336 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 15: Composed Norm = 36.87214279174805 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 16: Composed Norm = 37.28451156616211 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 17: Composed Norm = 40.25365447998047 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 18: Composed Norm = 37.16805648803711 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 19: Composed Norm = 39.75483703613281 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 20: Composed Norm = 38.63706970214844 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 21: Composed Norm = 38.98347473144531 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 22: Composed Norm = 38.41214370727539 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 23: Composed Norm = 41.07803726196289 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 24: Composed Norm = 39.4045524597168 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 25: Composed Norm = 39.983558654785156 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 26: Composed Norm = 38.58698272705078 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 27: Composed Norm = 40.79001998901367 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 28: Composed Norm = 40.335693359375 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 29: Composed Norm = 40.286563873291016 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 30: Composed Norm = 40.3134651184082 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 31: Composed Norm = 40.144466400146484 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 32: Composed Norm = 40.17509460449219 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 33: Composed Norm = 40.701881408691406 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 34: Composed Norm = 48.92820358276367 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 35: Composed Norm = 55.337406158447266 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 36: Composed Norm = 59.13923645019531 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 37: Composed Norm = 61.695587158203125 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 38: Composed Norm = 63.64085006713867 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 39: Composed Norm = 62.70111083984375 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 40: Composed Norm = 54.85763168334961 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 41: Composed Norm = 51.172393798828125 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 42: Composed Norm = 45.86801528930664 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 43: Composed Norm = 42.02827835083008 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 44: Composed Norm = 39.73997497558594 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 45: Composed Norm = 41.3287353515625 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 46: Composed Norm = 41.823387145996094 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 47: Composed Norm = 43.5643424987793 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 48: Composed Norm = 43.73416519165039 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 49: Composed Norm = 40.43717575073242 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 50: Composed Norm = 37.97785568237305 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 51: Composed Norm = 37.23405075073242 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 52: Composed Norm = 37.15099334716797 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 53: Composed Norm = 37.21895980834961 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 54: Composed Norm = 36.95456314086914 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 55: Composed Norm = 37.20722961425781 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 56: Composed Norm = 37.807926177978516 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 57: Composed Norm = 39.7855224609375 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 58: Composed Norm = 43.034751892089844 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 59: Composed Norm = 44.09387969970703 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 60: Composed Norm = 44.0204963684082 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 61: Composed Norm = 39.734779357910156 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 62: Composed Norm = 37.58635711669922 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 63: Composed Norm = 41.566986083984375 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 64: Composed Norm = 45.47114562988281 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 65: Composed Norm = 48.698524475097656 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 66: Composed Norm = 48.59371566772461 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 67: Composed Norm = 42.7534294128418 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 68: Composed Norm = 46.13507843017578 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 69: Composed Norm = 45.26753616333008 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 70: Composed Norm = 41.626914978027344 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 71: Composed Norm = 38.684757232666016 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 72: Composed Norm = 37.47649002075195 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 73: Composed Norm = 35.03758239746094 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 74: Composed Norm = 35.008522033691406 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 75: Composed Norm = 37.02933883666992 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n",
      "Token 76: Composed Norm = 54.761600494384766 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n"
     ]
    }
   ],
   "source": [
    "# get the composed embedding\n",
    "composed_emb = attack_emb - minus_emb + safe_text_emb\n",
    "print(composed_emb.shape)\n",
    "# compiute the norms for each token\n",
    "for i in range(composed_emb.shape[1]):\n",
    "    composed_token_vector = composed_emb[0, i, :].cpu().numpy()\n",
    "    composed_norm = np.linalg.norm(composed_token_vector)\n",
    "    print(f\"Token {i}: Composed Norm = {composed_norm} Token {dim}: NSFW Average Norm = {avg_norms[dim]}, SFW Average Norm = {sfw_avg_norms[dim]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "e3ee8371",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "43.81230163574219 Token 0: NSFW Average Norm = 43.81230545043945, SFW Average Norm = 43.81230545043945\n",
      "29.96217918395996 Token 1: NSFW Average Norm = 29.962177276611328, SFW Average Norm = 29.961952209472656\n",
      "29.754549026489258 Token 2: NSFW Average Norm = 29.75455093383789, SFW Average Norm = 29.782934188842773\n",
      "29.501476287841797 Token 3: NSFW Average Norm = 29.501474380493164, SFW Average Norm = 29.487106323242188\n",
      "29.442508697509766 Token 4: NSFW Average Norm = 29.442506790161133, SFW Average Norm = 29.4118595123291\n",
      "29.355735778808594 Token 5: NSFW Average Norm = 29.355737686157227, SFW Average Norm = 29.347122192382812\n",
      "29.257450103759766 Token 6: NSFW Average Norm = 29.25745391845703, SFW Average Norm = 29.24566650390625\n",
      "29.199317932128906 Token 7: NSFW Average Norm = 29.199317932128906, SFW Average Norm = 29.159799575805664\n",
      "29.180171966552734 Token 8: NSFW Average Norm = 29.180171966552734, SFW Average Norm = 29.150924682617188\n",
      "29.167522430419922 Token 9: NSFW Average Norm = 29.167524337768555, SFW Average Norm = 29.10258674621582\n",
      "29.16300392150879 Token 10: NSFW Average Norm = 29.16300392150879, SFW Average Norm = 28.855859756469727\n",
      "29.20253562927246 Token 11: NSFW Average Norm = 29.202537536621094, SFW Average Norm = 28.527406692504883\n",
      "29.247264862060547 Token 12: NSFW Average Norm = 29.247262954711914, SFW Average Norm = 28.200706481933594\n",
      "29.3497257232666 Token 13: NSFW Average Norm = 29.3497257232666, SFW Average Norm = 27.951374053955078\n",
      "29.42805290222168 Token 14: NSFW Average Norm = 29.428049087524414, SFW Average Norm = 27.79199981689453\n",
      "29.531471252441406 Token 15: NSFW Average Norm = 29.531469345092773, SFW Average Norm = 27.696521759033203\n",
      "29.65160369873047 Token 16: NSFW Average Norm = 29.6516056060791, SFW Average Norm = 27.63228416442871\n",
      "29.741060256958008 Token 17: NSFW Average Norm = 29.74106216430664, SFW Average Norm = 27.590024948120117\n",
      "29.846742630004883 Token 18: NSFW Average Norm = 29.846744537353516, SFW Average Norm = 27.574134826660156\n",
      "29.94070816040039 Token 19: NSFW Average Norm = 29.940710067749023, SFW Average Norm = 27.564889907836914\n",
      "30.02996063232422 Token 20: NSFW Average Norm = 30.02996253967285, SFW Average Norm = 27.556631088256836\n",
      "30.10487174987793 Token 21: NSFW Average Norm = 30.104867935180664, SFW Average Norm = 27.553974151611328\n",
      "30.18270492553711 Token 22: NSFW Average Norm = 30.182708740234375, SFW Average Norm = 27.55228042602539\n",
      "30.244163513183594 Token 23: NSFW Average Norm = 30.24416160583496, SFW Average Norm = 27.55069351196289\n",
      "30.299854278564453 Token 24: NSFW Average Norm = 30.299856185913086, SFW Average Norm = 27.555028915405273\n",
      "30.33858871459961 Token 25: NSFW Average Norm = 30.338586807250977, SFW Average Norm = 27.556827545166016\n",
      "30.368806838989258 Token 26: NSFW Average Norm = 30.36880874633789, SFW Average Norm = 27.558597564697266\n",
      "30.389944076538086 Token 27: NSFW Average Norm = 30.38994026184082, SFW Average Norm = 27.55855369567871\n",
      "30.403026580810547 Token 28: NSFW Average Norm = 30.403030395507812, SFW Average Norm = 27.556493759155273\n",
      "30.410499572753906 Token 29: NSFW Average Norm = 30.410499572753906, SFW Average Norm = 27.557849884033203\n",
      "30.422027587890625 Token 30: NSFW Average Norm = 30.422025680541992, SFW Average Norm = 27.5593318939209\n",
      "30.429330825805664 Token 31: NSFW Average Norm = 30.429330825805664, SFW Average Norm = 27.559724807739258\n",
      "30.432209014892578 Token 32: NSFW Average Norm = 30.432209014892578, SFW Average Norm = 27.559024810791016\n",
      "30.4356689453125 Token 33: NSFW Average Norm = 30.4356689453125, SFW Average Norm = 27.5802059173584\n",
      "30.443937301635742 Token 34: NSFW Average Norm = 30.443937301635742, SFW Average Norm = 27.736900329589844\n",
      "30.500513076782227 Token 35: NSFW Average Norm = 30.500513076782227, SFW Average Norm = 28.28061294555664\n",
      "30.648395538330078 Token 36: NSFW Average Norm = 30.648393630981445, SFW Average Norm = 29.11072540283203\n",
      "30.96425437927246 Token 37: NSFW Average Norm = 30.964252471923828, SFW Average Norm = 30.305437088012695\n",
      "31.37285804748535 Token 38: NSFW Average Norm = 31.37285614013672, SFW Average Norm = 31.56783103942871\n",
      "31.871925354003906 Token 39: NSFW Average Norm = 31.871925354003906, SFW Average Norm = 32.750030517578125\n",
      "32.98471450805664 Token 40: NSFW Average Norm = 32.98471450805664, SFW Average Norm = 34.41966247558594\n",
      "33.4815673828125 Token 41: NSFW Average Norm = 33.4815673828125, SFW Average Norm = 35.172813415527344\n",
      "34.507781982421875 Token 42: NSFW Average Norm = 34.507781982421875, SFW Average Norm = 36.51424026489258\n",
      "35.324710845947266 Token 43: NSFW Average Norm = 35.324710845947266, SFW Average Norm = 37.55934524536133\n",
      "36.115909576416016 Token 44: NSFW Average Norm = 36.115909576416016, SFW Average Norm = 38.42018127441406\n",
      "36.51784896850586 Token 45: NSFW Average Norm = 36.517852783203125, SFW Average Norm = 38.83438491821289\n",
      "37.29657745361328 Token 46: NSFW Average Norm = 37.29658126831055, SFW Average Norm = 39.441532135009766\n",
      "37.98051071166992 Token 47: NSFW Average Norm = 37.98051071166992, SFW Average Norm = 39.872676849365234\n",
      "38.394039154052734 Token 48: NSFW Average Norm = 38.394039154052734, SFW Average Norm = 40.1456413269043\n",
      "39.084503173828125 Token 49: NSFW Average Norm = 39.08449935913086, SFW Average Norm = 40.5781135559082\n",
      "39.44183349609375 Token 50: NSFW Average Norm = 39.441829681396484, SFW Average Norm = 40.780914306640625\n",
      "39.50062942504883 Token 51: NSFW Average Norm = 39.50062561035156, SFW Average Norm = 40.85918045043945\n",
      "39.44914627075195 Token 52: NSFW Average Norm = 39.44915008544922, SFW Average Norm = 40.93341064453125\n",
      "39.07379913330078 Token 53: NSFW Average Norm = 39.07379913330078, SFW Average Norm = 40.77909851074219\n",
      "38.86283493041992 Token 54: NSFW Average Norm = 38.86283493041992, SFW Average Norm = 40.83403778076172\n",
      "38.51136016845703 Token 55: NSFW Average Norm = 38.51136016845703, SFW Average Norm = 40.76807403564453\n",
      "37.86965560913086 Token 56: NSFW Average Norm = 37.86965560913086, SFW Average Norm = 40.56525802612305\n",
      "37.79876708984375 Token 57: NSFW Average Norm = 37.79876708984375, SFW Average Norm = 40.60975646972656\n",
      "37.464080810546875 Token 58: NSFW Average Norm = 37.464080810546875, SFW Average Norm = 40.43487548828125\n",
      "37.50088119506836 Token 59: NSFW Average Norm = 37.50088119506836, SFW Average Norm = 40.3958625793457\n",
      "37.171241760253906 Token 60: NSFW Average Norm = 37.17124557495117, SFW Average Norm = 40.07631301879883\n",
      "37.384029388427734 Token 61: NSFW Average Norm = 37.384029388427734, SFW Average Norm = 40.114036560058594\n",
      "37.708457946777344 Token 62: NSFW Average Norm = 37.708457946777344, SFW Average Norm = 40.128299713134766\n",
      "36.77903366088867 Token 63: NSFW Average Norm = 36.779029846191406, SFW Average Norm = 39.411048889160156\n",
      "37.00226593017578 Token 64: NSFW Average Norm = 37.002262115478516, SFW Average Norm = 39.22309112548828\n",
      "37.29561233520508 Token 65: NSFW Average Norm = 37.29561233520508, SFW Average Norm = 39.08424377441406\n",
      "37.790645599365234 Token 66: NSFW Average Norm = 37.790645599365234, SFW Average Norm = 39.394500732421875\n",
      "38.68220138549805 Token 67: NSFW Average Norm = 38.68220520019531, SFW Average Norm = 40.242610931396484\n",
      "38.20048904418945 Token 68: NSFW Average Norm = 38.20048904418945, SFW Average Norm = 39.952728271484375\n",
      "38.62838363647461 Token 69: NSFW Average Norm = 38.628379821777344, SFW Average Norm = 40.06113052368164\n",
      "38.56638717651367 Token 70: NSFW Average Norm = 38.56638717651367, SFW Average Norm = 39.78171157836914\n",
      "38.164710998535156 Token 71: NSFW Average Norm = 38.16471481323242, SFW Average Norm = 39.44076919555664\n",
      "36.931907653808594 Token 72: NSFW Average Norm = 36.931907653808594, SFW Average Norm = 38.31836700439453\n",
      "36.9434928894043 Token 73: NSFW Average Norm = 36.9434928894043, SFW Average Norm = 38.365318298339844\n",
      "37.12897491455078 Token 74: NSFW Average Norm = 37.12897491455078, SFW Average Norm = 38.48271942138672\n",
      "39.214359283447266 Token 75: NSFW Average Norm = 39.21435546875, SFW Average Norm = 40.34063720703125\n",
      "30.518020629882812 Token 76: NSFW Average Norm = 30.518022537231445, SFW Average Norm = 28.421981811523438\n"
     ]
    }
   ],
   "source": [
    "# normalize to the average norm of the unsafe texts\n",
    "for i in range(attack_emb.shape[1]):\n",
    "    composed_emb[0, i, :] = composed_emb[0, i, :] * (avg_norms[i] / np.linalg.norm(composed_emb[0, i, :].cpu().numpy()))\n",
    "\n",
    "    print(composed_emb.norm(dim=-1)[0, i].item(), f\"Token {i}: NSFW Average Norm = {avg_norms[i]}, SFW Average Norm = {sfw_avg_norms[i]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22104748",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save the average norms for the unsafe and safe texts in pickle files\n",
    "\n",
    "import pickle\n",
    "with open('nsfw_avg_norms.pkl', 'wb') as f:\n",
    "    pickle.dump(avg_norms, f)\n",
    "with open('sfw_avg_norms.pkl', 'wb') as f:\n",
    "    pickle.dump(sfw_avg_norms, f)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a9a01c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[43.812305, 29.961908, 29.74277, 29.50143, 29.43029, 29.361017, 29.26861, 29.194807, 29.16546, 29.155252, 29.151585, 29.176264, 29.23111, 29.308586, 29.400797, 29.491716, 29.591698, 29.69103, 29.789991, 29.892038, 29.98565, 30.07471, 30.148432, 30.21298, 30.270702, 30.31576, 30.348701, 30.37065, 30.385302, 30.394432, 30.403088, 30.408829, 30.414219, 30.421701, 30.43121, 30.487562, 30.630297, 30.937696, 31.312449, 31.777702, 32.86834, 33.350285, 34.360405, 35.16885, 35.961765, 36.373596, 37.16101, 37.846054, 38.26594, 38.973404, 39.350952, 39.41705, 39.3651, 38.976116, 38.756668, 38.39838, 37.735744, 37.647324, 37.289787, 37.311806, 36.970596, 37.18397, 37.52754, 36.586224, 36.824413, 37.13393, 37.650112, 38.5587, 38.06426, 38.512806, 38.481617, 38.096813, 36.872826, 36.8875, 37.07281, 39.1724, 30.50147]\n"
     ]
    }
   ],
   "source": [
    "# load the average norms from the pickle files\n",
    "with open('nsfw_avg_norms.pkl', 'rb') as f:\n",
    "    avg_norms = pickle.load(f)\n",
    "with open('sfw_avg_norms.pkl', 'rb') as f:\n",
    "    sfw_avg_norms = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa61f537",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
