{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GPyTorch Regression With KeOps\n",
    "\n",
    "## Introduction\n",
    "\n",
    "[KeOps](https://github.com/getkeops/keops) is a recently released software package for fast kernel operations that integrates wih PyTorch. We can use the ability of `KeOps` to perform efficient kernel matrix multiplies on the GPU to integrate with the rest of GPyTorch.\n",
    "\n",
    "In this tutorial, we'll demonstrate how to integrate the kernel matmuls of `KeOps` with all of the bells of whistles of GPyTorch, including things like our preconditioning for conjugate gradients.\n",
    "\n",
    "In this notebook, we will train an exact GP on `3droad`, which has hundreds of thousands of data points. Together, the highly optimized matmuls of `KeOps` combined with algorithmic speed improvements like preconditioning allow us to train on a dataset like this in a matter of minutes using only a single GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "import tqdm.notebook as tqdm\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Downloading Data\n",
    "We will be using the 3droad UCI dataset which contains a total of 434,874 data points. We will split the data in half for training and half for testing.\n",
    "\n",
    "The next cell will download this dataset from a Google drive and load it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import urllib.request\n",
    "import os.path\n",
    "from scipy.io import loadmat\n",
    "from math import floor\n",
    "\n",
    "if not os.path.isfile('../3droad.mat'):\n",
    "    print('Downloading \\'3droad\\' UCI dataset...')\n",
    "    urllib.request.urlretrieve('https://www.dropbox.com/s/f6ow1i59oqx05pl/3droad.mat?dl=1', '../3droad.mat')\n",
    "    \n",
    "data = torch.Tensor(loadmat('../3droad.mat')['data'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num train: 217437\n",
      "Num test: 217437\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "N = data.shape[0]\n",
    "# make train/val/test\n",
    "n_train = int(0.5 * N)\n",
    "\n",
    "train_x, train_y = data[:n_train, :-1], data[:n_train, -1]\n",
    "test_x, test_y = data[n_train:, :-1], data[n_train:, -1]\n",
    "\n",
    "# normalize features\n",
    "mean = train_x.mean(dim=-2, keepdim=True)\n",
    "std = train_x.std(dim=-2, keepdim=True) + 1e-6 # prevent dividing by 0\n",
    "train_x = (train_x - mean) / std\n",
    "test_x = (test_x - mean) / std\n",
    "\n",
    "# normalize labels\n",
    "mean, std = train_y.mean(),train_y.std()\n",
    "train_y = (train_y - mean) / std\n",
    "test_y = (test_y - mean) / std\n",
    "\n",
    "# make continguous\n",
    "train_x, train_y = train_x.contiguous(), train_y.contiguous()\n",
    "test_x, test_y = test_x.contiguous(), test_y.contiguous()\n",
    "\n",
    "output_device = torch.device('cuda:0')\n",
    "\n",
    "train_x, train_y = train_x.to(output_device), train_y.to(output_device)\n",
    "test_x, test_y = test_x.to(output_device), test_y.to(output_device)\n",
    "\n",
    "print(\n",
    "    f\"Num train: {train_y.size(-1)}\\n\"\n",
    "    f\"Num test: {test_y.size(-1)}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using KeOps with a GPyTorch Model\n",
    "\n",
    "Using KeOps with one of our pre built kernels is as straightforward as swapping the kernel out. For example, in the cell below, we copy the simple GP from our basic tutorial notebook, and swap out `gpytorch.kernels.MaternKernel` for `gpytorch.kernels.keops.MaternKernel`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We will use the simplest form of GP model, exact inference\n",
    "class ExactGPModel(gpytorch.models.ExactGP):\n",
    "    def __init__(self, train_x, train_y, likelihood):\n",
    "        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.keops.MaternKernel(nu=2.5))\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "# initialize likelihood and model\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()\n",
    "model = ExactGPModel(train_x, train_y, likelihood).cuda()\n",
    "\n",
    "# Because we know some properties about this dataset,\n",
    "# we will initialize the lengthscale to be somewhat small\n",
    "# This step isn't necessary, but it will help the model converge faster.\n",
    "model.covar_module.base_kernel.lengthscale = 0.05"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "691194d2d51e4d389fef9f0f7cb34f6b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training:   0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Find optimal model hyperparameters\n",
    "model.train()\n",
    "likelihood.train()\n",
    "\n",
    "# Use the adam optimizer\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters\n",
    "\n",
    "# \"Loss\" for GPs - the marginal log likelihood\n",
    "mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)\n",
    "\n",
    "import time\n",
    "training_iter = 25\n",
    "iterator = tqdm.tqdm(range(training_iter), desc=\"Training\")\n",
    "for i in iterator:\n",
    "    start_time = time.time()\n",
    "    # Zero gradients from previous iteration\n",
    "    optimizer.zero_grad()\n",
    "    # Output from model\n",
    "    output = model(train_x)\n",
    "    # Calc loss and backprop gradients\n",
    "    loss = -mll(output, train_y)\n",
    "    print_values = dict(\n",
    "        loss=loss.item(),\n",
    "        ls=model.covar_module.base_kernel.lengthscale.norm().item(),\n",
    "        os=model.covar_module.outputscale.item(),\n",
    "        noise=model.likelihood.noise.item(),\n",
    "        mu=model.mean_module.constant.item(),\n",
    "    )\n",
    "    iterator.set_postfix(**print_values)\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get into evaluation (predictive posterior) mode\n",
    "model.eval()\n",
    "\n",
    "# Test points are regularly spaced along [0,1]\n",
    "# Make predictions by feeding model through likelihood\n",
    "with torch.no_grad(), gpytorch.settings.fast_pred_var():\n",
    "    observed_pred = model.likelihood(model(test_x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute RMSE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE: 0.138\n"
     ]
    }
   ],
   "source": [
    "rmse = (observed_pred.mean - test_y).square().mean().sqrt().item()\n",
    "print(f\"RMSE: {rmse:.3f}\")"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
