{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9877bdcb",
   "metadata": {},
   "source": [
    "# Steady states"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a16df75",
   "metadata": {},
   "source": [
    "This example demonstrates how to use Diffrax to solve an ODE until it reaches a steady state. The key feature will be the use of event handling to detect that the steady state has been reached.\n",
    "\n",
    "In addition, for this example we need to backpropagate through the procedure of finding a steady state. We can do this efficiently using the implicit function theorem.\n",
    "\n",
    "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/steady_state.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7053a132",
   "metadata": {
    "execute": {
     "shell": {
      "execute_reply": "2022-07-15T17:49:27.190533+00:00"
     }
    },
    "iopub": {
     "execute_input": "2022-07-15T17:46:56.174218+00:00",
     "status": {
      "busy": "2022-07-15T17:46:56.173283+00:00",
      "idle": "2022-07-15T17:49:27.191890+00:00"
     }
    }
   },
   "outputs": [],
   "source": [
    "import diffrax\n",
    "import equinox as eqx  # https://github.com/patrick-kidger/equinox\n",
    "import jax.numpy as jnp\n",
    "import optax  # https://github.com/deepmind/optax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5a39737f",
   "metadata": {
    "execute": {
     "shell": {
      "execute_reply": "2022-07-15T17:49:27.200260+00:00"
     }
    },
    "iopub": {
     "execute_input": "2022-07-15T17:49:27.194682+00:00",
     "status": {
      "busy": "2022-07-15T17:49:27.194211+00:00",
      "idle": "2022-07-15T17:49:27.201694+00:00"
     }
    }
   },
   "outputs": [],
   "source": [
    "class ExponentialDecayToSteadyState(eqx.Module):\n",
    "    steady_state: float\n",
    "\n",
    "    def __call__(self, t, y, args):\n",
    "        return self.steady_state - y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "525b5430",
   "metadata": {
    "execute": {
     "shell": {
      "execute_reply": "2022-07-15T17:49:27.210168+00:00"
     }
    },
    "iopub": {
     "execute_input": "2022-07-15T17:49:27.203780+00:00",
     "status": {
      "busy": "2022-07-15T17:49:27.202937+00:00",
      "idle": "2022-07-15T17:49:27.211528+00:00"
     }
    }
   },
   "outputs": [],
   "source": [
    "def loss(model, target_steady_state):\n",
    "    term = diffrax.ODETerm(model)\n",
    "    solver = diffrax.Tsit5()\n",
    "    t0 = 0\n",
    "    t1 = jnp.inf\n",
    "    dt0 = None\n",
    "    y0 = 1.0\n",
    "    max_steps = None\n",
    "    controller = diffrax.PIDController(rtol=1e-3, atol=1e-6)\n",
    "    event = diffrax.SteadyStateEvent()\n",
    "    adjoint = diffrax.ImplicitAdjoint()\n",
    "    # This combination of event, t1, max_steps, adjoint is particularly\n",
    "    # natural: we keep integration forever until we hit the event, with\n",
    "    # no maximum time or number of steps. Backpropagation happens via\n",
    "    # the implicit function theorem.\n",
    "    sol = diffrax.diffeqsolve(\n",
    "        term,\n",
    "        solver,\n",
    "        t0,\n",
    "        t1,\n",
    "        dt0,\n",
    "        y0,\n",
    "        max_steps=max_steps,\n",
    "        stepsize_controller=controller,\n",
    "        discrete_terminating_event=event,\n",
    "        adjoint=adjoint,\n",
    "    )\n",
    "    (y1,) = sol.ys\n",
    "    return (y1 - target_steady_state) ** 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ec634466",
   "metadata": {
    "execute": {
     "shell": {
      "execute_reply": "2022-07-15T17:49:28.926507+00:00"
     }
    },
    "iopub": {
     "execute_input": "2022-07-15T17:49:27.214972+00:00",
     "status": {
      "busy": "2022-07-15T17:49:27.214240+00:00",
      "idle": "2022-07-15T17:49:28.927742+00:00"
     }
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step: 0 Steady State: 0.025839969515800476\n",
      "Step: 1 Steady State: 0.058249037712812424\n",
      "Step: 2 Steady State: 0.09451574087142944\n",
      "Step: 3 Steady State: 0.13270404934883118\n",
      "Step: 4 Steady State: 0.17144456505775452\n",
      "Step: 5 Steady State: 0.2097906768321991\n",
      "Step: 6 Steady State: 0.24709917604923248\n",
      "Step: 7 Steady State: 0.28294336795806885\n",
      "Step: 8 Steady State: 0.3170691728591919\n",
      "Step: 9 Steady State: 0.34933507442474365\n",
      "Step: 10 Steady State: 0.37968066334724426\n",
      "Step: 11 Steady State: 0.4081019163131714\n",
      "Step: 12 Steady State: 0.43463483452796936\n",
      "Step: 13 Steady State: 0.45934173464775085\n",
      "Step: 14 Steady State: 0.4823019802570343\n",
      "Step: 15 Steady State: 0.5035936236381531\n",
      "Step: 16 Steady State: 0.5233209133148193\n",
      "Step: 17 Steady State: 0.5415788888931274\n",
      "Step: 18 Steady State: 0.5584676265716553\n",
      "Step: 19 Steady State: 0.5740787982940674\n",
      "Step: 20 Steady State: 0.5885017514228821\n",
      "Step: 21 Steady State: 0.6018210053443909\n",
      "Step: 22 Steady State: 0.6141175627708435\n",
      "Step: 23 Steady State: 0.6254667043685913\n",
      "Step: 24 Steady State: 0.6359376907348633\n",
      "Step: 25 Steady State: 0.6455990076065063\n",
      "Step: 26 Steady State: 0.6545112729072571\n",
      "Step: 27 Steady State: 0.6627309322357178\n",
      "Step: 28 Steady State: 0.6703115701675415\n",
      "Step: 29 Steady State: 0.6773026585578918\n",
      "Step: 30 Steady State: 0.6837494373321533\n",
      "Step: 31 Steady State: 0.6896938681602478\n",
      "Step: 32 Steady State: 0.6951748728752136\n",
      "Step: 33 Steady State: 0.7002284526824951\n",
      "Step: 34 Steady State: 0.7048872113227844\n",
      "Step: 35 Steady State: 0.7091819047927856\n",
      "Step: 36 Steady State: 0.7131412029266357\n",
      "Step: 37 Steady State: 0.7167739868164062\n",
      "Step: 38 Steady State: 0.7201183438301086\n",
      "Step: 39 Steady State: 0.7231980562210083\n",
      "Step: 40 Steady State: 0.7260348796844482\n",
      "Step: 41 Steady State: 0.7286462187767029\n",
      "Step: 42 Steady State: 0.7310511469841003\n",
      "Step: 43 Steady State: 0.733269989490509\n",
      "Step: 44 Steady State: 0.7353137731552124\n",
      "Step: 45 Steady State: 0.7371994853019714\n",
      "Step: 46 Steady State: 0.7389383912086487\n",
      "Step: 47 Steady State: 0.740541934967041\n",
      "Step: 48 Steady State: 0.7420334219932556\n",
      "Step: 49 Steady State: 0.7434003353118896\n",
      "Step: 50 Steady State: 0.7446598410606384\n",
      "Step: 51 Steady State: 0.7458205819129944\n",
      "Step: 52 Steady State: 0.7468900680541992\n",
      "Step: 53 Steady State: 0.7478761672973633\n",
      "Step: 54 Steady State: 0.7487852573394775\n",
      "Step: 55 Steady State: 0.7496234178543091\n",
      "Step: 56 Steady State: 0.750394344329834\n",
      "Step: 57 Steady State: 0.7511063814163208\n",
      "Step: 58 Steady State: 0.751763105392456\n",
      "Step: 59 Steady State: 0.7523672580718994\n",
      "Step: 60 Steady State: 0.7529228329658508\n",
      "Step: 61 Steady State: 0.753433346748352\n",
      "Step: 62 Steady State: 0.7539049983024597\n",
      "Step: 63 Steady State: 0.7543382048606873\n",
      "Step: 64 Steady State: 0.7547407746315002\n",
      "Step: 65 Steady State: 0.7551127672195435\n",
      "Step: 66 Steady State: 0.7554563879966736\n",
      "Step: 67 Steady State: 0.7557693123817444\n",
      "Step: 68 Steady State: 0.7560611367225647\n",
      "Step: 69 Steady State: 0.7563308477401733\n",
      "Step: 70 Steady State: 0.7565800547599792\n",
      "Step: 71 Steady State: 0.756810188293457\n",
      "Step: 72 Steady State: 0.7570226788520813\n",
      "Step: 73 Steady State: 0.7572163343429565\n",
      "Step: 74 Steady State: 0.7573966979980469\n",
      "Step: 75 Steady State: 0.7575633525848389\n",
      "Step: 76 Steady State: 0.7577127814292908\n",
      "Step: 77 Steady State: 0.7578537464141846\n",
      "Step: 78 Steady State: 0.7579842805862427\n",
      "Step: 79 Steady State: 0.7581048607826233\n",
      "Step: 80 Steady State: 0.7582123279571533\n",
      "Step: 81 Steady State: 0.7583134770393372\n",
      "Step: 82 Steady State: 0.7584078907966614\n",
      "Step: 83 Steady State: 0.7584953904151917\n",
      "Step: 84 Steady State: 0.758575975894928\n",
      "Step: 85 Steady State: 0.7586501836776733\n",
      "Step: 86 Steady State: 0.7587193250656128\n",
      "Step: 87 Steady State: 0.7587832808494568\n",
      "Step: 88 Steady State: 0.7588424682617188\n",
      "Step: 89 Steady State: 0.7588958144187927\n",
      "Step: 90 Steady State: 0.7589460015296936\n",
      "Step: 91 Steady State: 0.7589924931526184\n",
      "Step: 92 Steady State: 0.7590354681015015\n",
      "Step: 93 Steady State: 0.7590752243995667\n",
      "Step: 94 Steady State: 0.7591111063957214\n",
      "Step: 95 Steady State: 0.7591448426246643\n",
      "Step: 96 Steady State: 0.7591760754585266\n",
      "Step: 97 Steady State: 0.7592049241065979\n",
      "Step: 98 Steady State: 0.7592315673828125\n",
      "Step: 99 Steady State: 0.7592562437057495\n",
      "Target: 0.7599999904632568\n"
     ]
    }
   ],
   "source": [
    "model = ExponentialDecayToSteadyState(\n",
    "    jnp.array(0.0)\n",
    ")  # initial steady state guess is 0.\n",
    "# target steady state is 0.76\n",
    "target_steady_state = jnp.array(0.76)\n",
    "optim = optax.sgd(1e-2, momentum=0.7, nesterov=True)\n",
    "opt_state = optim.init(model)\n",
    "\n",
    "\n",
    "@eqx.filter_jit\n",
    "def make_step(model, opt_state, target_steady_state):\n",
    "    grads = eqx.filter_grad(loss)(model, target_steady_state)\n",
    "    updates, opt_state = optim.update(grads, opt_state)\n",
    "    model = eqx.apply_updates(model, updates)\n",
    "    return model, opt_state\n",
    "\n",
    "\n",
    "for step in range(100):\n",
    "    model, opt_state = make_step(model, opt_state, target_steady_state)\n",
    "    print(f\"Step: {step} Steady State: {model.steady_state}\")\n",
    "print(f\"Target: {target_steady_state}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "py37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
