{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2eadcdd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2915d1a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "\n",
    "def create_image_grid(original_images, mitigated_images, adv_images, images_per_line=7, save_path='./plots/image.png'):\n",
    "    num_rows = min(len(mitigated_images), len(adv_images))\n",
    "\n",
    "    # Add 2 spacer columns: one after original, one between mitigated and adv\n",
    "    total_columns = images_per_line + 2  # 7 images + 2 spacers\n",
    "    width_ratios = [1, 0.2] + [1] * (images_per_line // 2) + [0.2] + [1] * (images_per_line // 2)\n",
    "\n",
    "    fig = plt.figure(figsize=(total_columns * 1.5, num_rows * 2))\n",
    "    gs = gridspec.GridSpec(num_rows, total_columns, width_ratios=width_ratios, wspace=0.05, hspace=0.02)\n",
    "\n",
    "    for row in range(num_rows):\n",
    "        # Original image (column 0)\n",
    "        ax = plt.subplot(gs[row, 0])\n",
    "        ax.imshow(original_images[row])\n",
    "        ax.axis('off')\n",
    "        if row == 0:\n",
    "            ax.set_title('Original', fontsize=16, weight='bold')\n",
    "\n",
    "        # Spacer at column 1 is skipped\n",
    "\n",
    "        # Mitigated images (start at column 2)\n",
    "        for col in range(images_per_line // 2):\n",
    "            ax = plt.subplot(gs[row, col + 2])\n",
    "            ax.imshow(mitigated_images[row][col])\n",
    "            ax.axis('off')\n",
    "            if row == 0 and col == 1:\n",
    "                ax.set_title('Mitigated', fontsize=16, weight='bold')\n",
    "\n",
    "        # Spacer after mitigated images (index: 2 + images_per_line // 2)\n",
    "        # So adversarial starts at: 3 + images_per_line // 2\n",
    "\n",
    "        for col in range(images_per_line // 2):\n",
    "            ax = plt.subplot(gs[row, col + 3 + images_per_line // 2])\n",
    "            ax.imshow(adv_images[row][col])\n",
    "            ax.axis('off')\n",
    "            if row == 0 and col == 1:\n",
    "                ax.set_title('Mitigated + Adv. Embedding', fontsize=16, weight='bold')\n",
    "\n",
    "    # fig.subplots_adjust(top=0.92, bottom=0.05, left=0.03, right=0.97)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "    ax.get_figure().savefig(save_path, dpi=150, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1b0146c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "from PIL import Image\n",
    "\n",
    "def get_first_images(path, img_index, num_images=6):\n",
    "    image_paths = glob.glob(f'{path}/img_{img_index:04d}_*.jpg')\n",
    "    sorted(image_paths)\n",
    "    return [Image.open(x) for x in image_paths[:num_images]]\n",
    "\n",
    "def get_originals(path, img_index):\n",
    "    image_paths = glob.glob(f'{path}/{img_index:04d}_*.png')\n",
    "    return Image.open(image_paths[0]).convert('RGB')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff9e0fa1",
   "metadata": {},
   "source": [
    "# Wanda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb7e422d",
   "metadata": {},
   "outputs": [],
   "source": [
    "mitigated_images = [get_first_images('generated_images/wanda_mitigation_0.01', i) for i in range(10, 20)]\n",
    "adv_images = [get_first_images('generated_images/wanda_adv_all_prompts_50', i) for i in range(10, 20)]\n",
    "original_images = [get_originals('prompts/memorized_images', i) for i in range(10, 20)]\n",
    "\n",
    "create_image_grid(\n",
    "    original_images,\n",
    "    mitigated_images,\n",
    "    adv_images,\n",
    "    save_path='./plots/wanda_mitigation_and_adv.png',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "016a179b",
   "metadata": {},
   "outputs": [],
   "source": [
    "mitigated_images = [get_first_images('generated_images/wanda_mitigation_all_sparsity_sweep/wanda_mitigation_all_0.10_sparsity', i) for i in range(10, 20)]\n",
    "adv_images = [get_first_images('generated_images/wanda_adv_all_prompts_sparsity_sweep/wanda_sparsity_0.10', i) for i in range(10, 20)]\n",
    "\n",
    "create_image_grid(\n",
    "    original_images,\n",
    "    mitigated_images,\n",
    "    adv_images,\n",
    "    save_path='./plots/wanda_mitigation_adv_0.1_sparsity.png',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c33adfb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1edb72b5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
