{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scalable Exact GP Posterior Sampling using Contour Integral Quadrature\n",
    "\n",
    "This notebook demonstrates the most simple usage of contour integral quadrature with msMINRES as described [here](https://arxiv.org/pdf/2006.11267.pdf) to sample from the predictive distribution of an exact GP.\n",
    "\n",
    "Note that to achieve results where Cholesky would run the GPU out of memory, you'll either need to have KeOps installed (see our KeOps tutorial in this same folder), or use the `checkpoint_kernel` beta feature. Despite this, on this relatively simple example with 1000 training points but seeing to sample at 20000 test points in 1D, we will achieve significant speed ups over Cholesky."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import warnings\n",
    "warnings.simplefilter(\"ignore\", gpytorch.utils.warnings.NumericalWarning)\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training data is 11 points in [0,1] inclusive regularly spaced\n",
    "train_x = torch.linspace(0, 1, 1000)\n",
    "# True function is sin(2*pi*x) with Gaussian noise\n",
    "train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Are we running with KeOps?\n",
    "\n",
    "If you have KeOps, change the below flag to `True` to run with a significantly larger test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "HAVE_KEOPS = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define an Exact GP Model and train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExactGPModel(gpytorch.models.ExactGP):\n",
    "    def __init__(self, train_x, train_y, likelihood):\n",
    "        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        \n",
    "        if HAVE_KEOPS:\n",
    "            self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.keops.RBFKernel())\n",
    "        else:\n",
    "            self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "    \n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "# initialize likelihood and model\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
    "model = ExactGPModel(train_x, train_y, likelihood)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    train_x = train_x.cuda()\n",
    "    train_y = train_y.cuda()\n",
    "    model = model.cuda()\n",
    "    likelihood = likelihood.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 1/50 - Loss: 0.861   lengthscale: 0.693   noise: 0.693\n",
      "Iter 2/50 - Loss: 0.814   lengthscale: 0.644   noise: 0.644\n",
      "Iter 3/50 - Loss: 0.763   lengthscale: 0.598   noise: 0.598\n",
      "Iter 4/50 - Loss: 0.718   lengthscale: 0.554   noise: 0.554\n",
      "Iter 5/50 - Loss: 0.666   lengthscale: 0.513   noise: 0.513\n",
      "Iter 6/50 - Loss: 0.618   lengthscale: 0.474   noise: 0.474\n",
      "Iter 7/50 - Loss: 0.572   lengthscale: 0.439   noise: 0.437\n",
      "Iter 8/50 - Loss: 0.530   lengthscale: 0.408   noise: 0.402\n",
      "Iter 9/50 - Loss: 0.486   lengthscale: 0.380   noise: 0.370\n",
      "Iter 10/50 - Loss: 0.452   lengthscale: 0.355   noise: 0.339\n",
      "Iter 11/50 - Loss: 0.415   lengthscale: 0.334   noise: 0.311\n",
      "Iter 12/50 - Loss: 0.376   lengthscale: 0.316   noise: 0.285\n",
      "Iter 13/50 - Loss: 0.331   lengthscale: 0.301   noise: 0.261\n",
      "Iter 14/50 - Loss: 0.293   lengthscale: 0.288   noise: 0.238\n",
      "Iter 15/50 - Loss: 0.258   lengthscale: 0.276   noise: 0.217\n",
      "Iter 16/50 - Loss: 0.219   lengthscale: 0.266   noise: 0.198\n",
      "Iter 17/50 - Loss: 0.188   lengthscale: 0.258   noise: 0.181\n",
      "Iter 18/50 - Loss: 0.146   lengthscale: 0.250   noise: 0.165\n",
      "Iter 19/50 - Loss: 0.109   lengthscale: 0.244   noise: 0.150\n",
      "Iter 20/50 - Loss: 0.081   lengthscale: 0.238   noise: 0.136\n",
      "Iter 21/50 - Loss: 0.042   lengthscale: 0.234   noise: 0.124\n",
      "Iter 22/50 - Loss: 0.007   lengthscale: 0.230   noise: 0.113\n",
      "Iter 23/50 - Loss: -0.020   lengthscale: 0.227   noise: 0.103\n",
      "Iter 24/50 - Loss: -0.043   lengthscale: 0.224   noise: 0.094\n",
      "Iter 25/50 - Loss: -0.075   lengthscale: 0.222   noise: 0.085\n",
      "Iter 26/50 - Loss: -0.094   lengthscale: 0.221   noise: 0.078\n",
      "Iter 27/50 - Loss: -0.108   lengthscale: 0.221   noise: 0.071\n",
      "Iter 28/50 - Loss: -0.149   lengthscale: 0.221   noise: 0.065\n",
      "Iter 29/50 - Loss: -0.160   lengthscale: 0.221   noise: 0.060\n",
      "Iter 30/50 - Loss: -0.174   lengthscale: 0.222   noise: 0.055\n",
      "Iter 31/50 - Loss: -0.191   lengthscale: 0.223   noise: 0.050\n",
      "Iter 32/50 - Loss: -0.205   lengthscale: 0.224   noise: 0.047\n",
      "Iter 33/50 - Loss: -0.205   lengthscale: 0.226   noise: 0.043\n",
      "Iter 34/50 - Loss: -0.211   lengthscale: 0.228   noise: 0.040\n",
      "Iter 35/50 - Loss: -0.211   lengthscale: 0.230   noise: 0.038\n",
      "Iter 36/50 - Loss: -0.208   lengthscale: 0.232   noise: 0.035\n",
      "Iter 37/50 - Loss: -0.219   lengthscale: 0.235   noise: 0.034\n",
      "Iter 38/50 - Loss: -0.199   lengthscale: 0.237   noise: 0.032\n",
      "Iter 39/50 - Loss: -0.198   lengthscale: 0.240   noise: 0.031\n",
      "Iter 40/50 - Loss: -0.201   lengthscale: 0.243   noise: 0.030\n",
      "Iter 41/50 - Loss: -0.204   lengthscale: 0.246   noise: 0.029\n",
      "Iter 42/50 - Loss: -0.196   lengthscale: 0.249   noise: 0.028\n",
      "Iter 43/50 - Loss: -0.196   lengthscale: 0.252   noise: 0.028\n",
      "Iter 44/50 - Loss: -0.197   lengthscale: 0.254   noise: 0.028\n",
      "Iter 45/50 - Loss: -0.183   lengthscale: 0.256   noise: 0.028\n",
      "Iter 46/50 - Loss: -0.183   lengthscale: 0.258   noise: 0.028\n",
      "Iter 47/50 - Loss: -0.195   lengthscale: 0.261   noise: 0.028\n",
      "Iter 48/50 - Loss: -0.197   lengthscale: 0.263   noise: 0.029\n",
      "Iter 49/50 - Loss: -0.193   lengthscale: 0.265   noise: 0.029\n",
      "Iter 50/50 - Loss: -0.203   lengthscale: 0.268   noise: 0.030\n"
     ]
    }
   ],
   "source": [
    "# this is for running the notebook in our testing framework\n",
    "import os\n",
    "smoke_test = ('CI' in os.environ)\n",
    "training_iter = 2 if smoke_test else 50\n",
    "\n",
    "# Find optimal model hyperparameters\n",
    "model.train()\n",
    "likelihood.train()\n",
    "\n",
    "# Use the adam optimizer\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters\n",
    "\n",
    "# \"Loss\" for GPs - the marginal log likelihood\n",
    "mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
    "\n",
    "for i in range(training_iter):\n",
    "    # Zero gradients from previous iteration\n",
    "    optimizer.zero_grad()\n",
    "    # Output from model\n",
    "    output = model(train_x)\n",
    "    # Calc loss and backprop gradients\n",
    "    loss = -mll(output, train_y)\n",
    "    loss.backward()\n",
    "    print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (\n",
    "        i + 1, training_iter, loss.item(),\n",
    "        model.covar_module.base_kernel.lengthscale.item(),\n",
    "        model.likelihood.noise.item()\n",
    "    ))\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define test set\n",
    "\n",
    "If we have KeOps installed, we'll test on 50000 points instead of 10000."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10000])\n"
     ]
    }
   ],
   "source": [
    "if HAVE_KEOPS:\n",
    "    test_n = 50000\n",
    "else:\n",
    "    test_n = 10000\n",
    "\n",
    "test_x = torch.linspace(0, 1, test_n)\n",
    "if torch.cuda.is_available():\n",
    "    test_x = test_x.cuda()\n",
    "print(test_x.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Draw a sample with CIQ\n",
    "\n",
    "To do this, we just add the `ciq_samples` setting to the rsample call. We additionally demonstrate all relevant settings for controlling Contour Integral Quadrature:\n",
    "\n",
    "- The `ciq_samples` setting determines whether or not to use CIQ\n",
    "- The `num_contour_quadrature` setting controls the number of quadrature sites (Q in the paper).\n",
    "- The `minres_tolerance` setting controls the error we tolerate from minres (here, <0.01%).\n",
    "\n",
    "Note that, of these settings, increase num_contour_quadrature is unlikely to improve performance. As Theorem 1 from the paper demonstrates, virtually all of the error in this method is controlled by minres_tolerance. Here, we use a quite tight tolerance for minres."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running with CIQ\n",
      "CPU times: user 711 ms, sys: 38.4 ms, total: 749 ms\n",
      "Wall time: 677 ms\n",
      "Running with Cholesky\n",
      "CPU times: user 1min 36s, sys: 892 ms, total: 1min 37s\n",
      "Wall time: 30.9 s\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "model.train()\n",
    "likelihood.train()\n",
    "\n",
    "# Get into evaluation (predictive posterior) mode\n",
    "model.eval()\n",
    "likelihood.eval()\n",
    "\n",
    "# Test points are regularly spaced along [0,1]\n",
    "# Make predictions by feeding model through likelihood\n",
    "\n",
    "test_x.requires_grad_(True)\n",
    "\n",
    "with torch.no_grad():\n",
    "    observed_pred = likelihood(model(test_x))\n",
    "    \n",
    "    # All relevant settings for using CIQ.\n",
    "    #   ciq_samples(True) - Use CIQ for sampling\n",
    "    #   num_contour_quadrature(10) -- Use 10 quadrature sites (Q in the paper)\n",
    "    #   minres_tolerance -- error tolerance from minres (here, <0.01%).\n",
    "    print(\"Running with CIQ\")\n",
    "    with gpytorch.settings.ciq_samples(True), gpytorch.settings.num_contour_quadrature(10), gpytorch.settings.minres_tolerance(1e-4):\n",
    "        %time y_samples = observed_pred.rsample()\n",
    "    \n",
    "    print(\"Running with Cholesky\")\n",
    "    # Make sure we use Cholesky\n",
    "    with gpytorch.settings.fast_computations(covar_root_decomposition=False):\n",
    "        %time y_samples = observed_pred.rsample()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
