{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4a4ae8a4-7603-475c-a3b1-035de1396021",
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import (\n",
    "    StableDiffusionPipeline,\n",
    "    UNet2DConditionModel,\n",
    "    DPMSolverMultistepScheduler,\n",
    ")\n",
    "\n",
    "from arc2face import CLIPTextModelWrapper, project_face_embs\n",
    "from IPython.display import Image, display, HTML\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "\n",
    "import matplotlib.image as mpimg\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "\n",
    "sns.set_theme()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3c8c7744-38bb-47ae-aa9d-33bcdca46cb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_embeddings():\n",
    "    return np.load(f'RFW_arcface.npy').squeeze()\n",
    "def get_names():\n",
    "    return np.load('RFW_names.npy')\n",
    "def get_ids():\n",
    "    return np.load('RFW_ids.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9cb0002-8884-4225-be73-a6d446691bdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Arc2Face is built upon SD1.5\n",
    "# The repo below can be used instead of the now deprecated 'runwayml/stable-diffusion-v1-5'\n",
    "base_model = 'stable-diffusion-v1-5/stable-diffusion-v1-5'\n",
    "\n",
    "encoder = CLIPTextModelWrapper.from_pretrained(\n",
    "    'models', subfolder=\"encoder\", torch_dtype=torch.float16\n",
    ")\n",
    "\n",
    "unet = UNet2DConditionModel.from_pretrained(\n",
    "    'models', subfolder=\"arc2face\", torch_dtype=torch.float16\n",
    ")\n",
    "\n",
    "pipeline = StableDiffusionPipeline.from_pretrained(\n",
    "        base_model,\n",
    "        text_encoder=encoder,\n",
    "        unet=unet,\n",
    "        torch_dtype=torch.float16,\n",
    "        safety_checker=None\n",
    "    )\n",
    "\n",
    "pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n",
    "pipeline = pipeline.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "694d43ac-4892-4254-911e-35a8816196ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "RFW_embeddings = get_embeddings()\n",
    "RFW_names = get_names()\n",
    "RFW_ids = get_ids()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6a41c030-c351-401a-839c-d89cd1fb396c",
   "metadata": {},
   "outputs": [],
   "source": [
    "celeba_embeddings = np.load('CelebA_arcface.npy').squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a864714c-cc54-4f5f-aa0b-dff579ae054b",
   "metadata": {},
   "outputs": [],
   "source": [
    "celeba_names = np.load('CelebA_names.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "189c92a4-8742-44ac-ac8a-01703979c557",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_direction(id_labels, embeddings):\n",
    "    # Map labels to integers and get counts in one pass.\n",
    "    unique_labels, integer_labels, counts = np.unique(id_labels, return_inverse=True, return_counts=True)\n",
    "    # Compute per-sample weights: each sample gets weight = 1 / (number of times its label occurs).\n",
    "    weights = 1 / counts[integer_labels]\n",
    "    # Compute the weighted sum using @\n",
    "    v = weights.astype(np.float32) @ embeddings\n",
    "    # The sum of weights is equal to the number of unique labels, so we normalize by that\n",
    "    v /= unique_labels.size\n",
    "    # Normalize the vector v\n",
    "    return v / np.linalg.norm(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "86199cf8-cbd8-41c9-b9a3-bfc2a2819eae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def slerp(u, v, t):\n",
    "    dot = u @ v\n",
    "    theta = np.arccos(dot)\n",
    "    if theta < 1e-6:\n",
    "        return normalize((1 - t) * u + t * v)\n",
    "    return (np.sin((1 - t) * theta) * u + np.sin(t * theta) * v) / np.sin(theta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ba88e0d-9902-4523-940e-9c8ffb088693",
   "metadata": {},
   "outputs": [],
   "source": [
    "group_A = [26908, 26909, 26907, 26910, 20509, 20510, 20511, 22743, 22739, 22740, 22742, 22741, 27135, 27133, 23459, 25958, 25957, 26618, 26619, 26616, 26617, 29425, \n",
    "           20517, 20518, 20519, 20520, 20521, 27992, 27991, 27990, 25765, 22002, 22000, 22001, 20922, 28561, 28559, 28562, 28563, 21348, 21347, 21349, 29214, 29213, \n",
    "           29215, 27222, 33210, 33209, 33211, 23408, 23410, 23409, 23407, 21999, 21208, 21209, 21210, 21211, 21212, 21213, 21214, 11631, 24685, 24686, 24684, 26161, \n",
    "           26163, 26162, 24017, 24018, 22521, 22525, 22522, 22523, 28501, 28503, 28500, 28499, 28502, 23230, 23229, 23227, 21614, 21615, 21616, 21617, 21618, 21619, \n",
    "           27292, 27291, 27290, 24002, 24003, 26938, 27964, 27966, 27965, 23796, 23795, 23797, 26936, 26937, 25600, 29931, 29933, 29932, 29889, 29886, 29888, 29887, \n",
    "           25615, 25613, 25616, 25614, 11632, 20891, 25183, 26485, 26487, 26488, 26486, 28598, 25332, 27559, 27115, 24839, 24840, 24837, 24838, 24841]\n",
    "group_B = [[34890, 34891, 34892, 34893, 36533, 36531, 36534, 36532, 36733, 36731, 36732, 36169, 36168, 36170, 36356, 34440, 36862, 36861, 34441, 30448, 30446, 30449, \n",
    "            40548, 40547, 40549, 34705, 36309, 36310, 31395, 40546, 31890, 38622, 31893, 31396, 31891, 31892, 37375, 34571, 34573, 34056, 34058, 34059, 34057, 38281, \n",
    "            38279, 38280, 35974, 35815, 35813, 35812, 35814, 38625, 35973, 34060, 38624, 39952, 39953, 39956, 35566, 35567, 34706, 36066, 36064, 39955, 34299, 33862, \n",
    "            33623, 33627, 33625, 35076, 35077, 38492, 38494, 33624, 34449, 38482, 33030, 33029, 33863, 36795, 34707, 35972, 36797, 39954, 36796, 36273, 36274, 33864, \n",
    "            36065, 34439, 39207, 39206, 35381, 35568, 38479, 35565, 38186, 34448, 34450, 34451, 31348, 32844, 32843, 32076, 36920, 33535, 33532, 33534, 37553, 32077, \n",
    "            40193, 34243, 34244, 33533, 31981, 40195, 34865, 34864, 34246, 33116, 34245, 37554, 38800, 30882, 37434, 37432, 38493, 33117, 37396, 37430, 30880, 40194, \n",
    "            40151, 34863, 31266, 37874, 37876, 31985, 31817, 35078, 31979, 31980, 31983]]\n",
    "group_C = [[37762, 37763, 37764, 37766, 38368, 38370, 38367, 38369, 36675, 36676, 36673, 36674, 30745, 30747, 30746, 39575, 39576, 39577, 37465, 33406, 37466, 37467, \n",
    "            39712, 39714, 35351, 39713, 34040, 34041, 37837, 37840, 37839, 35349, 35156, 35154, 39728, 40123, 40125, 35155, 36324, 36326, 36325, 36327, 33339, 33341,\n",
    "            33340, 32074, 32072, 35222, 35139, 35143, 35140, 34002, 35557, 35556, 35346, 35347, 37838, 35141, 30727, 38783, 38782, 35348, 30726, 38384, 31958, 35555, \n",
    "            32073, 32075, 35350, 39502, 39503, 32047, 32044, 32046, 32045, 38044, 38043, 33493, 33494, 32186, 32184, 32185, 39124, 39126, 39123, 39467, 39471, 31195, \n",
    "            39501, 31427, 31194, 39468, 38383, 30345, 30346, 35704]]\n",
    "group_D = [[10982, 16479, 16478, 19744, 19741, 19743, 19742, 19740, 10960, 10962, 12014, 12013, 14831, 15628, 15625, 15627, 15626, 14830, 13939, 13940, 15363, 15362, \n",
    "            15361, 11305, 11302, 13941, 13938, 17706, 11304, 13509, 13507, 12190, 15364, 12456, 12454, 13942, 12455, 17704, 16095, 16093, 13535, 13506, 12457, 12035, \n",
    "            13704, 12037, 12191, 14833, 17945, 13565, 17707, 19736, 14832, 19734, 19735, 19477, 19281, 19280, 16444, 16974, 14267, 16443, 18446, 18447, 19478, 11258, \n",
    "            11260, 11261, 18201, 16094, 13906, 11341, 11291, 11290, 11340, 11259, 12189, 14266, 12192, 17944, 14728, 13566, 16442, 16408, 16410, 14992, 16689, 11339, \n",
    "            17943, 19476, 13563, 19479, 11306, 10970, 11303, 13567, 11292, 16409, 17705, 13564]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "967ceb6a-a657-4389-a8d2-07726bbcaea2",
   "metadata": {},
   "outputs": [],
   "source": [
    "v = get_direction(RFW_ids[Group_D], RFW_embeddings[Group_D])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ef5f53e-3b1a-465c-a2dc-4e57c666a82b",
   "metadata": {},
   "source": [
    "## Addition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2b5d0a9b-f9cb-45a1-8c7a-3d235e7af0c2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div style=\"display:flex;align-items:center;justify-content:center;\"><div style=\"margin: 2px; text-align: center;\"><img src=\"../../ICCV 2025/CelebA/img_align_celeba/161750.jpg\" style=\"height: 300px;\"></div><div style=\"margin: 2px; text-align: center;\"><img src=\"../../ICCV 2025/CelebA/img_align_celeba/182508.jpg\" style=\"height: 300px;\"></div><div style=\"margin: 2px; text-align: center;\"><img src=\"../../ICCV 2025/CelebA/img_align_celeba/098626.jpg\" style=\"height: 300px;\"></div><div style=\"margin: 2px; text-align: center;\"><img src=\"../../ICCV 2025/CelebA/img_align_celeba/090562.jpg\" style=\"height: 300px;\"></div><div style=\"margin: 2px; text-align: center;\"><img src=\"../../ICCV 2025/CelebA/img_align_celeba/071651.jpg\" style=\"height: 300px;\"></div><div style=\"margin: 2px; text-align: center;\"><img src=\"../../ICCV 2025/CelebA/img_align_celeba/100824.jpg\" style=\"height: 300px;\"></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "image_filenames = [celeba_names[35087], celeba_names[34529], celeba_names[34453], celeba_names[185854], celeba_names[165802], celeba_names[9]]\n",
    "images_html = \"\".join(\n",
    "    f'<div style=\"margin: 2px; text-align: center;\"><img src=\"{image_path}\" style=\"height: 300px;\"></div>'\n",
    "    for image_path in image_filenames\n",
    ")\n",
    "html = f'<div style=\"display:flex;align-items:center;justify-content:center;\">{images_html}</div>'\n",
    "display(HTML(html))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24ded276-9761-4501-ae8b-71fe929d7816",
   "metadata": {},
   "outputs": [],
   "source": [
    "for face in [35087, 34529, 34453, 185854, 165802, 9]:\n",
    "    e = celeba_embeddings[face]\n",
    "    \n",
    "    id_emb = torch.tensor(e, dtype=torch.float16)\n",
    "    id_emb = id_emb.unsqueeze(0)\n",
    "    id_emb = id_emb.to('cuda:0')\n",
    "    id_emb = project_face_embs(pipeline, id_emb)\n",
    "    images = pipeline(prompt_embeds=id_emb, num_inference_steps=25, guidance_scale=3.0, num_images_per_prompt=4).images\n",
    "    for i, img in enumerate(images):\n",
    "        img.save(f'/Plots/{face}_original_{i}.png', 'PNG')\n",
    "\n",
    "    # Interpolation: from identity (img) toward \"discovered\" direction d\n",
    "    edit = slerp(e, v, t=0.45)  # t ∈ [0, 1] controls how much \"strength\"\n",
    "    \n",
    "    id_emb = torch.tensor(edit, dtype=torch.float16)\n",
    "    id_emb = id_emb.unsqueeze(0)\n",
    "    id_emb = id_emb.to('cuda:0')\n",
    "    id_emb = project_face_embs(pipeline, id_emb)\n",
    "    images = pipeline(prompt_embeds=id_emb, num_inference_steps=25, guidance_scale=3.0, num_images_per_prompt=4).images\n",
    "    for i, img in enumerate(images):\n",
    "        img.save(f'/Plots/{face}_addition_{i}.png', 'PNG')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfbf769e-6432-41a5-b2a7-d9246f8e4d9e",
   "metadata": {},
   "source": [
    "## Subtraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c97c8c82-fcc6-411b-bcc0-ba2830709128",
   "metadata": {},
   "outputs": [],
   "source": [
    "images_to_plot = [\n",
    "    [24294, 21175, 24684, 24297, 29843, 24686, 21530, 21446, 21533, 29495, 21531, 29841, 24295, 21535],\n",
    "    [34041, 37763, 33339, 32074, 34040, 31958, 35139, 30747, 39728, 38783, 39502, 30746, 37465, 30727],\n",
    "    [36862, 31395, 39952, 39953, 34440, 39956, 40548, 35076, 34299, 36168, 35814, 40547, 36273, 37375],\n",
    "    [12456, 11305, 13535, 10962, 15362, 11304, 17706, 14830, 13938, 15627, 19744, 13942, 12014, 17704],\n",
    "]\n",
    "groups = [group_A, group_B, group_C, group_D]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0377187-e712-4458-9fee-87d0c07d5686",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in range(4):\n",
    "    v = get_direction(RFW_ids[groups[k]], RFW_embeddings[groups[k]])\n",
    "    for face in images_to_plot[k]:\n",
    "        e = RFW_embeddings[img]\n",
    "        id_emb = torch.tensor(e, dtype=torch.float16)\n",
    "        id_emb = id_emb.unsqueeze(0)\n",
    "        id_emb = id_emb.to('cuda:0')\n",
    "        id_emb = project_face_embs(pipeline, id_emb)\n",
    "        images = pipeline(prompt_embeds=id_emb, num_inference_steps=25, guidance_scale=3.0, num_images_per_prompt=4).images\n",
    "        for i, img in enumerate(images):\n",
    "            img.save(f'/Plots/{face}_original_{i}.png', 'PNG')\n",
    "        \n",
    "        \n",
    "        # Interpolation: from identity (img) toward opposite \"discovered\" direction d\n",
    "        edit = slerp(e, v, t=-0.5)  # t ∈ [0, 1] controls how much \"strength\"\n",
    "        id_emb = torch.tensor(edit, dtype=torch.float16)\n",
    "        id_emb = id_emb.unsqueeze(0)\n",
    "        id_emb = id_emb.to('cuda:0')\n",
    "        id_emb = project_face_embs(pipeline, id_emb)\n",
    "        images = pipeline(prompt_embeds=id_emb, num_inference_steps=25, guidance_scale=3.0, num_images_per_prompt=4).images\n",
    "        for i, img in enumerate(images):\n",
    "            img.save(f'/Plots/{face}_subtraction_{i}.png', 'PNG')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04db5f53-d5dc-4e58-b1e8-bafe62e2471e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
