{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19e3e810",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import CLIPTextModel, CLIPTokenizer\n",
    "import torch\n",
    "import pandas as pd\n",
    "import argparse\n",
    "import os\n",
    "import numpy as np\n",
    "import random\n",
    "import csv"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92c24841",
   "metadata": {},
   "source": [
    "# Parameters of Gene-Algo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "530638a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "population_size = 200\n",
    "generation = 3000\n",
    "mutateRate = 0.25\n",
    "crossoverRate = 0.5\n",
    "length = 16 # for K = 77, please set length = 75\n",
    "cof = 3\n",
    "path_Nudity_vector = './stable-diffusion/eval/Nudity_vector.npy'\n",
    "path_Violence_vector = './stable-diffusion/eval/Violence_vector.npy'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1859284",
   "metadata": {},
   "source": [
    "# Load Text-Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ad41382",
   "metadata": {},
   "outputs": [],
   "source": [
    "dir_ = \"CompVis/stable-diffusion-v1-4\" # all the erasure models built on SDv1-4\n",
    "torch_device = device = 'cuda'\n",
    "tokenizer = CLIPTokenizer.from_pretrained(dir_, subfolder=\"tokenizer\")\n",
    "text_encoder = CLIPTextModel.from_pretrained(dir_, subfolder=\"text_encoder\")\n",
    "text_encoder.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40afd1ff",
   "metadata": {},
   "source": [
    "# Load Prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f102f837",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('./stable-diffusion/data/unsafe-prompts4703.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "358484e7",
   "metadata": {},
   "source": [
    "# Functions of Gene-Algo."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7707faa",
   "metadata": {},
   "source": [
    "💡<span style='color:blue'> **[Note]**</span> <span style='font-family= Times New Roman'>ID 49406 = <|startoftext|>, ID 49407 = <|endoftext|>\n",
    "<br>\n",
    "    \n",
    "The structure of each prompt is [<span style='color:blue'> 49406</span>, 75 other IDs, <span style='color:blue'>49407</span>]\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16997253",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fitness(population):\n",
    "    dummy_tokens = torch.cat(population, 0)\n",
    "    dummy_embed = text_encoder(dummy_tokens.to(device))[0] \n",
    "    losses = ((targetEmbed - dummy_embed) ** 2).sum(dim=(1,2))\n",
    "    return losses.cpu().detach().numpy()\n",
    "\n",
    "def crossover(parents, crossoverRate):\n",
    "    new_population = []\n",
    "    for i in range(len(parents)):\n",
    "        new_population.append(parents[i])\n",
    "        if random.random() < crossoverRate:\n",
    "            idx = np.random.randint(0, len(parents), size=(1,))[0]\n",
    "            crossover_point = np.random.randint(1, length+1, size=(1,))[0] ##Because idx 0 is 49406, random ids are from idx 1 to idx length +1.\n",
    "            new_population.append(torch.concat((parents[i][:,:crossover_point],parents[idx][:,crossover_point:]), 1))\n",
    "            new_population.append(torch.concat((parents[idx][:,:crossover_point],parents[i][:,crossover_point:]), 1))\n",
    "            \n",
    "    return new_population\n",
    "        \n",
    "def mutation(population, mutateRate):\n",
    "    for i in range(len(population)):\n",
    "        if random.random() < mutateRate:\n",
    "            idx = np.random.randint(1, length+1, size=(1,)) ##Because idx 0 is 49406, random ids are from idx 1 to idx length +1.\n",
    "            value = np.random.randint(1, 49406, size=(1))[0] # Choose meaningful (Avoid selecting Token Ids 0, 40496, and 40497.)\n",
    "            population[i][:,idx] = value\n",
    "    return population\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "056242df",
   "metadata": {},
   "source": [
    "# Find Problematic Prompts -- Nudity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5b82c86",
   "metadata": {},
   "source": [
    "<li> Line 2: Choose prompts with nudity percentages exceeding 50%. </li>\n",
    "<li> Line 7: [startoftext, random tokenizer ids (with the length defined above), padding with endoftext ] (total length is 77)</li>\n",
    "<li> Line 10: Sorting fitness values from smallest to largest. </li>\n",
    "<li> Line 11: Select the top half of the population.</li>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4219249",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for _, row in df.iterrows():\n",
    "    if float(row.nudity_percentage)>50:\n",
    "        prompt = row.prompt\n",
    "        text_input = tokenizer(prompt, padding=\"max_length\", max_length=tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n",
    "        targetEmbed = text_encoder(text_input.input_ids.to(torch_device))[0] + cof * torch.from_numpy(np.load(path_Nudity_vector)).to(device)\n",
    "        targetEmbed = targetEmbed.detach().clone()\n",
    "        population = [torch.concat((torch.from_numpy(np.array([[49406]])),torch.randint(low=1, high=49406, size=(1,length)),torch.tile(torch.from_numpy(np.array([[49407]])),[1,76-length])),1) for i in range(population_size)]\n",
    "        for step in range(generation):\n",
    "            score = fitness(population)\n",
    "            idx = np.argsort(score)\n",
    "            population = [population[index] for index in idx][:population_size//2] \n",
    "            if step != generation - 1:\n",
    "                new_popu = crossover(population, crossoverRate)\n",
    "                population = mutation(new_popu, mutateRate)\n",
    "            if step % 50 == 0:\n",
    "                print(f\"[Info]: Nudity_cof_{cof}_length_{length}\")\n",
    "                print(f\"Iteration {step+1}, minium loss: {score[idx[0]]}\")\n",
    "                \n",
    "        with open(f'data/InvPrompt/Nudity_{cof}_length_{length}.csv', 'a') as f:\n",
    "            writer = csv.writer(f)\n",
    "            invPrompt = tokenizer.decode(population[0][0][1:length+1])\n",
    "            print(invPrompt)\n",
    "            writer.writerow([invPrompt])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2861c703",
   "metadata": {},
   "source": [
    "# Find Problematic Prompts -- Violence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9f23363",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, row in df.iterrows():\n",
    "    categorical = row.categories.split(',')\n",
    "    if 'violence' in categorical:\n",
    "        if float(row.nudity_percentage) < 50 and float(row.inappropriate_percentage) > 50 and int(row.hard) == 1:\n",
    "            prompt = row.prompt\n",
    "            text_input = tokenizer(prompt, padding=\"max_length\", max_length=tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n",
    "            targetEmbed = text_encoder(text_input.input_ids.to(torch_device))[0] + cof * torch.from_numpy(np.load(path_Violence_vector)).to(device)\n",
    "            targetEmbed = targetEmbed.detach().clone()\n",
    "            population = [torch.concat((torch.from_numpy(np.array([[49406]])),torch.randint(low=1, high=49406, size=(1,length)),torch.tile(torch.from_numpy(np.array([[49407]])),[1,76-length])),1) for i in range(population_size)]\n",
    "            for step in range(generation):\n",
    "                score = fitness(population)\n",
    "                idx = np.argsort(score)\n",
    "                population = [population[index] for index in idx][:population_size//2] \n",
    "                if step != generation - 1:\n",
    "                    new_popu = crossover(population, crossoverRate)\n",
    "                    population = mutation(new_popu, mutateRate)\n",
    "                if step % 50 == 0:\n",
    "                    print(f\"[Info]: Violence_cof_{cof}_length_{length}\")\n",
    "                    print(f\"Iteration {step+1}, minium loss: {score[idx[0]]}\")\n",
    "\n",
    "            with open(f'data/InvPrompt/Violence_{cof}_length_{length}.csv', 'a') as f:\n",
    "                writer = csv.writer(f)\n",
    "                invPrompt = tokenizer.decode(population[0][0][1:length+1])\n",
    "                print(invPrompt)\n",
    "                writer.writerow([invPrompt])\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
