{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "355UKMUQJxFd"
   },
   "source": [
    "# SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers\n",
    "\n",
    "This notebook samples from pre-trained SiT models. SiTs are class-conSiTional latent interpolant models trained on ImageNet, unifying Flow and Diffusion Methods. \n",
    "\n",
    "[Paper]() | [GitHub](github.com/willisma/SiT)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zJlgLkSaKn7u"
   },
   "source": [
    "# 1. Setup\n",
    "\n",
    "We recommend using GPUs (Runtime > Change runtime type > Hardware accelerator > GPU). Run this cell to clone the SiT GitHub repo and setup PyTorch. You only have to run this once."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/willisma/SiT.git\n",
    "import SiT, os\n",
    "os.chdir('SiT')\n",
    "os.environ['PYTHONPATH'] = '/env/python:/content/SiT'\n",
    "!pip install diffusers timm torchdiffeq --upgrade\n",
    "# SiT imports:\n",
    "import torch\n",
    "from torchvision.utils import save_image\n",
    "from transport import create_transport, Sampler\n",
    "from diffusers.models import AutoencoderKL\n",
    "from download import find_model\n",
    "from models import SiT_XL_2\n",
    "from PIL import Image\n",
    "from IPython.display import display\n",
    "torch.set_grad_enabled(False)\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "if device == \"cpu\":\n",
    "    print(\"GPU not found. Using CPU instead.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AXpziRkoOvV9"
   },
   "source": [
    "# Download SiT-XL/2 Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EWG-WNimO59K"
   },
   "outputs": [],
   "source": [
    "image_size = \"256\"\n",
    "vae_model = \"stabilityai/sd-vae-ft-ema\" #@param [\"stabilityai/sd-vae-ft-mse\", \"stabilityai/sd-vae-ft-ema\"]\n",
    "latent_size = int(image_size) // 8\n",
    "# Load model:\n",
    "model = SiT_XL_2(input_size=latent_size).to(device)\n",
    "state_dict = find_model(f\"SiT-XL-2-{image_size}x{image_size}.pt\")\n",
    "model.load_state_dict(state_dict)\n",
    "model.eval() # important!\n",
    "vae = AutoencoderKL.from_pretrained(vae_model).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5JTNyzNZKb9E"
   },
   "source": [
    "# 2. Sample from Pre-trained SiT Models\n",
    "\n",
    "You can customize several sampling options. For the full list of ImageNet classes, [check out this](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-Hw7B5h4Kk4p"
   },
   "outputs": [],
   "source": [
    "# Set user inputs:\n",
    "seed = 0 #@param {type:\"number\"}\n",
    "torch.manual_seed(seed)\n",
    "num_sampling_steps = 250 #@param {type:\"slider\", min:0, max:1000, step:1}\n",
    "cfg_scale = 4 #@param {type:\"slider\", min:1, max:10, step:0.1}\n",
    "class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:\"raw\"}\n",
    "samples_per_row = 4 #@param {type:\"number\"}\n",
    "sampler_type = \"ODE\" #@param [\"ODE\", \"SDE\"]\n",
    "\n",
    "\n",
    "# Create diffusion object:\n",
    "transport = create_transport()\n",
    "sampler = Sampler(transport)\n",
    "\n",
    "# Create sampling noise:\n",
    "n = len(class_labels)\n",
    "z = torch.randn(n, 4, latent_size, latent_size, device=device)\n",
    "y = torch.tensor(class_labels, device=device)\n",
    "\n",
    "# Setup classifier-free guidance:\n",
    "z = torch.cat([z, z], 0)\n",
    "y_null = torch.tensor([1000] * n, device=device)\n",
    "y = torch.cat([y, y_null], 0)\n",
    "model_kwargs = dict(y=y, cfg_scale=cfg_scale)\n",
    "\n",
    "# Sample images:\n",
    "if sampler_type == \"SDE\":\n",
    "    SDE_sampling_method = \"Euler\" #@param [\"Euler\", \"Heun\"]\n",
    "    diffusion_form = \"linear\" #@param [\"constant\", \"SBDM\", \"sigma\", \"linear\", \"decreasing\", \"increasing-decreasing\"]\n",
    "    diffusion_norm = 1 #@param {type:\"slider\", min:0, max:10.0, step:0.1}\n",
    "    last_step = \"Mean\" #@param [\"Mean\", \"Tweedie\", \"Euler\"]\n",
    "    last_step_size = 0.4 #@param {type:\"slider\", min:0, max:1.0, step:0.01}\n",
    "    sample_fn = sampler.sample_sde(\n",
    "        sampling_method=SDE_sampling_method,\n",
    "        diffusion_form=diffusion_form, \n",
    "        diffusion_norm=diffusion_norm,\n",
    "        last_step_size=last_step_size, \n",
    "        num_steps=num_sampling_steps,\n",
    "    ) \n",
    "elif sampler_type == \"ODE\":\n",
    "    # default to Adaptive Solver\n",
    "    ODE_sampling_method = \"dopri5\" #@param [\"dopri5\", \"euler\", \"rk4\"]\n",
    "    atol = 1e-6\n",
    "    rtol = 1e-3\n",
    "    sample_fn = sampler.sample_ode(\n",
    "        sampling_method=ODE_sampling_method,\n",
    "        atol=atol,\n",
    "        rtol=rtol,\n",
    "        num_steps=num_sampling_steps\n",
    "    ) \n",
    "samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]\n",
    "samples = vae.decode(samples / 0.18215).sample\n",
    "\n",
    "# Save and display images:\n",
    "save_image(samples, \"sample.png\", nrow=int(samples_per_row), \n",
    "           normalize=True, value_range=(-1, 1))\n",
    "samples = Image.open(\"sample.png\")\n",
    "display(samples)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
