{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b5f40724",
   "metadata": {},
   "source": [
    "# Geometric Kernels\n",
    "\n",
    "[GPJax](https://github.com/JaxGaussianProcesses/GPJax) is a Python package for working wit Gaussian processes in JAX. This notebook highlights how Geometric Kernels can be integrated into the functionality provided in GPJax."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "68c2e9c0",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "2023-04-24 18:13:17.425780: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n",
      "2023-04-24 18:13:18.138298: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
      "2023-04-24 18:13:18.138395: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
      "2023-04-24 18:13:18.138407: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
      "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
      "INFO: Using numpy backend\n"
     ]
    }
   ],
   "source": [
    "import gpjax as gpx\n",
    "import jax.numpy as jnp\n",
    "import jax.random as jr\n",
    "from jax.config import config\n",
    "import geometric_kernels.jax \n",
    "from geometric_kernels.frontends.jax.gpjax import GPJaxGeometricKernel\n",
    "from geometric_kernels.kernels import MaternKarhunenLoeveKernel\n",
    "from geometric_kernels.spaces import Mesh\n",
    "import jax\n",
    "import meshzoo         \n",
    "\n",
    "config.update(\"jax_enable_x64\", True)\n",
    "key = jr.PRNGKey(123)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90869a42",
   "metadata": {},
   "source": [
    "## Data\n",
    "\n",
    "We'll now define a dataset that we'll seek to model. The data support used in this example is an icose sphere from the [MeshZoo](https://github.com/meshpro/meshzoo) library. A Matérn kernel is then defined on the sphere and a single draw is taken from the Gaussian process' prior distribution at a random set of points to give us a response variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "868f42c1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "resolution = 3\n",
    "num_data = 25\n",
    "vertices, faces = meshzoo.icosa_sphere(resolution)\n",
    "mesh = Mesh(vertices, faces)\n",
    "\n",
    "truncation_level = 20\n",
    "base_kernel = MaternKarhunenLoeveKernel(mesh, truncation_level)\n",
    "geometric_kernel = GPJaxGeometricKernel(base_kernel)\n",
    "\n",
    "init_params = geometric_kernel.init_params(key)\n",
    "\n",
    "\n",
    "def get_data():\n",
    "    _X = jr.randint(key, minval=0, maxval=mesh.num_vertices, shape=(num_data, 1))\n",
    "    _K = geometric_kernel.gram(init_params, _X)\n",
    "    _L = jnp.linalg.cholesky(_K.to_dense() + jnp.eye(_K.shape[0]) * 1e-6)\n",
    "    _y = _L @ jr.normal(key, (num_data,))\n",
    "    return _X, _y\n",
    "\n",
    "\n",
    "X, y = get_data()\n",
    "X_test = jnp.arange(mesh.num_vertices).reshape(mesh.num_vertices, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "580034b6",
   "metadata": {},
   "source": [
    "## Model specification\n",
    "\n",
    "A model can now be defined. We'll purposefully keep this section brief as the workflow is identical to that of a regular Gaussian process regression model that is detailed in [full](https://gpjax.readthedocs.io/en/latest/examples/regression.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f02d9a56",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_819/1632642032.py:1: DeprecatedWarning: Dataset is deprecated as of 0.5.5 and will be removed in 0.6.0. Use JaxUtils for a Dataset object\n",
      "  data = gpx.Dataset(X=X, y=y.reshape(-1, 1))\n",
      "/tmp/ipykernel_819/1632642032.py:4: DeprecatedWarning: add_parameter is deprecated as of 0.5.6 and will be removed in 0.6.0. Use method from jaxutils.config instead.\n",
      "  gpx.config.add_parameter(\"nu\", gpx.config.Softplus)\n"
     ]
    }
   ],
   "source": [
    "data = gpx.Dataset(X=X, y=y.reshape(-1, 1))\n",
    "\n",
    "prior = gpx.Prior(kernel=geometric_kernel)\n",
    "gpx.config.add_parameter(\"nu\", gpx.config.Softplus)\n",
    "\n",
    "likelihood = gpx.likelihoods.Gaussian(num_datapoints=num_data)\n",
    "\n",
    "posterior = likelihood * prior"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ee4d424",
   "metadata": {},
   "source": [
    "As with a regular conjugate Gaussian process, the marginal log-likelihood is tractable and can be evaluated using the posterior's `marginal_log_likelihood` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3f4ae6a4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(-23.18739116, dtype=float64)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "params, _, _ = gpx.initialise(posterior, key).unpack()\n",
    "\n",
    "posterior.marginal_log_likelihood(data)(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dec7bda6",
   "metadata": {},
   "source": [
    "Derivatives of the marginal log-likelihood can be taken."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "055ba1e4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'kernel': {'lengthscale': Array([0.49285908], dtype=float64), 'nu': Array([-0.58913405], dtype=float64)}, 'likelihood': {'obs_noise': Array([12.30148303], dtype=float64)}, 'mean_function': {}}\n"
     ]
    }
   ],
   "source": [
    "grads = jax.grad(posterior.marginal_log_likelihood(data, negative=True))(params)\n",
    "print(grads)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23a4e36e",
   "metadata": {},
   "source": [
    "Finally, the predictive posterior can be computed for making predictions at unseen points. Evaluating the predictive posterior distribution returns a multivariate Gaussian distribution for which we can compute the posterior mean and variance as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9cad6772",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "predictive_posterior = posterior.predict(params, data)(X_test)\n",
    "\n",
    "mu = predictive_posterior.mean()\n",
    "sigma2 = predictive_posterior.variance()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "848e191b-b4af-4549-9ca1-7890aa53a0ad",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "custom_cell_magics": "kql",
   "encoding": "# -*- coding: utf-8 -*-"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "vscode": {
   "interpreter": {
    "hash": "ac4e3e560a819fea325e71fb1f32d4b3eb3f4b4768b78485241e42962cf1d521"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
