{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1fe536ed",
   "metadata": {},
   "source": [
    "# Stiff ODE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "324996ff",
   "metadata": {},
   "source": [
    "This example demonstrates the use of implicit integrators to handle stiff dynamical systems. In this case we consider the Robertson problem.\n",
    "\n",
    "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/stiff_ode.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6d6bdf63",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import diffrax\n",
    "import equinox as eqx  # https://github.com/patrick-kidger/equinox\n",
    "import jax\n",
    "import jax.numpy as jnp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c389c1dc",
   "metadata": {},
   "source": [
    "Using 64-bit precision is important when solving problems with tolerances of `1e-8` (or smaller)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "327c1eda",
   "metadata": {},
   "outputs": [],
   "source": [
    "jax.config.update(\"jax_enable_x64\", True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "62c84f51",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Robertson(eqx.Module):\n",
    "    k1: float\n",
    "    k2: float\n",
    "    k3: float\n",
    "\n",
    "    def __call__(self, t, y, args):\n",
    "        f0 = -self.k1 * y[0] + self.k3 * y[1] * y[2]\n",
    "        f1 = self.k1 * y[0] - self.k2 * y[1] ** 2 - self.k3 * y[1] * y[2]\n",
    "        f2 = self.k2 * y[1] ** 2\n",
    "        return jnp.stack([f0, f1, f2])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa95b544",
   "metadata": {},
   "source": [
    "One should almost always use adaptive step sizes when using implicit integrators. This is so that the step size can be reduced if the nonlinear solve (inside the implicit solve) doesn't converge.\n",
    "\n",
    "Note that the solver takes a `root_finder` argument, e.g. `Kvaerno5(root_finder=VeryChord())`. If you want to optimise performance then you can try adjusting the error tolerances, kappa value, and maximum number of steps for the nonlinear solver."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e15519bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def main(k1, k2, k3):\n",
    "    robertson = Robertson(k1, k2, k3)\n",
    "    terms = diffrax.ODETerm(robertson)\n",
    "    t0 = 0.0\n",
    "    t1 = 100.0\n",
    "    y0 = jnp.array([1.0, 0.0, 0.0])\n",
    "    dt0 = 0.0002\n",
    "    solver = diffrax.Kvaerno5()\n",
    "    saveat = diffrax.SaveAt(ts=jnp.array([0.0, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2]))\n",
    "    stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)\n",
    "    sol = diffrax.diffeqsolve(\n",
    "        terms,\n",
    "        solver,\n",
    "        t0,\n",
    "        t1,\n",
    "        dt0,\n",
    "        y0,\n",
    "        saveat=saveat,\n",
    "        stepsize_controller=stepsize_controller,\n",
    "    )\n",
    "    return sol"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ac4ccb4",
   "metadata": {},
   "source": [
    "Do one iteration to JIT compile everything. Then time the second iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7c13ba1c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results:\n",
      "t=0.0, y=[1.0, 0.0, 0.0]\n",
      "t=0.0001, y=[0.9999960000079533, 3.983706533880144e-06, 1.6285512749146082e-08]\n",
      "t=0.001, y=[0.9999600015663405, 2.9190972128800855e-05, 1.0807461530666215e-05]\n",
      "t=0.01, y=[0.9996006826927243, 3.6450525072358624e-05, 0.0003628667822034639]\n",
      "t=0.1, y=[0.9960777367939997, 3.5804375234045534e-05, 0.003886458830766306]\n",
      "t=1.0, y=[0.9664587937089412, 3.0745881216629206e-05, 0.03351046040984223]\n",
      "t=10.0, y=[0.8413689033339126, 1.623370548364245e-05, 0.15861486296060393]\n",
      "t=100.0, y=[0.6172348816081134, 6.153591196187991e-06, 0.382758964800691]\n",
      "Took 34 steps in 0.0003752708435058594 seconds.\n"
     ]
    }
   ],
   "source": [
    "main(0.04, 3e7, 1e4)\n",
    "\n",
    "start = time.time()\n",
    "sol = main(0.04, 3e7, 1e4)\n",
    "end = time.time()\n",
    "\n",
    "print(\"Results:\")\n",
    "for ti, yi in zip(sol.ts, sol.ys):\n",
    "    print(f\"t={ti.item()}, y={yi.tolist()}\")\n",
    "print(f\"Took {sol.stats['num_steps']} steps in {end - start} seconds.\")"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,py:light"
  },
  "kernelspec": {
   "display_name": "jax030",
   "language": "python",
   "name": "jax030"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
