{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VAE MNIST example: BO in a latent space"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we use the MNIST dataset and some standard PyTorch examples to show a synthetic problem where the input to the objective function is a `28 x 28` image. The main idea is to train a [variational auto-encoder (VAE)](https://arxiv.org/abs/1312.6114) on the MNIST dataset and run Bayesian Optimization in the latent space. We also refer readers to [this tutorial](http://krasserm.github.io/2018/04/07/latent-space-optimization/), which discusses [the method](https://arxiv.org/abs/1610.02415) of jointly training a VAE with a predictor (e.g., classifier), and shows a similar tutorial for the MNIST setting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets  # transforms\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "dtype = torch.double\n",
    "SMOKE_TEST = os.environ.get(\"SMOKE_TEST\", False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Problem setup\n",
    "\n",
    "Let's first define our synthetic expensive-to-evaluate objective function. We assume that it takes the following form:\n",
    "\n",
    "$$\\text{image} \\longrightarrow \\text{image classifier} \\longrightarrow \\text{scoring function} \n",
    "\\longrightarrow \\text{score}.$$\n",
    "\n",
    "The classifier is a convolutional neural network (CNN) trained using the architecture of the [PyTorch CNN example](https://github.com/pytorch/examples/tree/master/mnist)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 20, 5, 1)\n",
    "        self.conv2 = nn.Conv2d(20, 50, 5, 1)\n",
    "        self.fc1 = nn.Linear(4 * 4 * 50, 500)\n",
    "        self.fc2 = nn.Linear(500, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        x = x.view(-1, 4 * 4 * 50)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pretrained_dir() -> str:\n",
    "    \"\"\"\n",
    "    Get the directory of pretrained models, which are in the BoTorch repo.\n",
    "\n",
    "    Returns the location specified by PRETRAINED_LOCATION if that env\n",
    "    var is set; otherwise checks if we are in a likely part of the BoTorch\n",
    "    repo (botorch/botorch or botorch/tutorials) and returns the right path.\n",
    "    \"\"\"\n",
    "    if \"PRETRAINED_LOCATION\" in os.environ.keys():\n",
    "        return os.environ[\"PRETRAINED_LOCATION\"]\n",
    "    cwd = os.getcwd()\n",
    "    folder = os.path.basename(cwd)\n",
    "    # automated tests run from botorch folder\n",
    "    if folder == \"botorch\":  \n",
    "        return os.path.join(cwd, \"tutorials/pretrained_models/\")\n",
    "    # typical case (running from tutorial folder)\n",
    "    elif folder == \"tutorials\":\n",
    "        return os.path.join(cwd, \"pretrained_models/\")\n",
    "    raise FileNotFoundError(\"Could not figure out location of pretrained models.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnn_weights_path = os.path.join(get_pretrained_dir(), \"mnist_cnn.pt\")\n",
    "cnn_model = Net().to(dtype=dtype, device=device)\n",
    "cnn_state_dict = torch.load(cnn_weights_path, map_location=device)\n",
    "cnn_model.load_state_dict(cnn_state_dict);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our VAE model follows the [PyTorch VAE example](https://github.com/pytorch/examples/tree/master/vae), except that we use the same data transform from the CNN tutorial for consistency. We then instantiate the model and again load a pre-trained model. To train these models, we refer readers to the PyTorch Github repository. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VAE(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(784, 400)\n",
    "        self.fc21 = nn.Linear(400, 20)\n",
    "        self.fc22 = nn.Linear(400, 20)\n",
    "        self.fc3 = nn.Linear(20, 400)\n",
    "        self.fc4 = nn.Linear(400, 784)\n",
    "\n",
    "    def encode(self, x):\n",
    "        h1 = F.relu(self.fc1(x))\n",
    "        return self.fc21(h1), self.fc22(h1)\n",
    "\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        std = torch.exp(0.5 * logvar)\n",
    "        eps = torch.randn_like(std)\n",
    "        return mu + eps * std\n",
    "\n",
    "    def decode(self, z):\n",
    "        h3 = F.relu(self.fc3(z))\n",
    "        return torch.sigmoid(self.fc4(h3))\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x.view(-1, 784))\n",
    "        z = self.reparameterize(mu, logvar)\n",
    "        return self.decode(z), mu, logvar\n",
    "\n",
    "vae_weights_path = os.path.join(get_pretrained_dir(), \"mnist_vae.pt\")\n",
    "vae_model = VAE().to(dtype=dtype, device=device)\n",
    "vae_state_dict = torch.load(vae_weights_path, map_location=device)\n",
    "vae_model.load_state_dict(vae_state_dict);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now define the scoring function that maps digits to scores. The function below prefers the digit '3'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def score(y):\n",
    "    \"\"\"Returns a 'score' for each digit from 0 to 9. It is modeled as a squared exponential\n",
    "    centered at the digit '3'.\n",
    "    \"\"\"\n",
    "    return torch.exp(-2 * (y - 3) ** 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given the scoring function, we can now write our overall objective, which as discussed above, starts with an image and outputs a score. Let's say the objective computes the expected score given the probabilities from the classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def score_image(x):\n",
    "    \"\"\"The input x is an image and an expected score \n",
    "    based on the CNN classifier and the scoring \n",
    "    function is returned.\n",
    "    \"\"\"\n",
    "    with torch.no_grad():\n",
    "        probs = torch.exp(cnn_model(x))  # b x 10\n",
    "        scores = score(\n",
    "            torch.arange(10, device=device, dtype=dtype)\n",
    "        ).expand(probs.shape)\n",
    "    return (probs * scores).sum(dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we define a helper function `decode` that takes as input the parameters `mu` and `logvar` of the variational distribution and performs reparameterization and the decoding. We use batched Bayesian optimization to search over the parameters `mu` and `logvar`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode(train_x):\n",
    "    with torch.no_grad():\n",
    "        decoded = vae_model.decode(train_x)\n",
    "    return decoded.view(train_x.shape[0], 1, 28, 28)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Model initialization and initial random batch\n",
    "\n",
    "We use a `SingleTaskGP` to model the score of an image generated by a latent representation. The model is initialized with points drawn from $[-6, 6]^{20}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<stdin>:1:10: fatal error: 'omp.h' file not found\n",
      "#include <omp.h>\n",
      "         ^~~~~~~\n",
      "1 error generated.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[KeOps] Warning : omp.h header is not in the path, disabling OpenMP.\n",
      "[KeOps] Warning : Cuda libraries were not detected on the system ; using cpu only mode\n"
     ]
    }
   ],
   "source": [
    "from botorch.models import SingleTaskGP\n",
    "from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood\n",
    "from botorch.utils.transforms import normalize, unnormalize\n",
    "from botorch.models.transforms import Standardize, Normalize\n",
    "\n",
    "d = 20\n",
    "bounds = torch.tensor([[-6.0] * d, [6.0] * d], device=device, dtype=dtype)\n",
    "\n",
    "\n",
    "def gen_initial_data(n=5):\n",
    "    # generate training data\n",
    "    train_x = unnormalize(\n",
    "        torch.rand(n, d, device=device, dtype=dtype), \n",
    "        bounds=bounds\n",
    "    )\n",
    "    train_obj = score_image(decode(train_x)).unsqueeze(-1)\n",
    "    best_observed_value = train_obj.max().item()\n",
    "    return train_x, train_obj, best_observed_value\n",
    "\n",
    "\n",
    "def get_fitted_model(train_x, train_obj, state_dict=None):\n",
    "    # initialize and fit model\n",
    "    model = SingleTaskGP(\n",
    "        train_X=normalize(train_x, bounds), \n",
    "        train_Y=train_obj,\n",
    "        outcome_transform=Standardize(m=1)\n",
    "    )\n",
    "    if state_dict is not None:\n",
    "        model.load_state_dict(state_dict)\n",
    "    mll = ExactMarginalLogLikelihood(model.likelihood, model)\n",
    "    mll.to(train_x)\n",
    "    fit_gpytorch_mll(mll)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define a helper function that performs the essential BO step\n",
    "The helper function below takes an acquisition function as an argument, optimizes it, and returns the batch $\\{x_1, x_2, \\ldots x_q\\}$ along with the observed function values. For this example, we'll use a small batch of $q=3$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from botorch.optim import optimize_acqf\n",
    "\n",
    "\n",
    "BATCH_SIZE = 3 if not SMOKE_TEST else 2\n",
    "NUM_RESTARTS = 10 if not SMOKE_TEST else 2\n",
    "RAW_SAMPLES = 256 if not SMOKE_TEST else 4\n",
    "\n",
    "\n",
    "def optimize_acqf_and_get_observation(acq_func):\n",
    "    \"\"\"Optimizes the acquisition function, and returns a\n",
    "    new candidate and a noisy observation\"\"\"\n",
    "\n",
    "    # optimize\n",
    "    candidates, _ = optimize_acqf(\n",
    "        acq_function=acq_func,\n",
    "        bounds=torch.stack(\n",
    "            [\n",
    "                torch.zeros(d, dtype=dtype, device=device),\n",
    "                torch.ones(d, dtype=dtype, device=device),\n",
    "            ]\n",
    "        ),\n",
    "        q=BATCH_SIZE,\n",
    "        num_restarts=NUM_RESTARTS,\n",
    "        raw_samples=RAW_SAMPLES,\n",
    "    )\n",
    "\n",
    "    # observe new values\n",
    "    new_x = unnormalize(candidates.detach(), bounds=bounds)\n",
    "    new_obj = score_image(decode(new_x)).unsqueeze(-1)\n",
    "    return new_x, new_obj"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Perform Bayesian Optimization loop with qEI\n",
    "The Bayesian optimization \"loop\" for a batch size of $q$ simply iterates the following steps: (1) given a surrogate model, choose a batch of points $\\{x_1, x_2, \\ldots x_q\\}$, (2) observe $f(x)$ for each $x$ in the batch, and (3) update the surrogate model. We run `N_BATCH=75` iterations. The acquisition function is approximated using `MC_SAMPLES=2048` samples. We also initialize the model with 5 randomly drawn points."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.\n"
     ]
    }
   ],
   "source": [
    "from botorch import fit_gpytorch_mll\n",
    "from botorch.acquisition.monte_carlo import qExpectedImprovement\n",
    "from botorch.sampling.normal import SobolQMCNormalSampler\n",
    "\n",
    "seed = 1\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "N_BATCH = 25 if not SMOKE_TEST else 3\n",
    "best_observed = []\n",
    "\n",
    "# call helper function to initialize model\n",
    "train_x, train_obj, best_value = gen_initial_data(n=5)\n",
    "best_observed.append(best_value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are now ready to run the BO loop (this make take a few minutes, depending on your machine)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Running BO ........................."
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "\n",
    "print(f\"\\nRunning BO \", end=\"\")\n",
    "\n",
    "state_dict = None\n",
    "# run N_BATCH rounds of BayesOpt after the initial random batch\n",
    "for iteration in range(N_BATCH):\n",
    "\n",
    "    # fit the model\n",
    "    model = get_fitted_model(\n",
    "        train_x=train_x,\n",
    "        train_obj=train_obj,\n",
    "        state_dict=state_dict,\n",
    "    )\n",
    "\n",
    "    # define the qNEI acquisition function\n",
    "    qEI = qExpectedImprovement(\n",
    "        model=model, best_f=train_obj.max()\n",
    "    )\n",
    "\n",
    "    # optimize and get new observation\n",
    "    new_x, new_obj = optimize_acqf_and_get_observation(qEI)\n",
    "\n",
    "    # update training points\n",
    "    train_x = torch.cat((train_x, new_x))\n",
    "    train_obj = torch.cat((train_obj, new_obj))\n",
    "\n",
    "    # update progress\n",
    "    best_value = train_obj.max().item()\n",
    "    best_observed.append(best_value)\n",
    "\n",
    "    state_dict = model.state_dict()\n",
    "\n",
    "    print(\".\", end=\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "EI recommends the best point observed so far. We can visualize what the images corresponding to recommended points *would have* been if the BO process ended at various times. Here, we show the progress of the algorithm by examining the images at 0%, 10%, 25%, 50%, 75%, and 100% completion. The first image is the best image found through the initial random batch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABGwAAADJCAYAAAB2bqQSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZD0lEQVR4nO3dXWzddf3A8U87WDf30DEeWipb3IUJJpih6zYLiCgNC0bChAuIF4IPPNmRwC6MU4GEkEzBCAGHJCIQL2CGC0bEBGMGG0L2QCcEELNgQmQKLZK4rgz2YPv7XyyUfz1nc6c9D99zvq9Xci727enp91veHvTj2e/XVhRFEQAAAAAko73RGwAAAABgMgMbAAAAgMQY2AAAAAAkxsAGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBiDGwAAAAAEmNgAwAAAJAYAxsAAACAxJxQqxfesGFD3HXXXTE0NBRLly6N++67L1asWPE/v298fDzefvvtmDdvXrS1tdVqexBFUcTo6Gj09PREe/vUZpc6J3U6Jwc6Jwc6Jwc6JwcVdV7UwMaNG4uZM2cWDz30UPGXv/yluOaaa4oFCxYUw8PD//N79+zZU0SEh0fdHnv27NG5R8s/dO6Rw0PnHjk8dO6Rw0PnHjk8jqfztqIoiqiylStXxvLly+MXv/hFRByZVi5atChuvPHG+MEPfnDM7x0ZGYkFCxbEZz/72ZgxY0a1twYTxsbG4tVXX429e/dGZ2dnxd+vc5qBzsmBzsmBzsmBzslBJZ1X/a9EHTp0KHbt2hXr1q2bWGtvb4/+/v7Ytm1byfMPHjwYBw8enPjz6OhoRETMmDHDf1Coi6l85FHnNBudkwOdkwOdkwOdk4Pj6bzqFx1+7733YmxsLLq6uiatd3V1xdDQUMnz169fH52dnROPRYsWVXtLUHU6Jwc6Jwc6Jwc6Jwc6pxU1/C5R69ati5GRkYnHnj17Gr0lqDqdkwOdkwOdkwOdkwOd0wyq/leiTjnllJgxY0YMDw9PWh8eHo7u7u6S53d0dERHR0e1twE1pXNyoHNyoHNyoHNyoHNaUdU/YTNz5sxYtmxZbN68eWJtfHw8Nm/eHH19fdX+cdAQOicHOicHOicHOicHOqcVVf0TNhERa9eujauuuip6e3tjxYoVcc8998T+/fvjW9/6Vi1+HDSEzsmBzsmBzsmBzsmBzmk1NRnYXHHFFfGvf/0rbr311hgaGoqzzz47nn766ZILQEEz0zk50Dk50Dk50Dk50Dmtpq0oiqLRm/j/9u3bF52dnXH22We7nRo1NTY2Fi+//HKMjIzE/Pnz6/qzdU696Jwc6Jwc6Jwc6JwcVNJ5w+8SBQAAAMBkBjYAAAAAiTGwAQAAAEiMgQ0AAABAYgxsAAAAABJjYAMAAACQGAMbAAAAgMQY2AAAAAAkxsAGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBiTmj0BqiO6667ruz6NddcU5Of99BDD5Vdv//++2vy82gN27ZtK1l77733yj73xRdfLFm7/fbbq74nqLYtW7aUrM2dO7f+Gymjt7e30VugRZTrPCKN1nVOtXg/Jwc6T5tP2AAAAAAkxsAGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBi3CUqYYODg43ewlF9+9vfLrvuLlH5+e53v1uydv311x/3959++ull10866aQp7wmm6le/+lXZ9c997nN13kltlPv3ijsw5Klc663ceYTWc+P9nBzovPX5hA0AAABAYgxsAAAAABJjYAMAAACQGAMbAAAAgMS46HAiUr7AcCVcGCo/Dz74YMnaY489Vva5+/fvr/V24Li9+OKLJWttbW0N2EljuUBrayvXeYTWP6Lz1uD9/Ajv561N50fk1rlP2AAAAAAkxsAGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBi3CWqznbu3NnoLUDNuRsUzSDHOytU4vnnny9ZO++88xqwE6ZD58dWrvMIrTcbnR+b9/PWoPNja9XOfcIGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBiXHS4ztrbazMj6+3trej527dvL1k74QQ5APko9745ODhYt59VS9U4x6xZs6qwExrtaO1p/Qidtwbv58em89ag82Nr1c59wgYAAAAgMQY2AAAAAIkxsAEAAABIjIENAAAAQGIMbAAAAAAS47ZACdu1a1fZ9euuu27ar/3vf/+7ZO3UU0+d9usCNLN63xUBGkXrtDqNkwOdtz6fsAEAAABIjIENAAAAQGIMbAAAAAASU/HA5rnnnotLLrkkenp6oq2tLTZt2jTp60VRxK233hqnn356zJ49O/r7++ONN96o1n6hLnRODnRODnRODnRODnROjioe2Ozfvz+WLl0aGzZsKPv1O++8M+6999544IEHYseOHTFnzpxYtWpVHDhwYNqbbQUffPBB2Udvb2/J47rrriv7qIZTTz215MHHdE4OdN4a5syZU/LgYzpvDeU61/rHdN4aNH5sOm8NOq9MxXeJuvjii+Piiy8u+7WiKOKee+6JH//4x3HppZdGRMRvfvOb6Orqik2bNsWVV145vd1CneicHOicHOicHOicHOicHFX1GjZvvvlmDA0NRX9//8RaZ2dnrFy5MrZt21b2ew4ePBj79u2b9ICU6Zwc6Jwc6Jwc6Jwc6JxWVdWBzdDQUEREdHV1TVrv6uqa+Np/W79+fXR2dk48Fi1aVM0tQdXpnBzonBzonBzonBzonFbV8LtErVu3LkZGRiYee/bsafSWoOp0Tg50Tg50Tg50Tg50TjOo6sCmu7s7IiKGh4cnrQ8PD0987b91dHTE/PnzJz0gZTonBzonBzonBzonBzqnVVV80eFjWbJkSXR3d8fmzZvj7LPPjoiIffv2xY4dO+KGG26o5o9qWueff36jtxARRy7M9d/a2tqm/bq9vb3Tfo3U6Tw9J5xQ/q1s9uzZJWujo6O13k5L0Hl6vvGNb5RdX7t2bU1+nvdznTdKudZ1PnU6T4/38+rTeXp0Xh0VD2zef//9+Nvf/jbx5zfffDNefvnlWLhwYSxevDhuuummuOOOO+LTn/50LFmyJG655Zbo6emJ1atXV3PfUFM6Jwc6Jwc6Jwc6Jwc6J0cVD2wGBwfjy1/+8sSfP5qQXXXVVfHII4/E97///di/f39ce+21sXfv3jjvvPPi6aefjlmzZlVv11BjOicHOicHOicHOicHOidHFQ9sLrjggrJ/neYjbW1tcfvtt8ftt98+rY1BI+mcHOicHOicHOicHOicHDX8LlEAAAAATFbViw7TPJYvX16ytmPHjrLPnTFjRsnaOeecU/U9wf8yODhY0fNXrFhRo51AdVTadDnlOj/aReTLvc8/9dRT094DHEutOo8o3/rR/vuM1qkl7+fkQOf15xM2AAAAAIkxsAEAAABIjIENAAAAQGIMbAAAAAASY2ADAAAAkBh3iaqCSq6W3dvbW8OdTM/KlSsbvQWYUI2r0I+Pj1dhJ1CZarRbiel2/vnPf75KOyEnzdZ5hNapXLN1rnGmQudp8wkbAAAAgMQY2AAAAAAkxsAGAAAAIDEGNgAAAACJcdHhCk33okxH+/6NGzeWrP3sZz+b1s8CoHZSuUhfR0dHydoLL7xw3K/b09Mz5T2RhxRaL9d5hNapjhQaj/B+Tm3pvDn5hA0AAABAYgxsAAAAABJjYAMAAACQGAMbAAAAgMQY2AAAAAAkxl2ijqLeV9G+8sorj2stIqIoipK15cuXV31PUA/t7dOfG/f29lZhJ3B09f53Qjnvv/9+2fVK7qwAx5JC5xHlW9c51ZJC597PqTWdtw6fsAEAAABIjIENAAAAQGIMbAAAAAASY2ADAAAAkBgDGwAAAIDEuEtURGzZsqXRW6jIHXfc0egtHPXOPjt37ixZe+CBB8o+98EHH6zqnkhfW1tbydr4+HjZ57rzE43Q0dFRk9cdGhoqu/61r33tuF/jaHtztwWmop6tV9J5RPm96ZxKeT8nBzpvfT5hAwAAAJAYAxsAAACAxBjYAAAAACTGwAYAAAAgMS46HBFz585t9BbKev3118uuP/nkk3XeSalyFxc+muuvv77s+tVXX12ydt555011SyTkaBfyLnfxsb6+vhrvBo7fwYMHy67/6U9/Kln70Y9+VPa5H3zwQVX39JGj7a1WPvnJT5as/fOf/6zrHqidcj2V6zyifOu16jyivq2X6zxC663A+/nHvJ+3Lp1/rFU79wkbAAAAgMQY2AAAAAAkxsAGAAAAIDEGNgAAAACJMbABAAAASIy7RCVi//79JWvf/OY3G7ATqEy5u6zNmjWr7HO/8IUv1Ho7UBM333xzo7dQd+XuSNjb29uAnVAvOv+Y1luXzo/QeGvT+RGt0LlP2AAAAAAkxsAGAAAAIDEGNgAAAACJMbABAAAASIyLDkfEBRdcULK2ZcuWuu7hS1/6Ul1/XiUGBwdr8rpHuzAtzaUoipI1FxcGAACYHp+wAQAAAEiMgQ0AAABAYgxsAAAAABJjYAMAAACQmIoGNuvXr4/ly5fHvHnz4rTTTovVq1fH7t27Jz3nwIEDMTAwECeffHLMnTs3Lr/88hgeHq7qpqGWdE4OdE4OdE4OdE4OdE6uKrpL1NatW2NgYCCWL18e//nPf+KHP/xhXHTRRfH666/HnDlzIiLi5ptvjt///vfx+OOPR2dnZ6xZsyYuu+yyeOGFF2pygGp4//336/azvvrVr9btZ1WqVneDajat2vl0VdLH6tWry67/4x//qNJuqm/79u0la618tyudV6aS/g8cOFCydvDgwbLPfeWVV0rWvvjFLx7/xmro0KFDjd7CtOm8MtPtPKJ86+U6j0ijdZ3r/Fi8n6dD55XReeuoaGDz9NNPT/rzI488Eqeddlrs2rUrzj///BgZGYlf//rX8eijj8ZXvvKViIh4+OGH4zOf+Uxs3769pf/HD61D5+RA5+RA5+RA5+RA5+RqWtewGRkZiYiIhQsXRkTErl274vDhw9Hf3z/xnDPPPDMWL14c27ZtK/saBw8ejH379k16QEp0Tg50Tg50Tg50Tg50Ti6mPLAZHx+Pm266Kc4999w466yzIiJiaGgoZs6cGQsWLJj03K6urhgaGir7OuvXr4/Ozs6Jx6JFi6a6Jag6nZMDnZMDnZMDnZMDnZOTKQ9sBgYG4rXXXouNGzdOawPr1q2LkZGRiceePXum9XpQTTonBzonBzonBzonBzonJxVdw+Yja9asiaeeeiqee+65OOOMMybWu7u749ChQ7F3795J083h4eHo7u4u+1odHR3R0dExlW3UVG9vb9n19vbSGdfOnTvLPveKK64oWXv33Xent7Eq+elPf9roLRz1d5yKHDqvlU2bNpVdT/mfea5/t1nn1Tdr1qzjWotI50J95ZxzzjmN3kLV6Lz6jtZ0uXWd14fOq8/7eXp0Xn06T1tFn7ApiiLWrFkTTzzxRDzzzDOxZMmSSV9ftmxZnHjiibF58+aJtd27d8dbb70VfX191dkx1JjOyYHOyYHOyYHOyYHOyVVFn7AZGBiIRx99NJ588smYN2/exN8H7OzsjNmzZ0dnZ2d85zvfibVr18bChQtj/vz5ceONN0ZfX1+2/+81zUfn5EDn5EDn5EDn5EDn5Kqigc0vf/nLiIi44IILJq0//PDDcfXVV0dExN133x3t7e1x+eWXx8GDB2PVqlVx//33V2WzUA86Jwc6Jwc6Jwc6Jwc6J1cVDWyKovifz5k1a1Zs2LAhNmzYMOVNQSPpnBzonBzonBzonBzonFxN+S5RAAAAANTGlO4SlbPx8fGStZTvfFPurlYRERdeeGGddwIRg4ODJWvl/jMVEbFixYpabwcAACBZPmEDAAAAkBgDGwAAAIDEGNgAAAAAJMbABgAAACAxLjrc4nbu3NnoLSR9UWaOXyX/HP/whz+UXT/55JNL1u6+++4p7wnqZWxsrGRtxowZDdjJ9PT19ZWsHT58uAE7IUXlOo9ovtbLdR6hdY7wfk4OdN46fMIGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBiDGwAAAAAEuMuUS3uaHf2KXeV8B07dhz3646Pj5ddX7FixXG/Bq1r1apVjd4CVNXKlStr8rqDg4PTfg134qNadE4OdE4OdN46fMIGAAAAIDEGNgAAAACJMbABAAAASIyBDQAAAEBiXHQ4U2NjYyVrLgAFUF/ed8mBzsmBzsmBzuvPJ2wAAAAAEmNgAwAAAJAYAxsAAACAxBjYAAAAACTGwAYAAAAgMQY2AAAAAIkxsAEAAABIjIENAAAAQGIMbAAAAAASY2ADAAAAkBgDGwAAAIDEGNgAAAAAJMbABgAAACAxBjYAAAAAiTGwAQAAAEjMCY3ewH8riiIiIsbGxhq8E1rdR4191Fw96Zx60Tk50Dk50Dk50Dk5qKTz5AY2o6OjERHx6quvNngn5GJ0dDQ6Ozvr/jMjdE796Jwc6Jwc6Jwc6JwcHE/nbUUjxpfHMD4+Hm+//XbMmzcvRkdHY9GiRbFnz56YP39+o7dWVfv27WvZs0U0x/mKoojR0dHo6emJ9vb6/u1AnbeGZjifzmuvGTqYjmY4n85rrxk6mI5mOJ/Oa68ZOpiOZjifzmuvGTqYjmY4XyWdJ/cJm/b29jjjjDMiIqKtrS0iIubPn5/sL3u6WvlsEemfr96T+4/ovLWkfj6d10crny0i/fPpvD5a+WwR6Z9P5/XRymeLSP98Oq+PVj5bRPrnO97OXXQYAAAAIDEGNgAAAACJSXpg09HREbfddlt0dHQ0eitV18pni2j981VTK/+uWvlsEa1/vmpq5d9VK58tovXPV02t/Ltq5bNFtP75qqmVf1etfLaI1j9fNbXy76qVzxbReudL7qLDAAAAALlL+hM2AAAAADkysAEAAABIjIENAAAAQGIMbAAAAAASk/TAZsOGDfGpT30qZs2aFStXroydO3c2eksVe+655+KSSy6Jnp6eaGtri02bNk36elEUceutt8bpp58es2fPjv7+/njjjTcas9kKrV+/PpYvXx7z5s2L0047LVavXh27d++e9JwDBw7EwMBAnHzyyTF37ty4/PLLY3h4uEE7TpPO06bz6tB52nReHTpPm86rQ+dp03l16DxtOXWe7MDmt7/9baxduzZuu+22+POf/xxLly6NVatWxbvvvtvorVVk//79sXTp0tiwYUPZr995551x7733xgMPPBA7duyIOXPmxKpVq+LAgQN13mnltm7dGgMDA7F9+/b44x//GIcPH46LLroo9u/fP/Gcm2++OX73u9/F448/Hlu3bo233347LrvssgbuOi0613kOdK7zHOhc5znQuc5zoHOdJ6VI1IoVK4qBgYGJP4+NjRU9PT3F+vXrG7ir6YmI4oknnpj48/j4eNHd3V3cddddE2t79+4tOjo6iscee6wBO5yed999t4iIYuvWrUVRHDnLiSeeWDz++OMTz/nrX/9aRESxbdu2Rm0zKTrXeQ50rvMc6FznOdC5znOgc52nJMlP2Bw6dCh27doV/f39E2vt7e3R398f27Zta+DOquvNN9+MoaGhSefs7OyMlStXNuU5R0ZGIiJi4cKFERGxa9euOHz48KTznXnmmbF48eKmPF+16VznOdC5znOgc53nQOc6z4HOdZ6aJAc27733XoyNjUVXV9ek9a6urhgaGmrQrqrvo7O0wjnHx8fjpptuinPPPTfOOuusiDhyvpkzZ8aCBQsmPbcZz1cLOm++c+q8cjpvvnPqvHI6b75z6rxyOm++c+q8cjpvvnO2eucnNHoDtIaBgYF47bXX4vnnn2/0VqBmdE4OdE4OdE4OdE4OWr3zJD9hc8opp8SMGTNKruI8PDwc3d3dDdpV9X10lmY/55o1a+Kpp56KZ599Ns4444yJ9e7u7jh06FDs3bt30vOb7Xy1ovPmOqfOp0bnzXVOnU+NzpvrnDqfGp031zl1PjU6b65z5tB5kgObmTNnxrJly2Lz5s0Ta+Pj47F58+bo6+tr4M6qa8mSJdHd3T3pnPv27YsdO3Y0xTmLoog1a9bEE088Ec8880wsWbJk0teXLVsWJ5544qTz7d69O956662mOF+t6VznOdC5znOgc53nQOc6z4HOdZ6cRl7x+Fg2btxYdHR0FI888kjx+uuvF9dee22xYMGCYmhoqNFbq8jo6Gjx0ksvFS+99FIREcXPf/7z4qWXXir+/ve/F0VRFD/5yU+KBQsWFE8++WTxyiuvFJdeemmxZMmS4sMPP2zwzv+3G264oejs7Cy2bNlSvPPOOxOPDz74YOI5119/fbF48eLimWeeKQYHB4u+vr6ir6+vgbtOi851ngOd6zwHOtd5DnSu8xzoXOcpSXZgUxRFcd999xWLFy8uZs6cWaxYsaLYvn17o7dUsWeffbaIiJLHVVddVRTFkVuq3XLLLUVXV1fR0dFRXHjhhcXu3bsbu+njVO5cEVE8/PDDE8/58MMPi+9973vFSSedVHziE58ovv71rxfvvPNO4zadIJ2nTefVofO06bw6dJ42nVeHztOm8+rQedpy6rytKIpi6p/PAQAAAKDakryGDQAAAEDODGwAAAAAEmNgAwAAAJAYAxsAAACAxBjYAAAAACTGwAYAAAAgMQY2AAAAAIkxsAEAAABIjIENAAAAQGIMbAAAAAASY2ADAAAAkBgDGwAAAIDE/B8p7caKHn6ulwAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 1400x1400 with 6 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(1, 6, figsize=(14, 14))\n",
    "percentages = np.array([0, 10, 25, 50, 75, 100], dtype=np.float32)\n",
    "inds = (N_BATCH * BATCH_SIZE * percentages / 100 + 4).astype(int)\n",
    "\n",
    "for i, ax in enumerate(ax.flat):\n",
    "    b = torch.argmax(score_image(decode(train_x[: inds[i], :])), dim=0)\n",
    "    img = decode(train_x[b].view(1, -1)).squeeze().cpu()\n",
    "    ax.imshow(img, alpha=0.8, cmap=\"gray\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
