{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ad42961",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from transformers import CLIPTextModel, CLIPTokenizer\n",
    "from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler\n",
    "from diffusers.utils.import_utils import is_xformers_available\n",
    "from tqdm import tqdm\n",
    "from PIL import Image\n",
    "from itertools import pairwise\n",
    "\n",
    "from schedulers import Scheduler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db504e31-5f80-4e69-8b06-7e378e98a1b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "class StableDiffuser:\n",
    "    def __init__(self,\n",
    "                 model_key,\n",
    "                 scheduler='ddim',\n",
    "                 device='cuda',\n",
    "                 float_dtype=torch.float32):\n",
    "        self.device = device\n",
    "        self.float_dtype = float_dtype\n",
    "        self.vae = AutoencoderKL.from_pretrained(model_key, subfolder='vae',\n",
    "                                                 torch_dtype=float_dtype)\n",
    "        self.tokenizer = CLIPTokenizer.from_pretrained(\n",
    "            model_key,\n",
    "            subfolder='tokenizer',\n",
    "            torch_dtype=float_dtype)\n",
    "        self.text_encoder = CLIPTextModel.from_pretrained(\n",
    "            model_key,\n",
    "            subfolder='text_encoder',\n",
    "            torch_dtype=float_dtype)\n",
    "\n",
    "        self.unet = UNet2DConditionModel.from_pretrained(\n",
    "            model_key,\n",
    "            subfolder='unet',\n",
    "            torch_dtype=float_dtype)\n",
    "        if scheduler == 'pndm':\n",
    "            self.scheduler = PNDMScheduler.from_pretrained(\n",
    "                model_key,\n",
    "                subfolder='scheduler')\n",
    "        else:\n",
    "            assert(scheduler == 'ddim')\n",
    "            self.scheduler = DDIMScheduler.from_pretrained(\n",
    "                model_key,\n",
    "                subfolder='scheduler',\n",
    "                torch_dtype=float_dtype)\n",
    "\n",
    "        self.vae = self.vae.to(self.device)\n",
    "        self.vae.enable_slicing()\n",
    "        self.text_encoder = self.text_encoder.to(self.device)\n",
    "        self.unet = self.unet.to(self.device)\n",
    "        if is_xformers_available():\n",
    "            self.unet.enable_xformers_memory_efficient_attention()\n",
    "\n",
    "        self.alphas = self.scheduler.alphas_cumprod.to(self.device)\n",
    "\n",
    "    def embed_text(self, prompt):\n",
    "        text_input = self.tokenizer([prompt], padding='max_length',\n",
    "                                    max_length=self.tokenizer.model_max_length,\n",
    "                                    truncation=True, return_tensors='pt')\n",
    "        with torch.no_grad():\n",
    "            text_embs = self.text_encoder(\n",
    "                text_input.input_ids.to(self.device))[0]\n",
    "\n",
    "        # Include unconditional text input for classifier-free guidance.\n",
    "        uncond_input = self.tokenizer([''], padding='max_length',\n",
    "                                      max_length=text_input.input_ids.shape[-1],\n",
    "                                      return_tensors='pt')\n",
    "        with torch.no_grad():\n",
    "            uncond_embs = self.text_encoder(\n",
    "                uncond_input.input_ids.to(self.device))[0]\n",
    "        return torch.cat([uncond_embs, text_embs])\n",
    "\n",
    "    def encode_imgs(self, imgs):\n",
    "        '''\n",
    "        Args:\n",
    "          imgs: [B, 3, H, W], with values in [0, 1]\n",
    "        '''\n",
    "        imgs = 2 * imgs - 1\n",
    "        posterior = self.vae.encode(imgs).latent_dist\n",
    "        # TODO: should we use mode here or sample (also pass a generator)?\n",
    "        latents = posterior.sample() * 0.18215\n",
    "        # latents = posterior.mode()\n",
    "        return latents\n",
    "\n",
    "    def decode_latents(self, latents):\n",
    "        latents = latents / 0.18215\n",
    "        imgs = self.vae.decode(latents).sample\n",
    "        imgs = (imgs / 2 + 0.5).clamp(0, 1)\n",
    "        return imgs # (B, 3, H, W)\n",
    "\n",
    "    def predict_noise(self, text_embs, latents_noisy, t,\n",
    "                      guidance_scale=7.5):\n",
    "        with torch.no_grad():\n",
    "            B = latents_noisy.shape[0]\n",
    "            latents_noisy = torch.cat([latents_noisy] * 2)\n",
    "            latents_noisy = self.scheduler.scale_model_input(latents_noisy, t)\n",
    "            noise_pred = self.unet(\n",
    "                latents_noisy,\n",
    "                t,\n",
    "                encoder_hidden_states=text_embs.repeat(B, 1, 1)\n",
    "            ).sample\n",
    "\n",
    "            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
    "            noise_pred = noise_pred_text + guidance_scale * (\n",
    "                noise_pred_text - noise_pred_uncond)\n",
    "        return noise_pred\n",
    "\n",
    "    def prompt_to_img(self, prompt, num_inference_steps=50, guidance_scale=7.5,\n",
    "                      seed=42, custom_sampler=False):\n",
    "        text_embs = self.embed_text(prompt)\n",
    "        latents = torch.randn([1, self.unet.in_channels, 512 // 8, 512 // 8],\n",
    "                              dtype=self.float_dtype,\n",
    "                              generator=torch.manual_seed(seed)).to(self.device)\n",
    "\n",
    "        if custom_sampler:\n",
    "            gam=2\n",
    "            sc = Scheduler(beta_start= 0.00085, beta_end= 0.012, beta_schedule='scaled_linear')\n",
    "            sc.set_timesteps_sigma(start=6.57, end=0.1195, num_inference_steps=num_inference_steps, style='Linear')\n",
    "            eps = None\n",
    "            for i, (t, t_prev) in enumerate(pairwise(sc.timesteps)):\n",
    "                eps, eps_prev = self.predict_noise(text_embs, latents, t,guidance_scale=guidance_scale), eps\n",
    "                eps_av = eps * gam + eps_prev * (1-gam)  if i > 0 else eps\n",
    "                latents += (sc.sigma(t_prev) - sc.sigma(t)) * eps_av\n",
    "        else:\n",
    "            self.scheduler.set_timesteps(num_inference_steps)\n",
    "            for t in self.scheduler.timesteps: # t is in [0, num_train_steps]\n",
    "                latents = self.scheduler.step(\n",
    "                    self.predict_noise(text_embs, latents, t,\n",
    "                                       guidance_scale=guidance_scale),\n",
    "                    t, latents).prev_sample\n",
    "        return self.decode_latents(latents)[0]\n",
    "\n",
    "def tensor_to_pil(img):\n",
    "    img = img.detach().cpu().permute(1, 2, 0).numpy()\n",
    "    img = (img * 255).round().astype(\"uint8\")\n",
    "    return Image.fromarray(img)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd5ecab9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_key = 'stabilityai/stable-diffusion-2-1-base'\n",
    "sd = StableDiffuser(model_key=model_key, scheduler='pndm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d65f1bb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(prompt, seed=0, num_inference_steps=10, guidance_scale=7.5, style='2nd'):\n",
    "    if style in ('ddim', 'pndm', 'dpm', 'unipc'):\n",
    "        if style == 'ddim':\n",
    "            sd.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder='scheduler')\n",
    "        elif style == 'pndm':\n",
    "            sd.scheduler = PNDMScheduler.from_pretrained(model_key, subfolder='scheduler')\n",
    "        elif style == 'dpm':\n",
    "            sd.scheduler = DPMSolverMultistepScheduler.from_pretrained(model_key, subfolder='scheduler')\n",
    "        elif style == 'unipc':\n",
    "            sd.scheduler = UniPCMultistepScheduler.from_pretrained(model_key, subfolder='scheduler')\n",
    "        return sd.prompt_to_img(prompt, seed=seed, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)\n",
    "        \n",
    "    text_embs = sd.embed_text(prompt)\n",
    "    latents = torch.randn([1, sd.unet.in_channels, 64, 64],\n",
    "                        dtype=sd.float_dtype,\n",
    "                        generator=torch.manual_seed(seed)).to(sd.device)\n",
    "    sc = Scheduler(beta_start= 0.00085, beta_end=0.012, beta_schedule='scaled_linear')\n",
    "    sc.set_timesteps_sigma(start=14.6, end=0.02, num_inference_steps=num_inference_steps, style='DDIM')\n",
    "    sigma_min = sc.sigma(sc.timesteps[-2])\n",
    "    sc.set_timesteps_sigma(start=14.6, end=sigma_min, num_inference_steps=num_inference_steps, style='Linear')\n",
    "    print(sc.timesteps)\n",
    "    latents = latents / sc.ap(sc.timesteps[0]).sqrt()\n",
    "    if style == 'ddim':\n",
    "        for t, t_prev in pairwise(sc.timesteps):\n",
    "            eps = sd.predict_noise(text_embs, latents * sc.ap(t).sqrt(), t,guidance_scale=guidance_scale)\n",
    "            latents += (sc.sigma(t_prev) - sc.sigma(t)) * eps\n",
    "    elif style == '2nd':\n",
    "        eps = None\n",
    "        for i, (t, t_prev) in enumerate(pairwise(sc.timesteps)):\n",
    "            eps, eps_prev = sd.predict_noise(text_embs, latents * sc.ap(t).sqrt(), t,guidance_scale=guidance_scale), eps\n",
    "            eps_av = eps * 2 - eps_prev if i > 0 else eps\n",
    "            latents += (sc.sigma(t_prev) - sc.sigma(t)) * eps_av\n",
    "    return sd.decode_latents(latents)[0]\n",
    "\n",
    "def experiment(prompt, seed=0, N=10, saveto=None):\n",
    "    for style in ('2nd', 'unipc', 'dpm', 'pndm', 'ddim'):\n",
    "        img = tensor_to_pil(sample(prompt, seed=seed, num_inference_steps=N, style=style))\n",
    "        if saveto is None:\n",
    "            display(img)\n",
    "        else:\n",
    "            img.save(f'{saveto}_{style}.png')\n",
    "\n",
    "# Save pictures used in paper\n",
    "#   experiment(\"A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors\",\n",
    "#              seed=4, N=10, saveto='/figures/sd_1')\n",
    "#   experiment(\"london luxurious interior living-room, light walls\",\n",
    "#              seed=3, N=10, saveto='figures/sd_2')\n",
    "#   experiment(\"Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k\",\n",
    "#              seed=0, N=10, saveto='figures/sd_3')\n",
    "\n",
    "experiment(\"Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k\",\n",
    "           seed=0, N=10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:diffusion-experiments]",
   "language": "python",
   "name": "conda-env-diffusion-experiments-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "name": "sd_new_sampler.ipynb"
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
