{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VAE MNIST example: BO in a latent space"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we use the MNIST dataset and some standard PyTorch examples to show a synthetic problem where the input to the objective function is a `28 x 28` image. The main idea is to train a [variational auto-encoder (VAE)](https://arxiv.org/abs/1312.6114) on the MNIST dataset and run Bayesian Optimization in the latent space. We also refer readers to [this tutorial](http://krasserm.github.io/2018/04/07/latent-space-optimization/), which discusses [the method](https://arxiv.org/abs/1610.02415) of jointly training a VAE with a predictor (e.g., classifier), and shows a similar tutorial for the MNIST setting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-08T21:35:49.795365900Z",
     "start_time": "2025-03-08T21:35:49.781855800Z"
    }
   },
   "outputs": [],
   "source": [
    "# Install dependencies if we are running in colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "    %pip install botorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-08T21:35:53.307983500Z",
     "start_time": "2025-03-08T21:35:49.793365500Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets  # transforms\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "dtype = torch.double\n",
    "SMOKE_TEST = os.environ.get(\"SMOKE_TEST\", False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Problem setup\n",
    "\n",
    "Let's first define our synthetic expensive-to-evaluate objective function. We assume that it takes the following form:\n",
    "\n",
    "$$\\text{image} \\longrightarrow \\text{image classifier} \\longrightarrow \\text{scoring function} \n",
    "\\longrightarrow \\text{score}.$$\n",
    "\n",
    "The classifier is a convolutional neural network (CNN) trained using the architecture of the [PyTorch CNN example](https://github.com/pytorch/examples/tree/master/mnist)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-08T21:35:53.323996200Z",
     "start_time": "2025-03-08T21:35:53.312488100Z"
    }
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 20, 5, 1)\n",
    "        self.conv2 = nn.Conv2d(20, 50, 5, 1)\n",
    "        self.fc1 = nn.Linear(4 * 4 * 50, 500)\n",
    "        self.fc2 = nn.Linear(500, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        x = x.view(-1, 4 * 4 * 50)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-08T21:35:53.347013200Z",
     "start_time": "2025-03-08T21:35:53.325995500Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_pretrained_dir() -> str:\n",
    "    \"\"\"\n",
    "    Get the directory of pretrained models, which are in the BoTorch repo.\n",
    "\n",
    "    Returns the location specified by PRETRAINED_LOCATION if that env\n",
    "    var is set; otherwise checks if we are in a likely part of the BoTorch\n",
    "    repo (botorch/botorch or botorch/tutorials) and returns the right path.\n",
    "    \"\"\"\n",
    "    if \"PRETRAINED_LOCATION\" in os.environ.keys():\n",
    "        return os.environ[\"PRETRAINED_LOCATION\"]\n",
    "    cwd = os.getcwd()\n",
    "    folder = os.path.basename(cwd)\n",
    "    # automated tests run from botorch folder\n",
    "    if folder == \"botorch\":\n",
    "        return os.path.join(cwd, \"tutorials/pretrained_models/\")\n",
    "    # typical case (running from tutorial folder)\n",
    "    elif folder == \"tutorials\":\n",
    "        return os.path.join(cwd, \"pretrained_models/\")\n",
    "    raise FileNotFoundError(\"Could not figure out location of pretrained models.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-08T21:35:53.542929600Z",
     "start_time": "2025-03-08T21:35:53.342014900Z"
    }
   },
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "Could not figure out location of pretrained models.",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mFileNotFoundError\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[5], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m cnn_weights_path \u001B[38;5;241m=\u001B[39m os\u001B[38;5;241m.\u001B[39mpath\u001B[38;5;241m.\u001B[39mjoin(\u001B[43mget_pretrained_dir\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mmnist_cnn.pt\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m      2\u001B[0m cnn_model \u001B[38;5;241m=\u001B[39m Net()\u001B[38;5;241m.\u001B[39mto(dtype\u001B[38;5;241m=\u001B[39mdtype, device\u001B[38;5;241m=\u001B[39mdevice)\n\u001B[0;32m      3\u001B[0m cnn_state_dict \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mload(cnn_weights_path, map_location\u001B[38;5;241m=\u001B[39mdevice, weights_only\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m)\n",
      "Cell \u001B[1;32mIn[4], line 19\u001B[0m, in \u001B[0;36mget_pretrained_dir\u001B[1;34m()\u001B[0m\n\u001B[0;32m     17\u001B[0m \u001B[38;5;28;01melif\u001B[39;00m folder \u001B[38;5;241m==\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtutorials\u001B[39m\u001B[38;5;124m\"\u001B[39m:\n\u001B[0;32m     18\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m os\u001B[38;5;241m.\u001B[39mpath\u001B[38;5;241m.\u001B[39mjoin(cwd, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mpretrained_models/\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m---> 19\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mFileNotFoundError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mCould not figure out location of pretrained models.\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n",
      "\u001B[1;31mFileNotFoundError\u001B[0m: Could not figure out location of pretrained models."
     ]
    }
   ],
   "source": [
    "cnn_weights_path = os.path.join(get_pretrained_dir(), \"mnist_cnn.pt\")\n",
    "cnn_model = Net().to(dtype=dtype, device=device)\n",
    "cnn_state_dict = torch.load(cnn_weights_path, map_location=device, weights_only=True)\n",
    "cnn_model.load_state_dict(cnn_state_dict);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our VAE model follows the [PyTorch VAE example](https://github.com/pytorch/examples/tree/master/vae), except that we use the same data transform from the CNN tutorial for consistency. We then instantiate the model and again load a pre-trained model. To train these models, we refer readers to the PyTorch Github repository. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-08T21:35:53.545928300Z",
     "start_time": "2025-03-08T21:35:53.545928300Z"
    }
   },
   "outputs": [],
   "source": [
    "class VAE(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(784, 400)\n",
    "        self.fc21 = nn.Linear(400, 20)\n",
    "        self.fc22 = nn.Linear(400, 20)\n",
    "        self.fc3 = nn.Linear(20, 400)\n",
    "        self.fc4 = nn.Linear(400, 784)\n",
    "\n",
    "    def encode(self, x):\n",
    "        h1 = F.relu(self.fc1(x))\n",
    "        return self.fc21(h1), self.fc22(h1)\n",
    "\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        std = torch.exp(0.5 * logvar)\n",
    "        eps = torch.randn_like(std)\n",
    "        return mu + eps * std\n",
    "\n",
    "    def decode(self, z):\n",
    "        h3 = F.relu(self.fc3(z))\n",
    "        return torch.sigmoid(self.fc4(h3))\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x.view(-1, 784))\n",
    "        z = self.reparameterize(mu, logvar)\n",
    "        return self.decode(z), mu, logvar\n",
    "\n",
    "vae_weights_path = os.path.join(get_pretrained_dir(), \"mnist_vae.pt\")\n",
    "vae_model = VAE().to(dtype=dtype, device=device)\n",
    "vae_state_dict = torch.load(vae_weights_path, map_location=device, weights_only=True)\n",
    "vae_model.load_state_dict(vae_state_dict);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now define the scoring function that maps digits to scores. The function below prefers the digit '3'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2025-03-08T21:35:53.546929800Z"
    }
   },
   "outputs": [],
   "source": [
    "def score(y):\n",
    "    \"\"\"Returns a 'score' for each digit from 0 to 9. It is modeled as a squared exponential\n",
    "    centered at the digit '3'.\n",
    "    \"\"\"\n",
    "    return torch.exp(-2 * (y - 3) ** 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given the scoring function, we can now write our overall objective, which as discussed above, starts with an image and outputs a score. Let's say the objective computes the expected score given the probabilities from the classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2025-03-08T21:35:53.548930300Z"
    }
   },
   "outputs": [],
   "source": [
    "def score_image(x):\n",
    "    \"\"\"The input x is an image and an expected score\n",
    "    based on the CNN classifier and the scoring\n",
    "    function is returned.\n",
    "    \"\"\"\n",
    "    with torch.no_grad():\n",
    "        probs = torch.exp(cnn_model(x))  # b x 10\n",
    "        scores = score(\n",
    "            torch.arange(10, device=device, dtype=dtype)\n",
    "        ).expand(probs.shape)\n",
    "    return (probs * scores).sum(dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we define a helper function `decode` that takes as input the parameters `mu` and `logvar` of the variational distribution and performs reparameterization and the decoding. We use batched Bayesian optimization to search over the parameters `mu` and `logvar`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2025-03-08T21:35:53.551043400Z"
    }
   },
   "outputs": [],
   "source": [
    "def decode(train_x):\n",
    "    with torch.no_grad():\n",
    "        decoded = vae_model.decode(train_x)\n",
    "    return decoded.view(train_x.shape[0], 1, 28, 28)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Model initialization and initial random batch\n",
    "\n",
    "We use a `SingleTaskGP` to model the score of an image generated by a latent representation. The model is initialized with points drawn from $[-6, 6]^{20}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-08T21:35:53.554042900Z",
     "start_time": "2025-03-08T21:35:53.553042400Z"
    }
   },
   "outputs": [],
   "source": [
    "from botorch.models import SingleTaskGP\n",
    "from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood\n",
    "from botorch.utils.transforms import normalize, unnormalize\n",
    "from botorch.models.transforms import Standardize, Normalize\n",
    "\n",
    "d = 20\n",
    "bounds = torch.tensor([[-6.0] * d, [6.0] * d], device=device, dtype=dtype)\n",
    "\n",
    "\n",
    "def gen_initial_data(n=5):\n",
    "    # generate training data\n",
    "    train_x = unnormalize(\n",
    "        torch.rand(n, d, device=device, dtype=dtype),\n",
    "        bounds=bounds\n",
    "    )\n",
    "    train_obj = score_image(decode(train_x)).unsqueeze(-1)\n",
    "    best_observed_value = train_obj.max().item()\n",
    "    return train_x, train_obj, best_observed_value\n",
    "\n",
    "\n",
    "def get_fitted_model(train_x, train_obj, state_dict=None):\n",
    "    # initialize and fit model\n",
    "    model = SingleTaskGP(\n",
    "        train_X=normalize(train_x, bounds),\n",
    "        train_Y=train_obj,\n",
    "        outcome_transform=Standardize(m=1)\n",
    "    )\n",
    "    if state_dict is not None:\n",
    "        model.load_state_dict(state_dict)\n",
    "    mll = ExactMarginalLogLikelihood(model.likelihood, model)\n",
    "    mll.to(train_x)\n",
    "    fit_gpytorch_mll(mll)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define a helper function that performs the essential BO step\n",
    "The helper function below takes an acquisition function as an argument, optimizes it, and returns the batch $\\{x_1, x_2, \\ldots x_q\\}$ along with the observed function values. For this example, we'll use a small batch of $q=3$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-08T21:35:53.555043700Z",
     "start_time": "2025-03-08T21:35:53.554042900Z"
    }
   },
   "outputs": [],
   "source": [
    "from botorch.optim import optimize_acqf\n",
    "\n",
    "\n",
    "BATCH_SIZE = 3 if not SMOKE_TEST else 2\n",
    "NUM_RESTARTS = 10 if not SMOKE_TEST else 2\n",
    "RAW_SAMPLES = 256 if not SMOKE_TEST else 4\n",
    "\n",
    "\n",
    "def optimize_acqf_and_get_observation(acq_func):\n",
    "    \"\"\"Optimizes the acquisition function, and returns a\n",
    "    new candidate and a noisy observation\"\"\"\n",
    "\n",
    "    # optimize\n",
    "    candidates, _ = optimize_acqf(\n",
    "        acq_function=acq_func,\n",
    "        bounds=torch.stack(\n",
    "            [\n",
    "                torch.zeros(d, dtype=dtype, device=device),\n",
    "                torch.ones(d, dtype=dtype, device=device),\n",
    "            ]\n",
    "        ),\n",
    "        q=BATCH_SIZE,\n",
    "        num_restarts=NUM_RESTARTS,\n",
    "        raw_samples=RAW_SAMPLES,\n",
    "    )\n",
    "\n",
    "    # observe new values\n",
    "    new_x = unnormalize(candidates.detach(), bounds=bounds)\n",
    "    new_obj = score_image(decode(new_x)).unsqueeze(-1)\n",
    "    return new_x, new_obj"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Perform Bayesian Optimization loop with qEI\n",
    "The Bayesian optimization \"loop\" for a batch size of $q$ simply iterates the following steps: (1) given a surrogate model, choose a batch of points $\\{x_1, x_2, \\ldots x_q\\}$, (2) observe $f(x)$ for each $x$ in the batch, and (3) update the surrogate model. We run `N_BATCH=75` iterations. The acquisition function is approximated using `MC_SAMPLES=2048` samples. We also initialize the model with 5 randomly drawn points."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2025-03-08T21:35:53.556043300Z"
    }
   },
   "outputs": [],
   "source": [
    "from botorch import fit_gpytorch_mll\n",
    "from botorch.acquisition.monte_carlo import qExpectedImprovement\n",
    "from botorch.sampling.normal import SobolQMCNormalSampler\n",
    "\n",
    "seed = 1\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "N_BATCH = 25 if not SMOKE_TEST else 3\n",
    "best_observed = []\n",
    "\n",
    "# call helper function to initialize model\n",
    "train_x, train_obj, best_value = gen_initial_data(n=5)\n",
    "best_observed.append(best_value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are now ready to run the BO loop (this make take a few minutes, depending on your machine)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2025-03-08T21:35:53.557042700Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "\n",
    "print(f\"\\nRunning BO \", end=\"\")\n",
    "\n",
    "state_dict = None\n",
    "# run N_BATCH rounds of BayesOpt after the initial random batch\n",
    "for iteration in range(N_BATCH):\n",
    "\n",
    "    # fit the model\n",
    "    model = get_fitted_model(\n",
    "        train_x=train_x,\n",
    "        train_obj=train_obj,\n",
    "        state_dict=state_dict,\n",
    "    )\n",
    "\n",
    "    # define the qNEI acquisition function\n",
    "    qEI = qExpectedImprovement(\n",
    "        model=model, best_f=train_obj.max()\n",
    "    )\n",
    "\n",
    "    # optimize and get new observation\n",
    "    new_x, new_obj = optimize_acqf_and_get_observation(qEI)\n",
    "\n",
    "    # update training points\n",
    "    train_x = torch.cat((train_x, new_x))\n",
    "    train_obj = torch.cat((train_obj, new_obj))\n",
    "\n",
    "    # update progress\n",
    "    best_value = train_obj.max().item()\n",
    "    best_observed.append(best_value)\n",
    "\n",
    "    state_dict = model.state_dict()\n",
    "\n",
    "    print(\".\", end=\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "EI recommends the best point observed so far. We can visualize what the images corresponding to recommended points *would have* been if the BO process ended at various times. Here, we show the progress of the algorithm by examining the images at 0%, 10%, 25%, 50%, 75%, and 100% completion. The first image is the best image found through the initial random batch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2025-03-08T21:35:53.558043600Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(1, 6, figsize=(14, 14))\n",
    "percentages = np.array([0, 10, 25, 50, 75, 100], dtype=np.float32)\n",
    "inds = (N_BATCH * BATCH_SIZE * percentages / 100 + 4).astype(int)\n",
    "\n",
    "for i, ax in enumerate(ax.flat):\n",
    "    b = torch.argmax(score_image(decode(train_x[: inds[i], :])), dim=0)\n",
    "    img = decode(train_x[b].view(1, -1)).squeeze().cpu()\n",
    "    ax.imshow(img, alpha=0.8, cmap=\"gray\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
