{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import CLIPTokenizer, CLIPTextModel\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from dataclasses import dataclass\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from copy import deepcopy\n",
    "from scipy.stats import gaussian_kde\n",
    "\n",
    "@dataclass\n",
    "class CLIPEmbeddingInfo:\n",
    "    prompt: str\n",
    "    embedding: np.ndarray\n",
    "\n",
    "class Args:\n",
    "    def __init__(self):\n",
    "        self.model_path = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n",
    "        self.device = \"cuda:0\"\n",
    "        self.unsafe_csv_path = \"datasets/unsafe_prompts_copro_sexual.csv\"\n",
    "        self.safe_csv_path = \"datasets/safe_prompts_copro_sexual.csv\"\n",
    "        self.concept_prompt = \"nudity\"\n",
    "        self.concept_guidance_scale = 200.0\n",
    "        self.safe_embedding_path = \"checkpoints/safe_embeddings.pth\"\n",
    "        self.text_encoder_path = \"checkpoints/des.pt\"\n",
    "\n",
    "args = Args()\n",
    "device = torch.device(args.device)\n",
    "\n",
    "# Load models\n",
    "tokenizer = CLIPTokenizer.from_pretrained(args.model_path, subfolder=\"tokenizer\")\n",
    "text_encoder = CLIPTextModel.from_pretrained(\n",
    "    args.model_path, \n",
    "    subfolder=\"text_encoder\"\n",
    ").to(device)\n",
    "text_encoder_ours = deepcopy(text_encoder)\n",
    "text_encoder_ours.load_state_dict(torch.load(args.text_encoder_path, map_location='cpu')['model_state_dict'])\n",
    "\n",
    "# Load prompts\n",
    "unsafe_df = pd.read_csv(args.unsafe_csv_path)\n",
    "safe_df = pd.read_csv(args.safe_csv_path)\n",
    "unsafe_prompts = unsafe_df['prompt'].tolist()\n",
    "safe_prompts = safe_df['prompt'].tolist()\n",
    "\n",
    "# Get unconditioned embedding\n",
    "empty_prompt = \"\"\n",
    "uncond_tokens = tokenizer(\n",
    "    empty_prompt,\n",
    "    padding=\"max_length\",\n",
    "    max_length=tokenizer.model_max_length,\n",
    "    truncation=True,\n",
    "    return_tensors=\"pt\"\n",
    ").to(device)\n",
    "\n",
    "# Get concept embedding\n",
    "concept_prompt = args.concept_prompt\n",
    "concept_tokens = tokenizer(\n",
    "    concept_prompt,\n",
    "    padding=\"max_length\",\n",
    "    max_length=tokenizer.model_max_length,\n",
    "    truncation=True,\n",
    "    return_tensors=\"pt\"\n",
    ").to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    concept_embedding = text_encoder(concept_tokens.input_ids)[0]\n",
    "\n",
    "concept_embedding = concept_embedding.view(1, -1)  # [1, 77*768]\n",
    "concept_embedding_norm = torch.norm(concept_embedding, dim=-1, keepdim=True)\n",
    "concept_embedding = concept_embedding / (concept_embedding_norm + 1e-12)\n",
    "\n",
    "# Try to load paired_data if path exists\n",
    "print(f\"Loading paired data from {args.safe_embedding_path}\")\n",
    "loaded_data = torch.load(args.safe_embedding_path, map_location='cpu')\n",
    "paired_data = loaded_data['paired_data']\n",
    "print(\"Successfully loaded paired data\")\n",
    "paired_data = [(\n",
    "    unsafe_prompt,\n",
    "    min_safe_embedding.to(device),\n",
    "    max_safe_embedding.to(device),\n",
    "    safe_prompt\n",
    ") for unsafe_prompt, min_safe_embedding, max_safe_embedding, safe_prompt in paired_data]\n",
    "del loaded_data\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "sims1 = []\n",
    "sims2 = []\n",
    "sims3 = []\n",
    "sims4 = []\n",
    "sims5 = []\n",
    "for (unsafe_prompt, min_safe_embedding, max_safe_embedding, safe_prompt) in tqdm(paired_data, desc='Computing cossim'):\n",
    "    # Originally saved after subtraction of concept direction, so add it back\n",
    "    max_safe_embedding = max_safe_embedding.view(-1) + args.concept_guidance_scale * concept_embedding\n",
    "    max_safe_embedding = max_safe_embedding.view(77, 768)\n",
    "    \n",
    "    # 1. similarity between max safe embedding and concept direction\n",
    "    max_safe_flat = max_safe_embedding.view(-1)\n",
    "    max_safe_norm = torch.norm(max_safe_flat, dim=-1, keepdim=True)\n",
    "    max_safe_normalized = max_safe_flat / (max_safe_norm + 1e-12)\n",
    "    max_safe_emb_sim = torch.sum(concept_embedding * max_safe_normalized)\n",
    "    \n",
    "    # 2. similarity between unsafe embedding and concept direction\n",
    "    unsafe_tokens = tokenizer(\n",
    "        unsafe_prompt,\n",
    "        padding=\"max_length\",\n",
    "        max_length=tokenizer.model_max_length,\n",
    "        truncation=True,\n",
    "        return_tensors=\"pt\"\n",
    "    ).to(device)\n",
    "    unsafe_embedding = text_encoder(unsafe_tokens.input_ids)[0]\n",
    "    unsafe_embedding_ours = text_encoder_ours(unsafe_tokens.input_ids)[0]\n",
    "    \n",
    "    unsafe_uncond_sub = unsafe_embedding[0].view(-1)\n",
    "    unsafe_uncond_sub_norm = torch.norm(unsafe_uncond_sub, dim=-1, keepdim=True)\n",
    "    unsafe_uncond_sub_normalized = unsafe_uncond_sub / (unsafe_uncond_sub_norm + 1e-12)\n",
    "    unsafe_uncond_sub_sim = torch.sum(concept_embedding * unsafe_uncond_sub_normalized)\n",
    "    \n",
    "    unsafe_uncond_sub_ours = unsafe_embedding_ours[0].view(-1)\n",
    "    unsafe_uncond_sub_norm_ours = torch.norm(unsafe_uncond_sub_ours, dim=-1, keepdim=True)\n",
    "    unsafe_uncond_sub_normalized_ours = unsafe_uncond_sub_ours / (unsafe_uncond_sub_norm_ours + 1e-12)\n",
    "    unsafe_uncond_sub_sim_ours = torch.sum(concept_embedding * unsafe_uncond_sub_normalized_ours)\n",
    "    \n",
    "    target_vector = max_safe_embedding.view(-1) - 50 * concept_embedding\n",
    "    target_vector_norm = torch.norm(target_vector, dim=-1, keepdim=True)\n",
    "    target_vector_normalized = target_vector / (target_vector_norm + 1e-12)\n",
    "    target_vector_sim = torch.sum(concept_embedding * target_vector_normalized)\n",
    "    \n",
    "    target_vector = max_safe_embedding.view(-1) - 100 * concept_embedding\n",
    "    target_vector_norm = torch.norm(target_vector, dim=-1, keepdim=True)\n",
    "    target_vector_normalized = target_vector / (target_vector_norm + 1e-12)\n",
    "    target_vector_sim2 = torch.sum(concept_embedding * target_vector_normalized)\n",
    "    \n",
    "    target_vector = max_safe_embedding.view(-1) - 150 * concept_embedding\n",
    "    target_vector_norm = torch.norm(target_vector, dim=-1, keepdim=True)\n",
    "    target_vector_normalized = target_vector / (target_vector_norm + 1e-12)\n",
    "    target_vector_sim3 = torch.sum(concept_embedding * target_vector_normalized)\n",
    "    \n",
    "    target_vector = max_safe_embedding.view(-1) - 200 * concept_embedding\n",
    "    target_vector_norm = torch.norm(target_vector, dim=-1, keepdim=True)\n",
    "    target_vector_normalized = target_vector / (target_vector_norm + 1e-12)\n",
    "    target_vector_sim4 = torch.sum(concept_embedding * target_vector_normalized)\n",
    "    \n",
    "    target_vector = min_safe_embedding.view(-1)\n",
    "    target_vector_norm = torch.norm(target_vector, dim=-1, keepdim=True)\n",
    "    target_vector_normalized = target_vector / (target_vector_norm + 1e-12)\n",
    "    minsafe_vector_sim = torch.sum(concept_embedding * target_vector_normalized)\n",
    "    \n",
    "    # similarity between safe embedding and concept direction\n",
    "    safe_tokens = tokenizer(\n",
    "        safe_prompt,\n",
    "        padding=\"max_length\",\n",
    "        max_length=tokenizer.model_max_length,\n",
    "        truncation=True,\n",
    "        return_tensors=\"pt\"\n",
    "    ).to(device)\n",
    "    safe_embedding = text_encoder(safe_tokens.input_ids)[0]\n",
    "    safe_embedding_ours = text_encoder_ours(safe_tokens.input_ids)[0]\n",
    "    \n",
    "    safe_uncond_sub = safe_embedding[0].view(-1)\n",
    "    safe_uncond_sub_norm = torch.norm(safe_uncond_sub, dim=-1, keepdim=True)\n",
    "    safe_uncond_sub_normalized = safe_uncond_sub / (safe_uncond_sub_norm + 1e-12)\n",
    "    safe_uncond_sub_sim = torch.sum(concept_embedding * safe_uncond_sub_normalized)\n",
    "    \n",
    "    safe_uncond_sub_ours = safe_embedding_ours[0].view(-1)\n",
    "    safe_uncond_sub_norm_ours = torch.norm(safe_uncond_sub_ours, dim=-1, keepdim=True)\n",
    "    safe_uncond_sub_normalized_ours = safe_uncond_sub_ours / (safe_uncond_sub_norm_ours + 1e-12)\n",
    "    safe_uncond_sub_sim_ours = torch.sum(concept_embedding * safe_uncond_sub_normalized_ours)\n",
    "    \n",
    "    sims1.append(max_safe_emb_sim.item())\n",
    "    sims2.append(target_vector_sim.item())\n",
    "    sims3.append(target_vector_sim2.item())\n",
    "    sims4.append(target_vector_sim3.item())\n",
    "    sims5.append(target_vector_sim4.item())\n",
    "\n",
    "bins = 100\n",
    "range = (-0.67, 0.41)\n",
    "\n",
    "plt.style.use('seaborn-v0_8-paper')\n",
    "plt.rcParams.update({\n",
    "    'font.family': 'serif',\n",
    "    'font.size': 9,\n",
    "    'axes.linewidth': 0.8,\n",
    "    'axes.labelsize': 9,\n",
    "    'xtick.labelsize': 9,\n",
    "    'ytick.labelsize': 9,\n",
    "})\n",
    "\n",
    "plt.figure(figsize=(8, 3), dpi=150)\n",
    "\n",
    "colors = ['#FFD700', '#D4DB00', '#A8DD00', '#7CDF00', '#50E000']\n",
    "plt.hist(sims1, bins=bins, range=range, alpha=0.7, color=colors[0],\n",
    "         label='Selected Safe Vector', edgecolor='white', linewidth=0.5)\n",
    "plt.hist(sims2, bins=bins, range=range, alpha=0.7, color=colors[1],\n",
    "         label='Target Vector w/ $s_{g}$=50', edgecolor='white', linewidth=0.5)\n",
    "plt.hist(sims3, bins=bins, range=range, alpha=0.7, color=colors[2],\n",
    "         label='Target Vector w/ $s_{g}$=100', edgecolor='white', linewidth=0.5)\n",
    "plt.hist(sims4, bins=bins, range=range, alpha=0.7, color=colors[3],\n",
    "         label='Target Vector w/ $s_{g}$=150', edgecolor='white', linewidth=0.5)\n",
    "plt.hist(sims5, bins=bins, range=range, alpha=0.7, color=colors[4],\n",
    "         label='Target Vector w/ $s_{g}$=200', edgecolor='white', linewidth=0.5)\n",
    "\n",
    "# Distribution lines\n",
    "for data in [sims1, sims2, sims3, sims4, sims5]:\n",
    "    hist_counts, bin_edges = np.histogram(data, bins=bins, range=range)\n",
    "    bin_width = bin_edges[1] - bin_edges[0]\n",
    "    \n",
    "    std = np.std(data)\n",
    "    bw_method = 0.45 * (0.05/std)\n",
    "    density = gaussian_kde(data, bw_method=bw_method)\n",
    "    xs = np.linspace(range[0], range[1], 200)\n",
    "    ys = density(xs)\n",
    "    ys = ys * bin_width * len(data)\n",
    "    plt.plot(xs, ys, color='black', linewidth=1, alpha=0.4)\n",
    "\n",
    "\n",
    "plt.xlabel('Cosine Similarity with $e_c$', fontsize=12, labelpad=0)\n",
    "plt.ylabel('Frequency', fontsize=12, labelpad=3)\n",
    "plt.grid(True, alpha=0.2, linestyle='--')\n",
    "\n",
    "plt.legend(frameon=True, fancybox=True, shadow=True, \n",
    "          fontsize=8,\n",
    "          bbox_to_anchor=(0.995, 0.99), \n",
    "          loc='upper right',\n",
    "          borderaxespad=0,\n",
    "          handlelength=1.5,\n",
    "          handletextpad=0.5,)\n",
    "\n",
    "plt.text(-0.58, 1200, '$s_{g}=200$', \n",
    "         bbox=dict(facecolor='none', alpha=0.8, edgecolor='none', pad=3),\n",
    "         fontsize=9)\n",
    "plt.text(-0.45, 1100, '$s_{g}=150$', \n",
    "         bbox=dict(facecolor='none', alpha=0.8, edgecolor='none', pad=3),\n",
    "         fontsize=9)\n",
    "plt.text(-0.29, 1000, '$s_{g}=100$', \n",
    "         bbox=dict(facecolor='none', alpha=0.8, edgecolor='none', pad=3),\n",
    "         fontsize=9)\n",
    "plt.text(-0.1, 900, '$s_{g}=50$', \n",
    "         bbox=dict(facecolor='none', alpha=0.8, edgecolor='none', pad=3),\n",
    "         fontsize=9)\n",
    "plt.text(0.1, 800, '$s_{g}=0$', \n",
    "         bbox=dict(facecolor='none', alpha=0.8, edgecolor='none', pad=3),\n",
    "         fontsize=9)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figure 6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Figure 6a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import CLIPTokenizer, CLIPTextModel\n",
    "import pandas as pd\n",
    "from dataclasses import dataclass\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from copy import deepcopy\n",
    "from scipy.stats import gaussian_kde\n",
    "\n",
    "@dataclass\n",
    "class CLIPEmbeddingInfo:\n",
    "    prompt: str\n",
    "    embedding: np.ndarray  # [seq_len, hidden_dim]\n",
    "\n",
    "class Args:\n",
    "    def __init__(self):\n",
    "        self.model_path = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n",
    "        self.device = \"cuda:0\"\n",
    "        self.unsafe_csv_path = \"datasets/unsafe_prompts_copro_sexual.csv\"\n",
    "        self.safe_csv_path = \"datasets/safe_prompts_copro_sexual.csv\"\n",
    "        self.adv1_csv_path = \"prompts/sneaky_prompts.csv\"\n",
    "        self.adv2_csv_path = \"prompts/mma_prompts.csv\"\n",
    "        self.adv3_csv_path = \"prompts/i2p_sexual_prompts.csv\"\n",
    "        self.adv4_csv_path = \"prompts/p4d_prompts.csv\"\n",
    "        self.adv5_csv_path = \"prompts/ringabell_prompts.csv\"\n",
    "        self.concept_prompt = \"nudity\"\n",
    "        self.concept_guidance_scale = 200.0\n",
    "        self.safe_embedding_path = \"checkpoints/safe_embeddings.pth\"\n",
    "        self.text_encoder_path = \"checkpoints/des.pt\"\n",
    "\n",
    "args = Args()\n",
    "torch.cuda.empty_cache()\n",
    "device = torch.device(args.device)\n",
    "\n",
    "# Load models\n",
    "tokenizer = CLIPTokenizer.from_pretrained(args.model_path, subfolder=\"tokenizer\")\n",
    "text_encoder = CLIPTextModel.from_pretrained(\n",
    "    args.model_path, \n",
    "    subfolder=\"text_encoder\"\n",
    ").to(device)\n",
    "text_encoder_ours = deepcopy(text_encoder)\n",
    "text_encoder_ours.load_state_dict(torch.load(args.text_encoder_path, map_location='cpu')['model_state_dict'])\n",
    "\n",
    "# Load prompts\n",
    "unsafe_df = pd.read_csv(args.unsafe_csv_path)\n",
    "safe_df = pd.read_csv(args.safe_csv_path)\n",
    "adv1_df = pd.read_csv(args.adv1_csv_path)\n",
    "adv2_df = pd.read_csv(args.adv2_csv_path)\n",
    "adv3_df = pd.read_csv(args.adv3_csv_path)\n",
    "adv4_df = pd.read_csv(args.adv4_csv_path)\n",
    "adv5_df = pd.read_csv(args.adv5_csv_path)\n",
    "unsafe_prompts = unsafe_df['prompt'].tolist()\n",
    "safe_prompts = safe_df['prompt'].tolist()\n",
    "adv1_prompts = adv1_df['prompt'].tolist()\n",
    "adv2_prompts = adv2_df['prompt'].tolist()\n",
    "adv3_prompts = adv3_df['prompt'].tolist()\n",
    "adv4_prompts = adv4_df['prompt'].tolist()\n",
    "adv5_prompts = adv5_df['prompt'].tolist()\n",
    "\n",
    "# Get concept embedding\n",
    "concept_prompt = args.concept_prompt\n",
    "concept_tokens = tokenizer(\n",
    "    concept_prompt,\n",
    "    padding=\"max_length\",\n",
    "    max_length=tokenizer.model_max_length,\n",
    "    truncation=True,\n",
    "    return_tensors=\"pt\"\n",
    ").to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    concept_embedding = text_encoder(concept_tokens.input_ids)[0]\n",
    "\n",
    "concept_embedding = concept_embedding.view(1, -1)  # [1, 77*768]\n",
    "concept_embedding_norm = torch.norm(concept_embedding, dim=-1, keepdim=True)\n",
    "concept_embedding = concept_embedding / (concept_embedding_norm + 1e-12)\n",
    "\n",
    "# Try to load paired_data if path exists\n",
    "print(f\"Loading paired data from {args.safe_embedding_path}\")\n",
    "loaded_data = torch.load(args.safe_embedding_path, map_location='cpu')\n",
    "paired_data = loaded_data['paired_data']\n",
    "print(\"Successfully loaded paired data\")\n",
    "paired_data = [(\n",
    "    unsafe_prompt,\n",
    "    min_safe_embedding.to(device),\n",
    "    max_safe_embedding.to(device),\n",
    "    safe_prompt\n",
    ") for unsafe_prompt, min_safe_embedding, max_safe_embedding, safe_prompt in paired_data]\n",
    "del loaded_data\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "sims1 = []\n",
    "sims2 = []\n",
    "sims3 = []\n",
    "sims4 = []\n",
    "sims5 = []\n",
    "sims6 = []\n",
    "sims7 = []\n",
    "sims8 = []\n",
    "sims9 = []\n",
    "sims10 = []\n",
    "for adv1_prompt in adv1_prompts:\n",
    "    unsafe_tokens = tokenizer(\n",
    "        adv1_prompt,\n",
    "        padding=\"max_length\",\n",
    "        max_length=tokenizer.model_max_length,\n",
    "        truncation=True,\n",
    "        return_tensors=\"pt\"\n",
    "    ).to(device)\n",
    "    unsafe_embedding = text_encoder(unsafe_tokens.input_ids)[0]\n",
    "    unsafe_embedding_ours = text_encoder_ours(unsafe_tokens.input_ids)[0]\n",
    "    \n",
    "    unsafe_uncond_sub = unsafe_embedding[0].view(-1)\n",
    "    unsafe_uncond_sub_norm = torch.norm(unsafe_uncond_sub, dim=-1, keepdim=True)\n",
    "    unsafe_uncond_sub_normalized = unsafe_uncond_sub / (unsafe_uncond_sub_norm + 1e-12)\n",
    "    unsafe_uncond_sub_sim = torch.sum(concept_embedding * unsafe_uncond_sub_normalized)\n",
    "    \n",
    "    unsafe_uncond_sub_ours = unsafe_embedding_ours[0].view(-1)\n",
    "    unsafe_uncond_sub_norm_ours = torch.norm(unsafe_uncond_sub_ours, dim=-1, keepdim=True)\n",
    "    unsafe_uncond_sub_normalized_ours = unsafe_uncond_sub_ours / (unsafe_uncond_sub_norm_ours + 1e-12)\n",
    "    unsafe_uncond_sub_sim_ours = torch.sum(concept_embedding * unsafe_uncond_sub_normalized_ours)\n",
    "    sims1.append(unsafe_uncond_sub_sim.item())\n",
    "    sims2.append(unsafe_uncond_sub_sim_ours.item())\n",
    "    \n",
    "for adv2_prompt in adv2_prompts:\n",
    "    unsafe_tokens = tokenizer(\n",
    "        adv2_prompt,\n",
    "        padding=\"max_length\",\n",
    "        max_length=tokenizer.model_max_length,\n",
    "        truncation=True,\n",
    "        return_tensors=\"pt\"\n",
    "    ).to(device)\n",
    "    unsafe_embedding = text_encoder(unsafe_tokens.input_ids)[0]\n",
    "    unsafe_embedding_ours = text_encoder_ours(unsafe_tokens.input_ids)[0]\n",
    "    \n",
    "    unsafe_uncond_sub = unsafe_embedding[0].view(-1)\n",
    "    unsafe_uncond_sub_norm = torch.norm(unsafe_uncond_sub, dim=-1, keepdim=True)\n",
    "    unsafe_uncond_sub_normalized = unsafe_uncond_sub / (unsafe_uncond_sub_norm + 1e-12)\n",
    "    unsafe_uncond_sub_sim = torch.sum(concept_embedding * unsafe_uncond_sub_normalized)\n",
    "    \n",
    "    unsafe_uncond_sub_ours = unsafe_embedding_ours[0].view(-1)\n",
    "    unsafe_uncond_sub_norm_ours = torch.norm(unsafe_uncond_sub_ours, dim=-1, keepdim=True)\n",
    "    unsafe_uncond_sub_normalized_ours = unsafe_uncond_sub_ours / (unsafe_uncond_sub_norm_ours + 1e-12)\n",
    "    unsafe_uncond_sub_sim_ours = torch.sum(concept_embedding * unsafe_uncond_sub_normalized_ours)\n",
    "    sims3.append(unsafe_uncond_sub_sim.item())\n",
    "    sims4.append(unsafe_uncond_sub_sim_ours.item())\n",
    "    \n",
    "for adv3_prompt in adv3_prompts:\n",
    "    unsafe_tokens = tokenizer(\n",
    "        adv3_prompt,\n",
    "        padding=\"max_length\",\n",
    "        max_length=tokenizer.model_max_length,\n",
    "        truncation=True,\n",
    "        return_tensors=\"pt\"\n",
    "    ).to(device)\n",
    "    unsafe_embedding = text_encoder(unsafe_tokens.input_ids)[0]\n",
    "    unsafe_embedding_ours = text_encoder_ours(unsafe_tokens.input_ids)[0]\n",
    "    \n",
    "    unsafe_uncond_sub = unsafe_embedding[0].view(-1)\n",
    "    unsafe_uncond_sub_norm = torch.norm(unsafe_uncond_sub, dim=-1, keepdim=True)\n",
    "    unsafe_uncond_sub_normalized = unsafe_uncond_sub / (unsafe_uncond_sub_norm + 1e-12)\n",
    "    unsafe_uncond_sub_sim = torch.sum(concept_embedding * unsafe_uncond_sub_normalized)\n",
    "    \n",
    "    unsafe_uncond_sub_ours = unsafe_embedding_ours[0].view(-1)\n",
    "    unsafe_uncond_sub_norm_ours = torch.norm(unsafe_uncond_sub_ours, dim=-1, keepdim=True)\n",
    "    unsafe_uncond_sub_normalized_ours = unsafe_uncond_sub_ours / (unsafe_uncond_sub_norm_ours + 1e-12)\n",
    "    unsafe_uncond_sub_sim_ours = torch.sum(concept_embedding * unsafe_uncond_sub_normalized_ours)\n",
    "    sims5.append(unsafe_uncond_sub_sim.item())\n",
    "    sims6.append(unsafe_uncond_sub_sim_ours.item())\n",
    "    \n",
    "for adv4_prompt in adv4_prompts:\n",
    "    unsafe_tokens = tokenizer(\n",
    "        adv4_prompt,\n",
    "        padding=\"max_length\",\n",
    "        max_length=tokenizer.model_max_length,\n",
    "        truncation=True,\n",
    "        return_tensors=\"pt\"\n",
    "    ).to(device)\n",
    "    unsafe_embedding = text_encoder(unsafe_tokens.input_ids)[0]\n",
    "    unsafe_embedding_ours = text_encoder_ours(unsafe_tokens.input_ids)[0]\n",
    "    \n",
    "    unsafe_uncond_sub = unsafe_embedding[0].view(-1)\n",
    "    unsafe_uncond_sub_norm = torch.norm(unsafe_uncond_sub, dim=-1, keepdim=True)\n",
    "    unsafe_uncond_sub_normalized = unsafe_uncond_sub / (unsafe_uncond_sub_norm + 1e-12)\n",
    "    unsafe_uncond_sub_sim = torch.sum(concept_embedding * unsafe_uncond_sub_normalized)\n",
    "    \n",
    "    unsafe_uncond_sub_ours = unsafe_embedding_ours[0].view(-1)\n",
    "    unsafe_uncond_sub_norm_ours = torch.norm(unsafe_uncond_sub_ours, dim=-1, keepdim=True)\n",
    "    unsafe_uncond_sub_normalized_ours = unsafe_uncond_sub_ours / (unsafe_uncond_sub_norm_ours + 1e-12)\n",
    "    unsafe_uncond_sub_sim_ours = torch.sum(concept_embedding * unsafe_uncond_sub_normalized_ours)\n",
    "    sims7.append(unsafe_uncond_sub_sim.item())\n",
    "    sims8.append(unsafe_uncond_sub_sim_ours.item())\n",
    "    \n",
    "for adv5_prompt in adv5_prompts:\n",
    "    unsafe_tokens = tokenizer(\n",
    "        adv5_prompt,\n",
    "        padding=\"max_length\",\n",
    "        max_length=tokenizer.model_max_length,\n",
    "        truncation=True,\n",
    "        return_tensors=\"pt\"\n",
    "    ).to(device)\n",
    "    unsafe_embedding = text_encoder(unsafe_tokens.input_ids)[0]\n",
    "    unsafe_embedding_ours = text_encoder_ours(unsafe_tokens.input_ids)[0]\n",
    "    \n",
    "    unsafe_uncond_sub = unsafe_embedding[0].view(-1)\n",
    "    unsafe_uncond_sub_norm = torch.norm(unsafe_uncond_sub, dim=-1, keepdim=True)\n",
    "    unsafe_uncond_sub_normalized = unsafe_uncond_sub / (unsafe_uncond_sub_norm + 1e-12)\n",
    "    unsafe_uncond_sub_sim = torch.sum(concept_embedding * unsafe_uncond_sub_normalized)\n",
    "    \n",
    "    unsafe_uncond_sub_ours = unsafe_embedding_ours[0].view(-1)\n",
    "    unsafe_uncond_sub_norm_ours = torch.norm(unsafe_uncond_sub_ours, dim=-1, keepdim=True)\n",
    "    unsafe_uncond_sub_normalized_ours = unsafe_uncond_sub_ours / (unsafe_uncond_sub_norm_ours + 1e-12)\n",
    "    unsafe_uncond_sub_sim_ours = torch.sum(concept_embedding * unsafe_uncond_sub_normalized_ours)\n",
    "    sims9.append(unsafe_uncond_sub_sim.item())\n",
    "    sims10.append(unsafe_uncond_sub_sim_ours.item())\n",
    "    \n",
    "\n",
    "bins = 100\n",
    "range = (-0.81, 0.81)\n",
    "\n",
    "plt.style.use('seaborn-v0_8-paper')\n",
    "plt.rcParams.update({\n",
    "    'font.family': 'serif',\n",
    "    'font.size': 9,\n",
    "    'axes.linewidth': 0.8,\n",
    "    'axes.labelsize': 9,\n",
    "    'xtick.labelsize': 9,\n",
    "    'ytick.labelsize': 9,\n",
    "})\n",
    "\n",
    "plt.figure(figsize=(7, 3), dpi=150)\n",
    "\n",
    "ori_labels = ['MMA (Org)', 'I2P (Org)', 'P4D (Org)', 'Ring-A-Bell (Org)', 'Sneaky (Org)']\n",
    "ori_colors = ['#006400', '#0000FF', '#FF8C00', '#FF0000', '#800080']\n",
    "\n",
    "des_labels = ['MMA (DES)', 'I2P (DES)', 'P4D (DES)', 'Ring-A-Bell (DES)', 'Sneaky (DES)']\n",
    "des_colors = ['#90EE90', '#ADD8E6', '#FFB366', '#FF9999', '#DDA0DD']\n",
    "\n",
    "for i, (data, color, label) in enumerate(zip([sims3, sims5, sims7, sims9, sims1], ori_colors, ori_labels)):\n",
    "    plt.hist(data, bins=bins, range=range, alpha=0.7, color=color,\n",
    "            label=label, edgecolor='white', linewidth=0.5)\n",
    "    \n",
    "    hist_counts, bin_edges = np.histogram(data, bins=bins, range=range)\n",
    "    bin_width = bin_edges[1] - bin_edges[0]\n",
    "    std = np.std(data)\n",
    "    bw_method = 0.45 * (0.05/std)\n",
    "    density = gaussian_kde(data, bw_method=bw_method)\n",
    "    xs = np.linspace(range[0], range[1], 200)\n",
    "    ys = density(xs)\n",
    "    ys = ys * bin_width * len(data)\n",
    "    plt.plot(xs, ys, color='black', linewidth=1, alpha=0.4)\n",
    "\n",
    "for i, (data, color, label) in enumerate(zip([sims4, sims6, sims8, sims10, sims2], des_colors, des_labels)):\n",
    "    plt.hist(data, bins=bins, range=range, alpha=0.7, color=color,\n",
    "            label=label, edgecolor='white', linewidth=0.5)\n",
    "    \n",
    "    hist_counts, bin_edges = np.histogram(data, bins=bins, range=range)\n",
    "    bin_width = bin_edges[1] - bin_edges[0]\n",
    "    std = np.std(data)\n",
    "    bw_method = 0.45 * (0.05/std)\n",
    "    density = gaussian_kde(data, bw_method=bw_method)\n",
    "    xs = np.linspace(range[0], range[1], 200)\n",
    "    ys = density(xs)\n",
    "    ys = ys * bin_width * len(data)\n",
    "    plt.plot(xs, ys, color='black', linewidth=1, alpha=0.4)\n",
    "\n",
    "plt.xlabel('Cosine Similarity with $e_c$', fontsize=12, labelpad=0)\n",
    "plt.ylabel('Frequency', fontsize=12, labelpad=3)\n",
    "plt.grid(True, alpha=0.2, linestyle='--')\n",
    "\n",
    "plt.legend(frameon=True, fancybox=True, shadow=True, \n",
    "          fontsize=8,\n",
    "          bbox_to_anchor=(0.99, 0.99), \n",
    "          loc='upper right',\n",
    "          borderaxespad=0,\n",
    "          handlelength=1.5,\n",
    "          handletextpad=0.5)\n",
    "\n",
    "plt.text(0.10, 120, 'Before DES', \n",
    "         bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=3),\n",
    "         fontsize=9)\n",
    "plt.text(-0.63, 80, 'After DES', \n",
    "         bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=3),\n",
    "         fontsize=9)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Figure 6b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import CLIPTokenizer, CLIPTextModel\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "class Args:\n",
    "    def __init__(self):\n",
    "        self.model_path = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n",
    "        self.device = \"cuda:0\"\n",
    "        self.unsafe_csv_path = \"datasets/unsafe_prompts_copro_sexual.csv\"\n",
    "        self.safe_csv_path = \"datasets/safe_prompts_copro_sexual.csv\"\n",
    "        self.adv1_csv_path = \"prompts/sneaky_prompts.csv\"\n",
    "        self.adv2_csv_path = \"prompts/mma_prompts.csv\"\n",
    "        self.adv3_csv_path = \"prompts/i2p_sexual_prompts.csv\"\n",
    "        self.adv4_csv_path = \"prompts/p4d_prompts.csv\"\n",
    "        self.adv5_csv_path = \"prompts/ringabell_prompts.csv\"\n",
    "        self.concept_prompt = \"nudity\"\n",
    "        self.concept_guidance_scale = 200.0\n",
    "        self.safe_embedding_path = \"checkpoints/safe_embeddings.pth\"\n",
    "        self.text_encoder_path = \"checkpoints/des.pt\"\n",
    "\n",
    "class Args:\n",
    "    def __init__(self):\n",
    "        self.model_path = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n",
    "        self.device = \"cuda:0\"\n",
    "        self.text_encoder_path = \"checkpoints/des.pt\"\n",
    "        self.datasets = {\n",
    "            'CoPro_Safe_seen': 'datasets/safe_prompts_copro_sexual.csv',\n",
    "            'CoPro_Unsafe_seen': 'datasets/unsafe_prompts_copro_sexual.csv',\n",
    "            'COCO': 'prompts/coco_prompts.csv',\n",
    "            'Sneaky': 'prompts/sneaky_prompts.csv',\n",
    "            'MMA': 'prompts/mma_prompts.csv',\n",
    "            'I2P': 'prompts/i2p_sexual_prompts.csv',\n",
    "            'P4D': 'prompts/p4d_prompts.csv',\n",
    "            'RingABell': 'prompts/ringabell_prompts.csv'\n",
    "        }\n",
    "\n",
    "def load_models(args):\n",
    "    tokenizer = CLIPTokenizer.from_pretrained(args.model_path, subfolder=\"tokenizer\")\n",
    "    text_encoder_ori = CLIPTextModel.from_pretrained(\n",
    "        args.model_path,\n",
    "        subfolder=\"text_encoder\"\n",
    "    ).to(args.device)\n",
    "    text_encoder_ours = CLIPTextModel.from_pretrained(\n",
    "        args.model_path,\n",
    "        subfolder=\"text_encoder\"\n",
    "    ).to(args.device)\n",
    "    \n",
    "    # Load Checkpoint of DES\n",
    "    checkpoint = torch.load(args.text_encoder_path, map_location=args.device)\n",
    "    text_encoder_ours.load_state_dict(checkpoint['model_state_dict'])\n",
    "    \n",
    "    text_encoder_ori.eval()\n",
    "    text_encoder_ours.eval()\n",
    "    \n",
    "    return tokenizer, text_encoder_ori, text_encoder_ours\n",
    "\n",
    "def load_prompt_data(args):\n",
    "    return {name: pd.read_csv(path).head(1000) for name, path in args.datasets.items()}\n",
    "\n",
    "def get_embeddings(prompt, tokenizer, encoder, device):\n",
    "    tokens = tokenizer(\n",
    "        prompt,\n",
    "        padding=\"max_length\",\n",
    "        max_length=tokenizer.model_max_length,\n",
    "        truncation=True,\n",
    "        return_tensors=\"pt\"\n",
    "    ).to(device)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        outputs = encoder(tokens.input_ids)\n",
    "        embedding = outputs[0].mean(dim=1)\n",
    "    return embedding.cpu().numpy().flatten()\n",
    "\n",
    "def process_all_prompts(datasets, tokenizer, text_encoder_ori, text_encoder_ours, device):\n",
    "    embeddings = []\n",
    "    labels = []\n",
    "    \n",
    "    for dataset_name, df in datasets.items():\n",
    "        # Original encoder\n",
    "        for prompt in df['prompt'].tolist():\n",
    "            embeddings.append(get_embeddings(prompt, tokenizer, text_encoder_ori, device))\n",
    "            labels.append(f\"{dataset_name} (Org)\")\n",
    "        \n",
    "        # DES encoder\n",
    "        for prompt in df['prompt'].tolist():\n",
    "            embeddings.append(get_embeddings(prompt, tokenizer, text_encoder_ours, device))\n",
    "            labels.append(f\"{dataset_name} (DES)\")\n",
    "    \n",
    "    return np.array(embeddings), labels\n",
    "\n",
    "def plot_all_embeddings(embeddings, labels, figsize=(9, 5), exclude_datasets=None, exclude_from_reduction=None):\n",
    "    plt.style.use('seaborn-v0_8-paper')\n",
    "    plt.rcParams.update({\n",
    "        'font.family': 'serif',\n",
    "        'font.size': 14,\n",
    "        'axes.linewidth': 0.8,\n",
    "        'axes.labelsize': 14,\n",
    "        'xtick.labelsize': 14,\n",
    "        'ytick.labelsize': 14,\n",
    "        'legend.fontsize': 14,\n",
    "        'legend.title_fontsize': 14,\n",
    "    })\n",
    "    fig, ax = plt.subplots(figsize=figsize, dpi=300)\n",
    "    \n",
    "    reducer = TSNE(n_components=2, random_state=42, perplexity=50)\n",
    "    \n",
    "    exclude_datasets = exclude_datasets or []\n",
    "    exclude_from_reduction = exclude_from_reduction or []\n",
    "    \n",
    "    reduction_mask = np.ones(len(labels), dtype=bool)\n",
    "    for dataset in exclude_from_reduction:\n",
    "        reduction_mask &= ~np.array([l.startswith(dataset) for l in labels])\n",
    "    \n",
    "    reduced_embeddings = np.zeros((len(embeddings), 2))\n",
    "    reduced_embeddings[reduction_mask] = reducer.fit_transform(embeddings[reduction_mask])\n",
    "    \n",
    "    colors = {\n",
    "        'Safe (Org)': '#4682B4',\n",
    "        'Safe (DES)': 'black',\n",
    "        'Unsafe (Org)': '#DC143C',\n",
    "        'Unsafe (DES)': '#F0E68C'\n",
    "    }\n",
    "    dataset_mapping = {\n",
    "        'COCO': 'Safe',\n",
    "        'MMA': 'Unsafe',\n",
    "        'Sneaky': 'Unsafe',\n",
    "        'P4D': 'Unsafe',\n",
    "        'RingABell': 'Unsafe',\n",
    "        'I2P': 'Unsafe',\n",
    "    }\n",
    "    exclude_datasets = exclude_datasets or []\n",
    "    \n",
    "    # COCO (Safe) plotting\n",
    "    mapped_name = dataset_mapping['COCO']\n",
    "    mask = np.array([l.endswith('(Org)') and l.startswith('COCO') for l in labels])\n",
    "    if any(mask):\n",
    "        ax.scatter(\n",
    "            reduced_embeddings[mask, 0],\n",
    "            reduced_embeddings[mask, 1],\n",
    "            c=colors[f'{mapped_name} (Org)'],\n",
    "            alpha=0.6,\n",
    "            marker='o',\n",
    "            s=35,\n",
    "            edgecolor='#483D8B',\n",
    "            linewidths=1.2,\n",
    "            label=f'{mapped_name} (Org)'\n",
    "        )\n",
    "\n",
    "    # COCO (Safe) After plotting\n",
    "    mask = np.array([l.endswith('(DES)') and l.startswith('COCO') for l in labels])\n",
    "    if any(mask):\n",
    "        ax.scatter(\n",
    "            reduced_embeddings[mask, 0],\n",
    "            reduced_embeddings[mask, 1],\n",
    "            c=colors[f'{mapped_name} (DES)'],\n",
    "            alpha=0.7,\n",
    "            marker='x',\n",
    "            s=15,\n",
    "            linewidths=1.2,\n",
    "            label=f'{mapped_name} (DES)'\n",
    "        )\n",
    "\n",
    "    unsafe_datasets = ['MMA', 'I2P', 'Sneaky', 'P4D', 'RingABell']\n",
    "    plot_dataset = next((dataset for dataset in unsafe_datasets \n",
    "                        if dataset not in (exclude_datasets or [])), None)\n",
    "    \n",
    "    if plot_dataset:\n",
    "        # Before plotting\n",
    "        mask = np.array([l.endswith('(Org)') and l.startswith(plot_dataset) for l in labels])\n",
    "        if any(mask):\n",
    "            ax.scatter(\n",
    "                reduced_embeddings[mask, 0],\n",
    "                reduced_embeddings[mask, 1],\n",
    "                c=colors['Unsafe (Org)'],\n",
    "                alpha=0.6,\n",
    "                marker='^',\n",
    "                s=30,\n",
    "                edgecolor='#8B0000',\n",
    "                linewidths=1.2,\n",
    "                label=f'{plot_dataset} (Org)'\n",
    "            )\n",
    "        \n",
    "        # After plotting\n",
    "        mask = np.array([l.endswith('(DES)') and l.startswith(plot_dataset) for l in labels])\n",
    "        if any(mask):\n",
    "            ax.scatter(\n",
    "                reduced_embeddings[mask, 0],\n",
    "                reduced_embeddings[mask, 1],\n",
    "                c=colors['Unsafe (DES)'],\n",
    "                alpha=0.6,\n",
    "                marker='*',\n",
    "                s=30,\n",
    "                edgecolor='#DAA520',\n",
    "                linewidths=1.2,\n",
    "                label=f'{plot_dataset} (DES)'\n",
    "            )\n",
    "\n",
    "    handles, labels = ax.get_legend_handles_labels()\n",
    "    \n",
    "    desired_order = [\n",
    "        'Safe (Org)', \n",
    "        'Safe (DES)', \n",
    "        f'{plot_dataset} (Org)' if plot_dataset else 'Unsafe (Org)',\n",
    "        f'{plot_dataset} (DES)' if plot_dataset else 'Unsafe (DES)'\n",
    "    ]\n",
    "    \n",
    "    ordered_pairs = []\n",
    "    for label in desired_order:\n",
    "        idx = labels.index(label)\n",
    "        ordered_pairs.append((handles[idx], label))\n",
    "    \n",
    "    handles, labels = zip(*ordered_pairs)\n",
    "    \n",
    "    ax.legend(handles, labels, \n",
    "             bbox_to_anchor=(0.99, 0.99),\n",
    "             loc='upper right',\n",
    "             borderaxespad=0,\n",
    "             frameon=True,\n",
    "             fancybox=True,\n",
    "             shadow=True,\n",
    "             framealpha=1.0,\n",
    "             edgecolor='lightgray',\n",
    "             fontsize=14,\n",
    "             ncol=1,\n",
    "             handletextpad=0.2,\n",
    "             borderpad=0.2,\n",
    "             )\n",
    "    \n",
    "    ax.grid(True, linestyle='--', alpha=0.2, color='gray', linewidth=0.5)\n",
    "    ax.set_facecolor('#ffffff')\n",
    "    fig.patch.set_facecolor('#ffffff')\n",
    "    ax.set_xlabel('')\n",
    "    ax.set_ylabel('')\n",
    "    \n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_linewidth(0.8)\n",
    "        spine.set_color('#333333')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    return fig, ax\n",
    "\n",
    "def main():\n",
    "    args = Args()\n",
    "    tokenizer, text_encoder_ori, text_encoder_ours = load_models(args)\n",
    "    datasets = load_prompt_data(args)\n",
    "    \n",
    "    embeddings, labels = process_all_prompts(\n",
    "        datasets, tokenizer, text_encoder_ori, text_encoder_ours, args.device\n",
    "    )\n",
    "    \n",
    "    scaler = StandardScaler()\n",
    "    embeddings_normalized = scaler.fit_transform(embeddings)\n",
    "    \n",
    "    ### Figure 6b\n",
    "    fig, ax = plot_all_embeddings(embeddings_normalized, labels, \n",
    "                                  exclude_datasets=['CoPro_Safe_seen', 'CoPro_Unsafe_seen', 'Sneaky', 'P4D', 'I2P', 'RingABell'],\n",
    "                                  exclude_from_reduction=['CoPro_Safe_seen', 'CoPro_Unsafe_seen'], figsize=(9, 4))\n",
    "    plt.show()\n",
    "    \n",
    "    ### else: Figure 19\n",
    "    # fig, ax = plot_all_embeddings(embeddings_normalized, labels, \n",
    "    #                               exclude_datasets=['CoPro_Safe_seen', 'CoPro_Unsafe_seen', 'P4D', 'MMA', 'I2P', 'RingABell'],\n",
    "    #                               exclude_from_reduction=['CoPro_Safe_seen', 'CoPro_Unsafe_seen'])\n",
    "    # plt.show()\n",
    "    # fig, ax = plot_all_embeddings(embeddings_normalized, labels, \n",
    "    #                               exclude_datasets=['CoPro_Safe_seen', 'CoPro_Unsafe_seen', 'Sneaky', 'MMA', 'P4D', 'RingABell'],\n",
    "    #                               exclude_from_reduction=['CoPro_Safe_seen', 'CoPro_Unsafe_seen'])\n",
    "    # plt.show()\n",
    "    # fig, ax = plot_all_embeddings(embeddings_normalized, labels, \n",
    "    #                               exclude_datasets=['CoPro_Safe_seen', 'CoPro_Unsafe_seen', 'Sneaky', 'MMA', 'I2P', 'RingABell'],\n",
    "    #                               exclude_from_reduction=['CoPro_Safe_seen', 'CoPro_Unsafe_seen'])\n",
    "    # plt.show()\n",
    "    # fig, ax = plot_all_embeddings(embeddings_normalized, labels, \n",
    "    #                               exclude_datasets=['CoPro_Safe_seen', 'CoPro_Unsafe_seen', 'Sneaky', 'MMA', 'I2P', 'P4D'],\n",
    "    #                               exclude_from_reduction=['CoPro_Safe_seen', 'CoPro_Unsafe_seen'])\n",
    "    # plt.show()\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sd3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
