{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1fe536ed",
   "metadata": {},
   "source": [
    "# Computing second-order sensitivities"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "598ab169-05d8-4733-a6cc-9fa91aa92198",
   "metadata": {},
   "source": [
    "This example demonstrates how to compute the Hessian of a differential equation solve.\n",
    "\n",
    "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/hessian.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6d6bdf63",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((Array(3.9131193, dtype=float32, weak_type=True),\n",
       "  Array(-2.374867, dtype=float32, weak_type=True)),\n",
       " (Array(-2.3748531, dtype=float32, weak_type=True),\n",
       "  Array(1.688472, dtype=float32, weak_type=True)))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from diffrax import diffeqsolve, ODETerm, Tsit5\n",
    "\n",
    "\n",
    "def vector_field(t, y, args):\n",
    "    prey, predator = y\n",
    "    α, β, γ, δ = args\n",
    "    d_prey = α * prey - β * prey * predator\n",
    "    d_predator = -γ * predator + δ * prey * predator\n",
    "    d_y = d_prey, d_predator\n",
    "    return d_y\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "@jax.hessian\n",
    "def run(y0):\n",
    "    term = ODETerm(vector_field)\n",
    "    solver = Tsit5(scan_kind=\"bounded\")\n",
    "    t0 = 0\n",
    "    t1 = 140\n",
    "    dt0 = 0.1\n",
    "    args = (0.1, 0.02, 0.4, 0.02)\n",
    "    sol = diffeqsolve(term, solver, t0, t1, dt0, y0, args=args)\n",
    "    ((prey,), _) = sol.ys\n",
    "    return prey\n",
    "\n",
    "\n",
    "y0 = (jnp.array(10.0), jnp.array(10.0))\n",
    "run(y0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3ec6532-5b0a-4e4c-af33-bef58c0a7319",
   "metadata": {},
   "source": [
    "Note the use of the `scan_kind` argument to `Tsit5`. By default, Diffrax internally uses constructs that are optimised specifically for first-order reverse-mode autodifferentiation. This argument is needed to switch to a different implementation that is compatible with higher-order autodiff. (In this case: for the loop-over-stages in the Runge--Kutta solver.)\n",
    "\n",
    "In similar fashion, if using `saveat=SaveAt(ts=...)` (or a handful of other esoteric cases) then you will need to pass `adjoint=DirectAdjoint()`. (In this case: for the loop-over-saving output.)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,py:light"
  },
  "kernelspec": {
   "display_name": "py39",
   "language": "python",
   "name": "py39"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
