{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/yang-song/score_sde/blob/main/Score_SDE_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "I2mneJHBuulE"
   },
   "source": [
    "# Preparation\n",
    "\n",
    "1. `git clone https://github.com/yang-song/score_sde.git`\n",
    "\n",
    "2. Install [required packages](https://github.com/yang-song/score_sde/blob/main/requirements.txt)\n",
    "\n",
    "3. `cd` into folder `score_sde`, launch a local jupyter server and connect to colab following [these instructions](https://research.google.com/colaboratory/local-runtimes.html)\n",
    "\n",
    "4. Download pre-trained [checkpoints](https://drive.google.com/drive/folders/1RAG8qpOTURkrqXKwdAR1d6cU9rwoQYnH?usp=sharing) and save them in the `exp` folder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "A1bYhwHGHw6D",
    "outputId": "fd0511fa-a3f5-49dc-800d-48a4671f3481",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#@title Autoload all modules\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from dataclasses import dataclass, field\n",
    "import matplotlib.pyplot as plt\n",
    "import io\n",
    "import csv\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import importlib\n",
    "import os\n",
    "import functools\n",
    "import itertools\n",
    "import jax.random as random\n",
    "\n",
    "import flax\n",
    "import flax.linen as nn\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import tensorflow_gan as tfgan\n",
    "import tqdm\n",
    "import io\n",
    "import inspect\n",
    "\n",
    "import models\n",
    "from models import utils as mutils\n",
    "from models import ncsnv2\n",
    "from models import ncsnpp\n",
    "from models import ddpm as ddpm_model\n",
    "from models import layerspp\n",
    "from models import layers\n",
    "from models import normalization\n",
    "import run_lib\n",
    "\n",
    "import sampling\n",
    "import losses as losses_lib\n",
    "import utils\n",
    "import evaluation\n",
    "from models import up_or_down_sampling as stylegan_layers\n",
    "import datasets\n",
    "from models import wideresnet_noise_conditional\n",
    "import sde_lib\n",
    "import likelihood\n",
    "import controllable_generation\n",
    "\n",
    "from sampling import *\n",
    "from sde_lib import *\n",
    "from scipy import integrate\n",
    "\n",
    "import os\n",
    "os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8aIP2GrjrOEM",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# @title Load the score-based model\n",
    "\n",
    "from configs.vp import cifar10_ddpmpp_continuous as configs\n",
    "\n",
    "# TODO: change ckpt path\n",
    "ckpt_filename = os.path.expanduser(\"~/projects/dodes/ckpts/cifar10_ddpmpp_continuous_l\")\n",
    "assert os.path.exists(ckpt_filename)\n",
    "config = configs.get_config()\n",
    "sde = VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)\n",
    "sampling_eps = 1e-3\n",
    "\n",
    "batch_size =   64#@param {\"type\":\"integer\"}\n",
    "local_batch_size = batch_size // jax.local_device_count()\n",
    "config.training.batch_size = batch_size\n",
    "config.eval.batch_size = batch_size\n",
    "\n",
    "random_seed = 0 #@param {\"type\": \"integer\"}\n",
    "rng = jax.random.PRNGKey(random_seed)\n",
    "rng, run_rng = jax.random.split(rng)\n",
    "rng, model_rng = jax.random.split(rng)\n",
    "score_model, init_model_state, initial_params = mutils.init_model(run_rng, config)\n",
    "optimizer = losses_lib.get_optimizer(config).create(initial_params)\n",
    "\n",
    "state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr,\n",
    "                      model_state=init_model_state,\n",
    "                      ema_rate=config.model.ema_rate,\n",
    "                      params_ema=initial_params,\n",
    "                      rng=rng)  # pytype: disable=wrong-keyword-args\n",
    "sigmas = mutils.get_sigmas(config)\n",
    "scaler = datasets.get_data_scaler(config)\n",
    "inverse_scaler = datasets.get_data_inverse_scaler(config)\n",
    "state = utils.load_training_state(ckpt_filename, state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "G8ei2Xsfg6JQ",
    "tags": []
   },
   "outputs": [],
   "source": [
    "#@title Visualization code\n",
    "\n",
    "def image_grid(x):\n",
    "  size = config.data.image_size\n",
    "  channels = config.data.num_channels\n",
    "  img = x.reshape(-1, size, size, channels)\n",
    "  w = int(np.sqrt(img.shape[0]))\n",
    "  img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels))\n",
    "  return img\n",
    "\n",
    "def show_samples(x):\n",
    "  img = image_grid(x)\n",
    "  plt.figure(figsize=(8,8))\n",
    "  plt.axis('off')\n",
    "  plt.imshow(img)\n",
    "  plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rng = random.PRNGKey(888)\n",
    "noise = random.normal(rng, (100, 32, 32, 3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## play deis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import jax_deis as jdeis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "t2alpha_fn, alpha2t_fn = jdeis.get_linear_alpha_fns(sde.beta_0, sde.beta_1)\n",
    "vpsde = jdeis.VPSDE(t2alpha_fn, alpha2t_fn, sampling_eps, sde.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "score_fn = mutils.get_score_fn(sde, score_model, state.params_ema, state.model_state, train=False, continuous=True)\n",
    "def eps_fn(x, scalar_t):\n",
    "  vec_t = scalar_t * jnp.ones(x.shape[0])\n",
    "  score = score_fn(x, vec_t)\n",
    "  std = sde.marginal_prob(jnp.zeros_like(score), vec_t)[1]\n",
    "  eps = - batch_mul(score, std)\n",
    "  return eps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "num_step = 7\n",
    "t_ab_fn = jdeis.get_sampler(vpsde, eps_fn, \"t\", 2, num_step, method=\"t_ab\", ab_order=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t_ab_7_img = t_ab_fn(noise)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_samples(inverse_scaler(t_ab_7_img))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rho_ab_7_fn = jdeis.get_sampler(vpsde, eps_fn, \"t\", 2, num_step, method=\"rho_ab\", ab_order=3)\n",
    "rho_ab_7_img = rho_ab_7_fn(noise)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "show_samples(inverse_scaler(rho_ab_7_img))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rho_rk_7_fn = jdeis.get_sampler(vpsde, eps_fn, \"t\", 2, 7, method=\"rho_rk\", rk_method=\"3kutta\") # it actually use 7 * 3 = 21 NFE\n",
    "rho_rk_7_img = rho_rk_7_fn(noise)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "show_samples(inverse_scaler(rho_rk_7_img))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "Score SDE demo",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
