{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ee352966",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.image as mpimg\n",
    "\n",
    "def make_image_grid_matplotlib(root_dir, labels, ids, cell_size=(3.2, 3.2), style='cyberpunk'):\n",
    "    \"\"\"\n",
    "    Rows = ids, columns = labels.\n",
    "    Expects: {root_dir}/{label}00_0.000/{id}.png\n",
    "    \"\"\"\n",
    "    labels = [str(l) for l in labels]\n",
    "    nrows, ncols = len(ids), len(labels)\n",
    "    figsize = (cell_size[0] * ncols, cell_size[1] * nrows)\n",
    "\n",
    "    # Zero spacing between subplots\n",
    "    fig, axes = plt.subplots(\n",
    "        nrows, ncols, figsize=figsize,\n",
    "        gridspec_kw={\"wspace\": 0.0, \"hspace\": 0.0}\n",
    "    )\n",
    "\n",
    "    # Ensure axes is always 2D\n",
    "    if nrows == 1 and ncols == 1:\n",
    "        axes = axes.reshape(1, 1)\n",
    "    elif nrows == 1:\n",
    "        axes = axes.reshape(1, -1)\n",
    "    elif ncols == 1:\n",
    "        axes = axes.reshape(-1, 1)\n",
    "\n",
    "    for r, id_ in enumerate(ids):\n",
    "        for c, lab in enumerate(labels):\n",
    "            ax = axes[r, c]\n",
    "            path = os.path.join(root_dir, f\"{lab}00_0.000\", f\"{id_}.png\")\n",
    "            img = mpimg.imread(path)\n",
    "            ax.imshow(img)\n",
    "            ax.axis(\"off\")  # no ticks/borders\n",
    "            # bottom-left label box\n",
    "            ax.text(\n",
    "                6, img.shape[0] - 10, lab,\n",
    "                color=\"white\", fontsize=12, fontweight=\"bold\",\n",
    "                bbox=dict(facecolor=\"black\", alpha=0.6, pad=3)\n",
    "            )\n",
    "\n",
    "    # Remove outer figure margins\n",
    "    fig.subplots_adjust(left=0, right=1, top=1, bottom=0)\n",
    "\n",
    "    # IMPORTANT: do NOT call tight_layout(); it reintroduces padding.\n",
    "    out_base = f\"style_visualize_{style}\"\n",
    "    fig.savefig(f\"{out_base}.pdf\", dpi=200, bbox_inches=\"tight\", pad_inches=0)\n",
    "    fig.savefig(f\"{out_base}.png\", dpi=50, bbox_inches=0, pad_inches=0)\n",
    "    return fig, axes\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9bda5605",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = \"FLUX.1-schnell\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "795c0a26",
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: 'results_FLUX.1-schnell_sketch_42/generate_with_hooks_diffusion/coco-captions-styles/none:sketch/incr-mean_ot-mean/transfor..cks.11:0/0.200_0.000/101203.png'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 6\u001b[0m\n\u001b[1;32m      3\u001b[0m labels \u001b[38;5;241m=\u001b[39m [\u001b[38;5;241m0.2\u001b[39m,\u001b[38;5;241m0.4\u001b[39m,\u001b[38;5;241m0.6\u001b[39m,\u001b[38;5;241m0.8\u001b[39m,\u001b[38;5;241m1.0\u001b[39m]\n\u001b[1;32m      4\u001b[0m img_ids \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m101203\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m784562\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m503832\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m65302\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m443957\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m472409\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m509247\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m----> 6\u001b[0m \u001b[43mmake_image_grid_matplotlib\u001b[49m\u001b[43m(\u001b[49m\u001b[43mroot_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mimg_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mimg_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstyle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstyle\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[1], line 32\u001b[0m, in \u001b[0;36mmake_image_grid_matplotlib\u001b[0;34m(root_dir, labels, ids, cell_size, style)\u001b[0m\n\u001b[1;32m     30\u001b[0m ax \u001b[38;5;241m=\u001b[39m axes[r, c]\n\u001b[1;32m     31\u001b[0m path \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(root_dir, \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mlab\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m00_0.000\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mid_\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.png\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 32\u001b[0m img \u001b[38;5;241m=\u001b[39m \u001b[43mmpimg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     33\u001b[0m ax\u001b[38;5;241m.\u001b[39mimshow(img)\n\u001b[1;32m     34\u001b[0m ax\u001b[38;5;241m.\u001b[39maxis(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moff\u001b[39m\u001b[38;5;124m\"\u001b[39m)  \u001b[38;5;66;03m# no ticks/borders\u001b[39;00m\n",
      "File \u001b[0;32m/opt/conda/envs/ml-act/lib/python3.9/site-packages/matplotlib/image.py:1544\u001b[0m, in \u001b[0;36mimread\u001b[0;34m(fname, format)\u001b[0m\n\u001b[1;32m   1537\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(fname, \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(parse\u001b[38;5;241m.\u001b[39murlparse(fname)\u001b[38;5;241m.\u001b[39mscheme) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m   1538\u001b[0m     \u001b[38;5;66;03m# Pillow doesn't handle URLs directly.\u001b[39;00m\n\u001b[1;32m   1539\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m   1540\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease open the URL for reading and pass the \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1541\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mresult to Pillow, e.g. with \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1542\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m``np.array(PIL.Image.open(urllib.request.urlopen(url)))``.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1543\u001b[0m         )\n\u001b[0;32m-> 1544\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43mimg_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfname\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m image:\n\u001b[1;32m   1545\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m (_pil_png_to_float_array(image)\n\u001b[1;32m   1546\u001b[0m             \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(image, PIL\u001b[38;5;241m.\u001b[39mPngImagePlugin\u001b[38;5;241m.\u001b[39mPngImageFile) \u001b[38;5;28;01melse\u001b[39;00m\n\u001b[1;32m   1547\u001b[0m             pil_to_array(image))\n",
      "File \u001b[0;32m/opt/conda/envs/ml-act/lib/python3.9/site-packages/PIL/ImageFile.py:135\u001b[0m, in \u001b[0;36mImageFile.__init__\u001b[0;34m(self, fp, filename)\u001b[0m\n\u001b[1;32m    131\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecodermaxblock \u001b[38;5;241m=\u001b[39m MAXBLOCK\n\u001b[1;32m    133\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_path(fp):\n\u001b[1;32m    134\u001b[0m     \u001b[38;5;66;03m# filename\u001b[39;00m\n\u001b[0;32m--> 135\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    136\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfilename \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mfspath(fp)\n\u001b[1;32m    137\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exclusive_fp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'results_FLUX.1-schnell_sketch_42/generate_with_hooks_diffusion/coco-captions-styles/none:sketch/incr-mean_ot-mean/transfor..cks.11:0/0.200_0.000/101203.png'"
     ]
    }
   ],
   "source": [
    "style = 'sketch'\n",
    "img_path = f\"results_{model}_{style}_42/generate_with_hooks_diffusion/coco-captions-styles/none:{style}/incr-mean_ot-mean/transfor..cks.11:0\"\n",
    "labels = [0.2,0.4,0.6,0.8,1.0]\n",
    "img_ids = ['101203', '784562', '503832', '65302', '443957', '472409', '509247']\n",
    "\n",
    "make_image_grid_matplotlib(root_dir=img_path, labels=labels, ids=img_ids, style=style)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0cafc2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "style = 'cyberpunk'\n",
    "img_path = f\"results_{model}_{style}_42/generate_with_hooks_diffusion/coco-captions-styles/none:{style}/incr-mean_ot_pid-mean/transfor..cks.11:0\"\n",
    "labels = [0.2,0.4,0.6,0.8,1.0]\n",
    "img_ids = ['101203', '784562', '503832', '65302', '443957', '472409', '509247']\n",
    "\n",
    "make_image_grid_matplotlib(root_dir=img_path, labels=labels, ids=img_ids, style=style)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a7788e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "style = 'cyberpunk'\n",
    "method=\"mean_ot_pid\"\n",
    "# 'results_FLUX.1-schnell_mean_ot_cyberpunk_42/generate_with_hooks_diffusion/coco-captions-styles/none:cyberpunk/incr-mean_ot-mean/transfor..cks.11:0/16314_animation.gif'\n",
    "img_path = f\"results_{model}_{method}_{style}_42/generate_with_hooks_diffusion/coco-captions-styles/none:{style}/incr-{method}-mean/transfor..cks.11:0\"\n",
    "labels = [0.2,0.4,0.6,0.8,1.0]\n",
    "img_ids = ['101203', '784562', '503832', '65302', '443957', '472409', '509247']\n",
    "\n",
    "make_image_grid_matplotlib(root_dir=img_path, labels=labels, ids=img_ids, style=style)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bffcc5e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "style = 'steampunk'\n",
    "method=\"mean_ot_pid\"\n",
    "# 'results_FLUX.1-schnell_mean_ot_cyberpunk_42/generate_with_hooks_diffusion/coco-captions-styles/none:cyberpunk/incr-mean_ot-mean/transfor..cks.11:0/16314_animation.gif'\n",
    "img_path = f\"results_{model}_{method}_{style}_42/generate_with_hooks_diffusion/coco-captions-styles/none:{style}/incr-{method}-mean/transfor..cks.11:0\"\n",
    "labels = [0.2,0.4,0.6,0.8,1.0]\n",
    "img_ids = ['101203', '784562', '503832', '65302', '443957', '472409', '509247']\n",
    "\n",
    "make_image_grid_matplotlib(root_dir=img_path, labels=labels, ids=img_ids, style=style)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21daa8d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "style = 'gothic'\n",
    "img_path = f\"results_{style}_42/generate_with_hooks_diffusion/coco-captions-styles/none:{style}/incr-mean_ot-mean/transfor..cks.11:0\"\n",
    "labels = [0.2,0.4,0.6,0.8,1.0]\n",
    "img_ids = ['101203', '784562', '503832', '65302', '443957', '472409', '509247']\n",
    "\n",
    "make_image_grid_matplotlib(root_dir=img_path, labels=labels, ids=img_ids, style=style)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87f5fd10",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml-act",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.23"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
