{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision.transforms as T\n",
    "import torchvision.transforms.functional as TF\n",
    "import torchvision\n",
    "import os\n",
    "from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler,DPMSolverMultistepScheduler\n",
    "vae = AutoencoderKL.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"vae\", use_safetensors=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def image2vector(image):\n",
    "    global vae\n",
    "    #to tensor\n",
    "    if(type(image)==np.ndarray):\n",
    "        image=torch.from_numpy(image)\n",
    "    if(len(image.shape)==3):\n",
    "        image=image.unsqueeze(0)\n",
    "    #to latent space\n",
    "    latent=vae.encode(image)\n",
    "    #to vector\n",
    "    vector=latent.flatten()\n",
    "    return vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#extract images from a certain folder and transform them to vectors and save as torch tensor\n",
    "def extract_images_from_file(folder,save_folder):\n",
    "    global vae\n",
    "    #get all images from folder\n",
    "    images=[]\n",
    "    for filename in os.listdir(folder):\n",
    "        img=Image.open(os.path.join(folder,filename))\n",
    "        if img is not None:\n",
    "            images.append(img)\n",
    "    #transform to tensor\n",
    "    images=torch.stack([TF.to_tensor(x) for x in images])\n",
    "    #transform to latent space\n",
    "    latent=vae.encode(images)\n",
    "    #flatten\n",
    "    latent=latent.flatten(start_dim=1)\n",
    "    #save as torch tensor\n",
    "    torch.save(latent,os.path.join(save_folder,\"latent.pt\"))\n",
    "    return latent\n",
    "def extract_images_from_tensor(tensor,save_folder):\n",
    "    global vae\n",
    "    #transform to latent space\n",
    "    latent=vae.encode(tensor)\n",
    "    #flatten\n",
    "    latent=latent.flatten(start_dim=1)\n",
    "    #save as torch tensor\n",
    "    torch.save(latent,os.path.join(save_folder,\"latent.pt\"))\n",
    "    return latent\n",
    "#find the most simliar image in a folder to a given image\n",
    "def find_most_similar_image(image,folder):\n",
    "    #transform image to vector\n",
    "    vector=image2vector(image)\n",
    "    #load latent tensors\n",
    "    latent=torch.load(os.path.join(folder,\"latent.pt\"))\n",
    "    #calculate distances\n",
    "    distances=torch.norm(latent-vector,dim=1)\n",
    "    #find closest image\n",
    "    index=torch.argmin(distances)\n",
    "    #load image\n",
    "    image=Image.open(os.path.join(folder,str(index)+\".png\"))\n",
    "    return image"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
