{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a321084a0fdec304",
   "metadata": {
    "collapsed": false,
    "id": "a321084a0fdec304"
   },
   "source": [
    "# <h1><center>Defining Nodes in **rex** (Robotic Environments with jaX)  <a href=\"http://colab.research.google.com/github/anonymous/rex/blob/master/examples/sim2real.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" width=\"140\" align=\"center\"/></a></center></h1>\n",
    "\n",
    "This notebook offers an introductory tutorial for **rex (Robotic Environments with jaX)**, a **JAX-based framework** for creating **graph-based environments** tailored for **sim2real robotics**.\n",
    "\n",
    "In this tutorial, we will guide you through the process of defining **nodes**, which are the **fundamental building blocks** for constructing **graph-based simulations** and **real-world systems** within rex. Specifically, we will demonstrate how to define the nodes used in the **sim2real.ipynb** notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b5eb9f9c80904dd3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-08T10:19:04.528714504Z",
     "start_time": "2024-10-08T10:19:02.944387450Z"
    },
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "collapsed": true,
    "id": "b5eb9f9c80904dd3",
    "outputId": "0abd47b5-b897-41bd-991b-7725fd621d8e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Installing rex via `pip install rex-lib[examples]`. If you are running this in a Colab notebook, you can ignore this message.\n",
      "Collecting rex-lib[examples]\n",
      "  Downloading rex_lib-0.0.5-py3-none-any.whl.metadata (15 kB)\n",
      "Collecting dill>=0.3.8 (from rex-lib[examples])\n",
      "  Downloading dill-0.3.9-py3-none-any.whl.metadata (10 kB)\n",
      "Collecting distrax>=0.1.5 (from rex-lib[examples])\n",
      "  Downloading distrax-0.1.5-py3-none-any.whl.metadata (13 kB)\n",
      "Collecting equinox>=0.11.4 (from rex-lib[examples])\n",
      "  Downloading equinox-0.11.7-py3-none-any.whl.metadata (18 kB)\n",
      "Collecting evosax>=0.1.6 (from rex-lib[examples])\n",
      "  Downloading evosax-0.1.6-py3-none-any.whl.metadata (26 kB)\n",
      "Requirement already satisfied: flax>=0.8.5 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.8.5)\n",
      "Collecting gymnasium>=0.29.1 (from rex-lib[examples])\n",
      "  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)\n",
      "Requirement already satisfied: jax>=0.4.30 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.4.33)\n",
      "Requirement already satisfied: matplotlib>=3.7.0 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (3.7.1)\n",
      "Requirement already satisfied: networkx>=3.2.1 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (3.3)\n",
      "Requirement already satisfied: optax>=0.2.3 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.2.3)\n",
      "Collecting seaborn>=0.13.2 (from rex-lib[examples])\n",
      "  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)\n",
      "Collecting supergraph>=0.0.8 (from rex-lib[examples])\n",
      "  Downloading supergraph-0.0.8-py3-none-any.whl.metadata (1.2 kB)\n",
      "Requirement already satisfied: termcolor>=2.4.0 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (2.4.0)\n",
      "Requirement already satisfied: tqdm>=4.66.4 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (4.66.5)\n",
      "Collecting brax>=0.10.5 (from rex-lib[examples])\n",
      "  Downloading brax-0.11.0-py3-none-any.whl.metadata (7.7 kB)\n",
      "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.4.0)\n",
      "Collecting dm-env (from brax>=0.10.5->rex-lib[examples])\n",
      "  Downloading dm_env-1.6-py3-none-any.whl.metadata (966 bytes)\n",
      "Requirement already satisfied: etils in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.9.4)\n",
      "Requirement already satisfied: flask in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (2.2.5)\n",
      "Collecting flask-cors (from brax>=0.10.5->rex-lib[examples])\n",
      "  Downloading Flask_Cors-5.0.0-py2.py3-none-any.whl.metadata (5.5 kB)\n",
      "Requirement already satisfied: grpcio in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.64.1)\n",
      "Requirement already satisfied: gym in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (0.25.2)\n",
      "Requirement already satisfied: jaxlib>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (0.4.33)\n",
      "Collecting jaxopt (from brax>=0.10.5->rex-lib[examples])\n",
      "  Downloading jaxopt-0.8.3-py3-none-any.whl.metadata (2.6 kB)\n",
      "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (3.1.4)\n",
      "Collecting ml-collections (from brax>=0.10.5->rex-lib[examples])\n",
      "  Downloading ml_collections-0.1.1.tar.gz (77 kB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m77.9/77.9 kB\u001B[0m \u001B[31m1.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25h  Preparing metadata (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "Collecting mujoco (from brax>=0.10.5->rex-lib[examples])\n",
      "  Downloading mujoco-3.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)\n",
      "\u001B[2K     \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m44.4/44.4 kB\u001B[0m \u001B[31m1.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hCollecting mujoco-mjx (from brax>=0.10.5->rex-lib[examples])\n",
      "  Downloading mujoco_mjx-3.2.3-py3-none-any.whl.metadata (3.4 kB)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.26.4)\n",
      "Requirement already satisfied: orbax-checkpoint in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (0.6.4)\n",
      "Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (10.4.0)\n",
      "Collecting pytinyrenderer (from brax>=0.10.5->rex-lib[examples])\n",
      "  Downloading pytinyrenderer-0.0.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)\n",
      "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.13.1)\n",
      "Collecting tensorboardX (from brax>=0.10.5->rex-lib[examples])\n",
      "  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)\n",
      "Collecting trimesh (from brax>=0.10.5->rex-lib[examples])\n",
      "  Downloading trimesh-4.4.9-py3-none-any.whl.metadata (18 kB)\n",
      "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (4.12.2)\n",
      "Requirement already satisfied: chex>=0.1.8 in /usr/local/lib/python3.10/dist-packages (from distrax>=0.1.5->rex-lib[examples]) (0.1.87)\n",
      "Requirement already satisfied: tensorflow-probability>=0.15.0 in /usr/local/lib/python3.10/dist-packages (from distrax>=0.1.5->rex-lib[examples]) (0.24.0)\n",
      "Collecting jaxtyping>=0.2.20 (from equinox>=0.11.4->rex-lib[examples])\n",
      "  Downloading jaxtyping-0.2.34-py3-none-any.whl.metadata (6.4 kB)\n",
      "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from evosax>=0.1.6->rex-lib[examples]) (6.0.2)\n",
      "Collecting dotmap (from evosax>=0.1.6->rex-lib[examples])\n",
      "  Downloading dotmap-1.3.30-py3-none-any.whl.metadata (3.2 kB)\n",
      "Requirement already satisfied: msgpack in /usr/local/lib/python3.10/dist-packages (from flax>=0.8.5->rex-lib[examples]) (1.0.8)\n",
      "Requirement already satisfied: tensorstore in /usr/local/lib/python3.10/dist-packages (from flax>=0.8.5->rex-lib[examples]) (0.1.66)\n",
      "Requirement already satisfied: rich>=11.1 in /usr/local/lib/python3.10/dist-packages (from flax>=0.8.5->rex-lib[examples]) (13.9.1)\n",
      "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium>=0.29.1->rex-lib[examples]) (2.2.1)\n",
      "Collecting farama-notifications>=0.0.1 (from gymnasium>=0.29.1->rex-lib[examples])\n",
      "  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)\n",
      "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.30->rex-lib[examples]) (0.4.1)\n",
      "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.30->rex-lib[examples]) (3.4.0)\n",
      "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (1.3.0)\n",
      "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (0.12.1)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (4.54.1)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (1.4.7)\n",
      "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (24.1)\n",
      "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (3.1.4)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (2.8.2)\n",
      "Requirement already satisfied: pandas>=1.2 in /usr/local/lib/python3.10/dist-packages (from seaborn>=0.13.2->rex-lib[examples]) (2.2.2)\n",
      "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.1.8->distrax>=0.1.5->rex-lib[examples]) (0.12.1)\n",
      "Collecting typeguard==2.13.3 (from jaxtyping>=0.2.20->equinox>=0.11.4->rex-lib[examples])\n",
      "  Downloading typeguard-2.13.3-py3-none-any.whl.metadata (3.6 kB)\n",
      "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.2->seaborn>=0.13.2->rex-lib[examples]) (2024.2)\n",
      "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.2->seaborn>=0.13.2->rex-lib[examples]) (2024.2)\n",
      "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib>=3.7.0->rex-lib[examples]) (1.16.0)\n",
      "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1->flax>=0.8.5->rex-lib[examples]) (3.0.0)\n",
      "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1->flax>=0.8.5->rex-lib[examples]) (2.18.0)\n",
      "Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability>=0.15.0->distrax>=0.1.5->rex-lib[examples]) (4.4.2)\n",
      "Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability>=0.15.0->distrax>=0.1.5->rex-lib[examples]) (0.6.0)\n",
      "Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability>=0.15.0->distrax>=0.1.5->rex-lib[examples]) (0.1.8)\n",
      "Requirement already satisfied: Werkzeug>=2.2.2 in /usr/local/lib/python3.10/dist-packages (from flask->brax>=0.10.5->rex-lib[examples]) (3.0.4)\n",
      "Requirement already satisfied: itsdangerous>=2.0 in /usr/local/lib/python3.10/dist-packages (from flask->brax>=0.10.5->rex-lib[examples]) (2.2.0)\n",
      "Requirement already satisfied: click>=8.0 in /usr/local/lib/python3.10/dist-packages (from flask->brax>=0.10.5->rex-lib[examples]) (8.1.7)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->brax>=0.10.5->rex-lib[examples]) (2.1.5)\n",
      "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.10/dist-packages (from gym->brax>=0.10.5->rex-lib[examples]) (0.0.8)\n",
      "Requirement already satisfied: contextlib2 in /usr/local/lib/python3.10/dist-packages (from ml-collections->brax>=0.10.5->rex-lib[examples]) (21.6.0)\n",
      "Collecting glfw (from mujoco->brax>=0.10.5->rex-lib[examples])\n",
      "  Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl.metadata (5.4 kB)\n",
      "Requirement already satisfied: pyopengl in /usr/local/lib/python3.10/dist-packages (from mujoco->brax>=0.10.5->rex-lib[examples]) (3.1.7)\n",
      "Requirement already satisfied: nest_asyncio in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint->brax>=0.10.5->rex-lib[examples]) (1.6.0)\n",
      "Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint->brax>=0.10.5->rex-lib[examples]) (3.20.3)\n",
      "Requirement already satisfied: humanize in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint->brax>=0.10.5->rex-lib[examples]) (4.10.0)\n",
      "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax>=0.8.5->rex-lib[examples]) (0.1.2)\n",
      "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils->brax>=0.10.5->rex-lib[examples]) (2024.6.1)\n",
      "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils->brax>=0.10.5->rex-lib[examples]) (6.4.5)\n",
      "Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils->brax>=0.10.5->rex-lib[examples]) (3.20.2)\n",
      "Downloading brax-0.11.0-py3-none-any.whl (998 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m998.6/998.6 kB\u001B[0m \u001B[31m11.8 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading dill-0.3.9-py3-none-any.whl (119 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m119.4/119.4 kB\u001B[0m \u001B[31m5.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading distrax-0.1.5-py3-none-any.whl (319 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m319.7/319.7 kB\u001B[0m \u001B[31m9.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading equinox-0.11.7-py3-none-any.whl (178 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m178.4/178.4 kB\u001B[0m \u001B[31m7.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading evosax-0.1.6-py3-none-any.whl (240 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m240.4/240.4 kB\u001B[0m \u001B[31m8.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading gymnasium-1.0.0-py3-none-any.whl (958 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m958.1/958.1 kB\u001B[0m \u001B[31m13.3 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading seaborn-0.13.2-py3-none-any.whl (294 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m294.9/294.9 kB\u001B[0m \u001B[31m8.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading supergraph-0.0.8-py3-none-any.whl (65 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m65.5/65.5 kB\u001B[0m \u001B[31m2.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading rex_lib-0.0.5-py3-none-any.whl (115 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m115.1/115.1 kB\u001B[0m \u001B[31m4.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)\n",
      "Downloading jaxtyping-0.2.34-py3-none-any.whl (42 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m42.4/42.4 kB\u001B[0m \u001B[31m1.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading typeguard-2.13.3-py3-none-any.whl (17 kB)\n",
      "Downloading dm_env-1.6-py3-none-any.whl (26 kB)\n",
      "Downloading dotmap-1.3.30-py3-none-any.whl (11 kB)\n",
      "Downloading Flask_Cors-5.0.0-py2.py3-none-any.whl (14 kB)\n",
      "Downloading jaxopt-0.8.3-py3-none-any.whl (172 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m172.3/172.3 kB\u001B[0m \u001B[31m5.2 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading mujoco-3.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.1 MB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m6.1/6.1 MB\u001B[0m \u001B[31m23.7 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading mujoco_mjx-3.2.3-py3-none-any.whl (6.7 MB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m6.7/6.7 MB\u001B[0m \u001B[31m12.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading pytinyrenderer-0.0.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m1.9/1.9 MB\u001B[0m \u001B[31m13.9 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m101.7/101.7 kB\u001B[0m \u001B[31m3.1 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading trimesh-4.4.9-py3-none-any.whl (700 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m700.1/700.1 kB\u001B[0m \u001B[31m20.4 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hDownloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB)\n",
      "\u001B[2K   \u001B[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001B[0m \u001B[32m211.8/211.8 kB\u001B[0m \u001B[31m7.5 MB/s\u001B[0m eta \u001B[36m0:00:00\u001B[0m\n",
      "\u001B[?25hBuilding wheels for collected packages: ml-collections\n",
      "  Building wheel for ml-collections (setup.py) ... \u001B[?25l\u001B[?25hdone\n",
      "  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94507 sha256=8b83b1225aa4d52136d84206a5cb94da537f08a16dbd7b480fa90dd833c1cf78\n",
      "  Stored in directory: /root/.cache/pip/wheels/7b/89/c9/a9b87790789e94aadcfc393c283e3ecd5ab916aed0a31be8fe\n",
      "Successfully built ml-collections\n",
      "Installing collected packages: pytinyrenderer, glfw, farama-notifications, dotmap, typeguard, trimesh, tensorboardX, supergraph, ml-collections, gymnasium, dm-env, dill, jaxtyping, seaborn, mujoco, flask-cors, mujoco-mjx, jaxopt, equinox, distrax, evosax, brax, rex-lib\n",
      "  Attempting uninstall: typeguard\n",
      "    Found existing installation: typeguard 4.3.0\n",
      "    Uninstalling typeguard-4.3.0:\n",
      "      Successfully uninstalled typeguard-4.3.0\n",
      "  Attempting uninstall: seaborn\n",
      "    Found existing installation: seaborn 0.13.1\n",
      "    Uninstalling seaborn-0.13.1:\n",
      "      Successfully uninstalled seaborn-0.13.1\n",
      "\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "inflect 7.4.0 requires typeguard>=4.0.1, but you have typeguard 2.13.3 which is incompatible.\u001B[0m\u001B[31m\n",
      "\u001B[0mSuccessfully installed brax-0.11.0 dill-0.3.9 distrax-0.1.5 dm-env-1.6 dotmap-1.3.30 equinox-0.11.7 evosax-0.1.6 farama-notifications-0.0.4 flask-cors-5.0.0 glfw-2.7.0 gymnasium-1.0.0 jaxopt-0.8.3 jaxtyping-0.2.34 ml-collections-0.1.1 mujoco-3.2.3 mujoco-mjx-3.2.3 pytinyrenderer-0.0.14 rex-lib-0.0.5 seaborn-0.13.2 supergraph-0.0.8 tensorboardX-2.6.2.2 trimesh-4.4.9 typeguard-2.13.3\n"
     ]
    }
   ],
   "source": [
    "# @title Install Necessary Libraries\n",
    "# @markdown This cell installs the required libraries for the project.\n",
    "# @markdown If you are running this notebook in Google Colab, most libraries should already be installed.\n",
    "\n",
    "try:\n",
    "    import rex  # noqa: F401\n",
    "\n",
    "    print(\"Rex already installed\")\n",
    "except ImportError:\n",
    "    print(\n",
    "        \"Installing rex via `pip install rex-lib[examples]`. \"\n",
    "        \"If you are running this in a Colab notebook, you can ignore this message.\"\n",
    "    )\n",
    "    !pip install rex-lib[examples]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8081a83b3afe366",
   "metadata": {
    "collapsed": false,
    "id": "8081a83b3afe366"
   },
   "source": [
    "# Introduction to Nodes in Rex\n",
    "\n",
    "In **Rex**, a **node** represents a fundamental computational unit within a graph-based system. Nodes encapsulate specific functionality and interact by passing data through connections, forming a network that can model complex systems. This tutorial introduces how to define nodes, specify their properties like rates and delays, and manage their interactions within a graph.\n",
    "\n",
    "## Defining Nodes\n",
    "\n",
    "Nodes are defined by creating subclasses of the `BaseNode` class. This base class provides a standardized API and essential functionality that all nodes inherit. When defining a node, you can specify several parameters directly in the `__init__` method:\n",
    "\n",
    "- **`name`**: A unique identifier for the node.\n",
    "- **`rate`**: The frequency at which the node's `step` method is called (in Hz).\n",
    "- **`delay`** (optional): The expected computation delay of the node (in seconds).\n",
    "- **`delay_dist`**: A distribution representing variability in the node's computation delay, useful for simulations.\n",
    "- **`advance`**: If `True`, the node's `step` method triggers when all inputs are ready; if `False`, it throttles until the scheduled time.\n",
    "- **`scheduling`**: Determines how the node's execution is scheduled. Options include `Scheduling.FREQUENCY` and `Scheduling.PHASE`.\n",
    "- **`color`**: Used for visualization purposes.\n",
    "- **`order`**: Determines the node's order in visualizations.\n",
    "\n",
    "Here's a basic example of a node definition:\n",
    "\n",
    "```python\n",
    "class MyNode(BaseNode):\n",
    "    def __init__(\n",
    "        self,\n",
    "        name: str,\n",
    "        rate: float,\n",
    "        delay: float = None,  # Expected computation delay (used for phase-shifting)\n",
    "        delay_dist: Union[DelayDistribution, distrax.Distribution] = None,  # Sim. computation delay\n",
    "        advance: bool = False,\n",
    "        scheduling: Scheduling = Scheduling.FREQUENCY,\n",
    "        color: str = None,\n",
    "        order: int = None\n",
    "    ):\n",
    "        super().__init__(name, rate, delay, delay_dist, advance, scheduling, color, order)\n",
    "        # Additional initialization if needed\n",
    "\n",
    "    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None):\n",
    "        # Initialize parameters\n",
    "        return MyParams()\n",
    "\n",
    "    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None):\n",
    "        # Initialize state\n",
    "        return MyState()\n",
    "\n",
    "    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None):\n",
    "        # Initialize default output\n",
    "        return MyOutput()\n",
    "\n",
    "    def step(self, step_state: StepState):\n",
    "        # Node's computation logic\n",
    "        new_state = ...\n",
    "        output = ...\n",
    "        return step_state.replace(state=new_state), output\n",
    "```\n",
    "\n",
    "## Connecting Nodes\n",
    "\n",
    "Nodes interact by passing outputs from one node to the inputs of another. This is achieved through the `connect` method, which establishes a connection between two nodes.\n",
    "\n",
    "### Connection API\n",
    "\n",
    "When connecting nodes, you can specify several parameters that control the nature of the connection:\n",
    "\n",
    "- **`output_node`**: node whose output will be connected as an input.\n",
    "- **`blocking`**: `True`, the receiving node waits for the input before proceeding. This can create dependencies between nodes.\n",
    "- **`delay`**: An additional delay introduced in the connection, which can control the phase shift between nodes.\n",
    "- **`delay_dist`**: Used in simulation to model communication delays between nodes.\n",
    "- **`window`**: Determines how many past messages are stored and accessible in the input buffer.\n",
    "- **`skip`**: If `True`, the connection is skipped when messages arrive simultaneously, helping resolve cyclic dependencies.\n",
    "- **`jitter`**: Controls how to handle irregularities in message timing (e.g., `Jitter.LATEST` uses the most recent message).\n",
    "- **`name`**: A shadow name for the input; defaults to the output node's name.\n",
    "\n",
    "#### Including `delay_dist` in Connection\n",
    "\n",
    "The `delay_dist` parameter allows you to specify a distribution that models the variability in communication delay between nodes. This is particularly useful in simulations where network latency or message passing delays are significant.\n",
    "\n",
    "#### Resolving Cyclic Dependencies with `skip`\n",
    "\n",
    "In graphs where nodes depend on each other's outputs (creating a cycle), the `skip` parameter can be used to resolve the dependency. By setting `skip=True` on a connection, you instruct the receiving node to proceed without waiting for the current message if it arrives simultaneously. This breaks the cycle and allows the system to function.\n",
    "\n",
    "#### Example Connection\n",
    "\n",
    "```python\n",
    "node_a.connect(\n",
    "    output_node=node_b,\n",
    "    blocking=True,\n",
    "    delay=0.01,  # Expected communication delay (used for phase-shifting)\n",
    "    delay_dist=distrax.Normal(loc=0.01, scale=0.005), # Sim. communication delay\n",
    "    window=5,\n",
    "    skip=False,\n",
    "    jitter=Jitter.LATEST,\n",
    "    name=\"input_from_b\"\n",
    ")\n",
    "```\n",
    "\n",
    "In this example, `node_a` connects to `node_b` with a blocking connection, an added delay of 0.01 seconds, and a delay distribution for simulation purposes. The `window` size is set to 5, meaning the last five messages are stored. The `skip` parameter is `False`, so the node will wait for the input.\n",
    "\n",
    "## Node Data Structure\n",
    "\n",
    "Nodes manage four main types of data (defined as [pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#patterns)), typically defined using immutable dataclasses for efficiency and safety:\n",
    "\n",
    "1. **Parameters**: Static configurations that usually remain constant during execution.\n",
    "2. **State**: Dynamic data that evolves over time with each `step`.\n",
    "3. **Outputs**: Data produced by a node's `step` method and sent to connected nodes.\n",
    "4. **Inputs**: Buffers that hold incoming data from other nodes, respecting the specified window size.\n",
    "\n",
    "### Immutable Dataclasses\n",
    "\n",
    "Using immutable dataclasses (e.g., via `@struct.dataclass` from Flax) ensures that the data structures are compatible with JAX's JIT compilation and functional programming paradigms. Additionally, dataclasses allow you to define specific methods related to the data structure, providing encapsulation and clarity.\n",
    "\n",
    "```python\n",
    "@struct.dataclass\n",
    "class MyParams:\n",
    "    some_parameter: float\n",
    "\n",
    "    def adjust_parameter(self, factor: float):\n",
    "        return self.replace(some_parameter=self.some_parameter * factor)\n",
    "\n",
    "@struct.dataclass\n",
    "class MyState:\n",
    "    some_state_variable: jax.Array\n",
    "\n",
    "    def update_state(self, delta: jax.Array):\n",
    "        return self.replace(some_state_variable=self.some_state_variable + delta)\n",
    "\n",
    "@struct.dataclass\n",
    "class MyOutput:\n",
    "    some_output_data: jax.Array\n",
    "```\n",
    "\n",
    "In this example, `MyParams` and `MyState` include methods to adjust parameters and update state, respectively. This encapsulation enhances code organization and readability.\n",
    "\n",
    "### Initialization\n",
    "\n",
    "Node data is initialized using specific methods that you should override:\n",
    "\n",
    "- **`init_params`**: Initializes the node's parameters.\n",
    "- **`init_state`**: Initializes the node's state.\n",
    "- **`init_output`**: Provides a default output, useful for initializing input buffers in connected nodes.\n",
    "\n",
    "These methods are typically called during the graph's initialization phase using `graph.init()`.\n",
    "\n",
    "## The `step` Method in Detail\n",
    "\n",
    "The `step` method defines how a node processes inputs and updates its state at each timestep. It receives a `StepState` object with all necessary information.\n",
    "\n",
    "### `StepState` Attributes\n",
    "\n",
    "- **`rng`**: Random number generator (updated if used).\n",
    "- **`state`**: Node's current state.\n",
    "- **`params`**: Static parameters influencing behavior.\n",
    "- **`inputs`**: Dictionary of `InputState` instances (keyed by input names).\n",
    "- **`eps`**: Episode number relates to the current computation graph used for simulation (unrelated to RL episode number).\n",
    "- **`seq`**: Current step number (auto-increments with each step).\n",
    "- **`ts`**: Timestamp at the start of the step.\n",
    "\n",
    "### Accessing Inputs\n",
    "\n",
    "Each `InputState` in `step_state.inputs` contains:\n",
    "\n",
    "- **`data`**: Messages from the connected node.\n",
    "- **`seq`**: Sequence numbers of the received messages.\n",
    "- **`ts_sent`**: Timestamps when messages were sent.\n",
    "- **`ts_recv`**: Timestamps when messages were received.\n",
    "\n",
    "For example, accessing the most recent message:\n",
    "\n",
    "```python\n",
    "latest_sensor_input = step_state.inputs['sensor'][-1].data\n",
    "```\n",
    "\n",
    "### Implementing the `step` Method\n",
    "\n",
    "The typical steps to implement the `step` method can be condensed into the following block:\n",
    "\n",
    "```python\n",
    "def step(self, step_state: StepState):\n",
    "    # Unpack StepState\n",
    "    rng, state, params, inputs = step_state.rng, step_state.state, step_state.params, step_state.inputs\n",
    "    \n",
    "    # Access latest input\n",
    "    control_signal = inputs['controller'][-1].data\n",
    "    \n",
    "    # Update state\n",
    "    new_state_variable = state.some_state_variable + control_signal * params.gain\n",
    "    new_state = state.replace(some_state_variable=new_state_variable)\n",
    "    \n",
    "    # Produce output\n",
    "    output = MyOutput(some_output_data=new_state_variable)\n",
    "    \n",
    "    # Update RNG if randomness is involved\n",
    "    rng, _ = jax.random.split(rng)\n",
    "    \n",
    "    # Return updated StepState and output\n",
    "    return step_state.replace(state=new_state, rng=rng), output\n",
    "```\n",
    "\n",
    "### Working with Time and Sequence\n",
    "\n",
    "Use `eps`, `ts` and `seq` for time-dependent logic:\n",
    "\n",
    "```python\n",
    "if step_state.ts > params.activation_time:\n",
    "    # Perform time-based logic\n",
    "    pass\n",
    "```\n",
    "\n",
    "### Handling Input Windows\n",
    "\n",
    "If the input window size is greater than 1, you can access past messages:\n",
    "\n",
    "```python\n",
    "recent_sensor_data = inputs['sensor_input'][-3:].data\n",
    "```\n",
    "\n",
    "### JIT Compilation and Side Effects Handling with External Callbacks\n",
    "\n",
    "Rex advocates for JIT-compiling the `step` method of each node to enhance performance. However, interfacing with real hardware often involves side effects that JAX's JIT compilation doesn't handle natively.\n",
    "\n",
    "To include side-effecting code (e.g., sending commands to actuators, reading sensor data), you must use JAX's external callback mechanism. This involves wrapping side-effecting functions with `jax.experimental.io_callback` to ensure compatibility with JIT compilation.\n",
    "\n",
    "Refer to the [JAX documentation on external callbacks](https://jax.readthedocs.io/en/latest/external_callbacks.html) for detailed guidance.\n",
    "\n",
    "```python\n",
    "def step(self, step_state: StepState):\n",
    "    # Compute outputs\n",
    "    output = ...\n",
    "\n",
    "    # Side-effecting function\n",
    "    def _apply_action(action):\n",
    "        # Code that interacts with hardware\n",
    "        return np.array(1.0)  # Dummy return value\n",
    "\n",
    "    # Wrap side-effecting code\n",
    "    _ = jax.experimental.io_callback(\n",
    "        _apply_action,\n",
    "        result_shape=jnp.array(1.0),\n",
    "        arg=output.some_output_data\n",
    "    )\n",
    "\n",
    "    # Update state and return\n",
    "    return step_state, output\n",
    "```\n",
    "\n",
    "## Real-World Nodes and Lifecycle Methods\n",
    "\n",
    "When nodes interface with real hardware or external systems, additional lifecycle management is necessary. The `BaseNode` API accommodates this through:\n",
    "\n",
    "- **`startup`**: Called before an episode starts, allowing the node to prepare (e.g., initialize hardware).\n",
    "- **`stop`**: Called after an episode ends, enabling the node to clean up resources or safely shut down hardware.\n",
    "\n",
    "```python\n",
    "class RealWorldNode(BaseNode):\n",
    "    def __init__(\n",
    "        self,\n",
    "        name: str,\n",
    "        rate: float,\n",
    "        delay: float = None,\n",
    "        delay_dist: Union[DelayDistribution, distrax.Distribution] = None,\n",
    "        advance: bool = False,\n",
    "        scheduling: Scheduling = Scheduling.FREQUENCY,\n",
    "        color: str = None,\n",
    "        order: int = None\n",
    "    ):\n",
    "        super().__init__(name, rate, delay, delay_dist, advance, scheduling, color, order)\n",
    "        # Additional initialization if needed\n",
    "\n",
    "    def startup(self, graph_state: GraphState, timeout: float = None):\n",
    "        # Initialize hardware connections\n",
    "        return True  # Return True if successful\n",
    "\n",
    "    def stop(self, timeout: float = None):\n",
    "        # Safely shut down hardware\n",
    "        return True\n",
    "```\n",
    "\n",
    "## Summary\n",
    "\n",
    "By following these guidelines, you can define robust and efficient nodes within the Rex framework. Nodes can be customized extensively through their parameters and state, connected flexibly to form complex graphs, and optimized using JIT compilation. Proper handling of side effects ensures that nodes interfacing with real-world systems remain performant and reliable.\n",
    "\n",
    "In the following examples, we'll implement specific nodes that illustrate these concepts in practice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d0bb247d74e6bac0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-08T10:19:04.579702461Z",
     "start_time": "2024-10-08T10:19:04.547984646Z"
    },
    "cellView": "form",
    "id": "d0bb247d74e6bac0"
   },
   "outputs": [],
   "source": [
    "# @title Example: Agent\n",
    "\n",
    "from typing import Tuple, Union\n",
    "\n",
    "import jax\n",
    "from flax import struct\n",
    "from flax.core import FrozenDict\n",
    "from jax import numpy as jnp\n",
    "\n",
    "from rex import base\n",
    "from rex.base import GraphState, StepState\n",
    "from rex.node import BaseNode\n",
    "from rex.ppo import Policy\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class AgentOutput(base.Base):\n",
    "    \"\"\"Agent's output\"\"\"\n",
    "\n",
    "    action: jax.typing.ArrayLike  # Torque to apply to the pendulum\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class AgentParams(base.Base):\n",
    "    # Policy\n",
    "    policy: Policy\n",
    "    # Observations\n",
    "    num_act: Union[int, jax.typing.ArrayLike] = struct.field(pytree_node=False)  # Action history length\n",
    "    num_obs: Union[int, jax.typing.ArrayLike] = struct.field(pytree_node=False)  # Observation history length\n",
    "    # Action\n",
    "    max_torque: Union[float, jax.typing.ArrayLike]\n",
    "    # Initial state\n",
    "    init_method: str = struct.field(pytree_node=False)  # \"random\", \"parametrized\"\n",
    "    parametrized: jax.typing.ArrayLike\n",
    "    max_th: Union[float, jax.typing.ArrayLike]\n",
    "    max_thdot: Union[float, jax.typing.ArrayLike]\n",
    "    # Train\n",
    "    gamma: Union[float, jax.typing.ArrayLike]\n",
    "    tmax: Union[float, jax.typing.ArrayLike]\n",
    "\n",
    "    @staticmethod\n",
    "    def process_inputs(inputs: FrozenDict[str, base.InputState]) -> jax.Array:\n",
    "        th, thdot = inputs[\"sensor\"][-1].data.th, inputs[\"sensor\"][-1].data.thdot\n",
    "        obs = jnp.array([jnp.cos(th), jnp.sin(th), thdot])\n",
    "        return obs\n",
    "\n",
    "    @staticmethod\n",
    "    def get_observation(step_state: StepState) -> jax.Array:\n",
    "        # Unpack StepState\n",
    "        inputs, state = step_state.inputs, step_state.state\n",
    "\n",
    "        # Convert inputs to single observation\n",
    "        single_obs = AgentParams.process_inputs(inputs)\n",
    "\n",
    "        # Concatenate with previous observations\n",
    "        obs = jnp.concatenate([single_obs, state.history_obs.flatten(), state.history_act.flatten()])\n",
    "        return obs\n",
    "\n",
    "    @staticmethod\n",
    "    def update_state(step_state: StepState, action: jax.Array) -> \"AgentState\":\n",
    "        # Unpack StepState\n",
    "        state, params, inputs = step_state.state, step_state.params, step_state.inputs\n",
    "\n",
    "        # Convert inputs to observation\n",
    "        single_obs = AgentParams.process_inputs(inputs)\n",
    "\n",
    "        # Update obs history\n",
    "        if params.num_obs > 0:\n",
    "            history_obs = jnp.roll(state.history_obs, shift=1, axis=0)\n",
    "            history_obs = history_obs.at[0].set(single_obs)\n",
    "        else:\n",
    "            history_obs = state.history_obs\n",
    "\n",
    "        # Update act history\n",
    "        if params.num_act > 0:\n",
    "            history_act = jnp.roll(state.history_act, shift=1, axis=0)\n",
    "            history_act = history_act.at[0].set(action)\n",
    "        else:\n",
    "            history_act = state.history_act\n",
    "\n",
    "        new_state = state.replace(history_obs=history_obs, history_act=history_act)\n",
    "        return new_state\n",
    "\n",
    "    @staticmethod\n",
    "    def to_output(action: jax.Array) -> AgentOutput:\n",
    "        return AgentOutput(action=action)\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class AgentState(base.Base):\n",
    "    history_act: jax.typing.ArrayLike  # History of actions\n",
    "    history_obs: jax.typing.ArrayLike  # History of observations\n",
    "    init_th: Union[float, jax.typing.ArrayLike]  # Initial angle of the pendulum\n",
    "    init_thdot: Union[float, jax.typing.ArrayLike]  # Initial angular velocity of the pendulum\n",
    "\n",
    "\n",
    "class Agent(BaseNode):\n",
    "    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> AgentParams:\n",
    "        return AgentParams(\n",
    "            policy=None,  # Policy must be set by the user\n",
    "            num_act=4,  # Number of actions to keep in history\n",
    "            num_obs=4,  # Number of observations to keep in history\n",
    "            max_torque=2.0,  # Maximum torque that can be applied to the pendulum\n",
    "            init_method=\"parametrized\",  # \"random\" or \"parametrized\"\n",
    "            parametrized=jnp.array([jnp.pi, 0.0]),  # [th, thdot]\n",
    "            max_th=jnp.pi,  # Maximum initial angle of the pendulum\n",
    "            max_thdot=9.0,  # Maximum initial angular velocity of the pendulum\n",
    "            gamma=0.99,  # Discount factor  (used during training)\n",
    "            tmax=3.0,  # Maximum time for an episode (used during training)\n",
    "        )\n",
    "\n",
    "    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -> AgentState:\n",
    "        graph_state = graph_state or base.GraphState()\n",
    "        params = graph_state.params.get(self.name, self.init_params(rng, graph_state))\n",
    "        history_act = jnp.zeros((params.num_act, 1), dtype=jnp.float32)  # [torque]\n",
    "        history_obs = jnp.zeros((params.num_obs, 3), dtype=jnp.float32)  # [cos(th), sin(th), thdot]\n",
    "\n",
    "        # Set the initial state of the pendulum\n",
    "        if params.init_method == \"parametrized\":\n",
    "            init_th, init_thdot = params.parametrized\n",
    "        elif params.init_method == \"random\":\n",
    "            rng = rng if rng is not None else jax.random.PRNGKey(0)\n",
    "            rngs = jax.random.split(rng, num=2)\n",
    "            init_th = jax.random.uniform(rngs[0], shape=(), minval=-params.max_th, maxval=params.max_th)\n",
    "            init_thdot = jax.random.uniform(rngs[1], shape=(), minval=-params.max_thdot, maxval=params.max_thdot)\n",
    "        else:\n",
    "            raise ValueError(f\"Invalid init_method: {params.init_method}\")\n",
    "        return AgentState(history_act=history_act, history_obs=history_obs, init_th=init_th, init_thdot=init_thdot)\n",
    "\n",
    "    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> AgentOutput:\n",
    "        \"\"\"Default output of the node.\"\"\"\n",
    "        rng = jax.random.PRNGKey(0) if rng is None else rng\n",
    "        graph_state = graph_state or base.GraphState()\n",
    "        params = graph_state.params.get(self.name, self.init_params(rng, graph_state))\n",
    "        action = jax.random.uniform(rng, shape=(1,), minval=-params.max_torque, maxval=params.max_torque)\n",
    "        return AgentOutput(action=action)\n",
    "\n",
    "    def step(self, step_state: StepState) -> Tuple[StepState, AgentOutput]:\n",
    "        \"\"\"Step the node.\"\"\"\n",
    "        # Unpack StepState\n",
    "        rng, params = step_state.rng, step_state.params\n",
    "\n",
    "        # Prepare output\n",
    "        rng, rng_net = jax.random.split(rng)\n",
    "        if params.policy is not None:  # Use policy to get action\n",
    "            obs = AgentParams.get_observation(step_state)\n",
    "            action = params.policy.get_action(obs, rng=None)  # Supply rng for stochastic policies\n",
    "        else:  # Random action if no policy is set\n",
    "            action = jax.random.uniform(rng_net, shape=(1,), minval=-params.max_torque, maxval=params.max_torque)\n",
    "        output = AgentParams.to_output(action)  # Convert action to output message\n",
    "\n",
    "        # Update step_state (observation and action history)\n",
    "        new_state = params.update_state(step_state, action)  # Update state\n",
    "        new_step_state = step_state.replace(rng=rng, state=new_state)  # Update step_state\n",
    "        return new_step_state, output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d0bd6f838b051602",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-08T10:19:04.580137757Z",
     "start_time": "2024-10-08T10:19:04.550775597Z"
    },
    "cellView": "form",
    "id": "d0bd6f838b051602"
   },
   "outputs": [],
   "source": [
    "# @title Example: Actuator\n",
    "\n",
    "from typing import Tuple, Union\n",
    "\n",
    "import jax\n",
    "import numpy as onp\n",
    "from flax import struct\n",
    "\n",
    "from rex import base\n",
    "from rex.base import GraphState, StepState\n",
    "from rex.jax_utils import tree_dynamic_slice\n",
    "from rex.node import BaseNode\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class ActuatorOutput(base.Base):\n",
    "    \"\"\"Pendulum actuator output\"\"\"\n",
    "\n",
    "    action: jax.typing.ArrayLike  # Torque to apply to the pendulum\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class ActuatorParams(base.Base):\n",
    "    \"\"\"Pendulum actuator param definition\"\"\"\n",
    "\n",
    "    actuator_delay: Union[float, jax.typing.ArrayLike]\n",
    "\n",
    "\n",
    "class Actuator(BaseNode):\n",
    "    \"\"\"This is a simple actuator node definition that could interface a real actuator.\n",
    "\n",
    "    When interfacing real hardware, you would send the action to real hardware in the .step method.\n",
    "    Optionally, you could also specify a startup routine that is called right before an episode starts.\n",
    "    Finally, a stop routine is called after the episode is done.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        \"\"\"No special initialization needed.\"\"\"\n",
    "        super().__init__(*args, **kwargs)\n",
    "\n",
    "    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> ActuatorParams:\n",
    "        \"\"\"Default params of the node.\"\"\"\n",
    "        actuator_delay = 0.05\n",
    "        return ActuatorParams(actuator_delay=actuator_delay)\n",
    "\n",
    "    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> ActuatorOutput:\n",
    "        \"\"\"Default output of the node.\"\"\"\n",
    "        return ActuatorOutput(action=jnp.array([0.0], dtype=jnp.float32))\n",
    "\n",
    "    def startup(self, graph_state: base.GraphState, timeout: float = None) -> bool:\n",
    "        \"\"\"Starts the node in the state specified by graph_state.\n",
    "\n",
    "        This method is called right before an episode starts.\n",
    "        It can be used to move (a real) robot to a starting position as specified by the graph_state.\n",
    "\n",
    "        Not used when running in compiled mode.\n",
    "        :param graph_state: The graph state.\n",
    "        :param timeout: The timeout of the startup.\n",
    "        :return: Whether the node has started successfully.\n",
    "        \"\"\"\n",
    "        # Move robot to starting position specified by graph_state (e.g. graph_state.state[\"agent\"].init_th)\n",
    "        return True  # Not doing anything here\n",
    "\n",
    "    def step(self, step_state: StepState) -> Tuple[StepState, ActuatorOutput]:\n",
    "        \"\"\"If we were to control a real robot, you would send the action to the robot here.\"\"\"\n",
    "        # Prepare output\n",
    "        output = step_state.inputs[\"agent\"][-1].data\n",
    "        output = ActuatorOutput(action=output.action)\n",
    "\n",
    "        def _apply_action(action):\n",
    "            \"\"\"\n",
    "            Not really doing anything here, just a dummy implementation.\n",
    "            Include some side-effecting code here (e.g. sending the action to a real robot).\n",
    "\n",
    "            The .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.\n",
    "            See the jax documentation for more information on how to do this:\n",
    "            https://jax.readthedocs.io/en/latest/external-callbacks.html\n",
    "            \"\"\"\n",
    "            # print(f\"Applying action: {action}\") # Apply action to the robot\n",
    "            return onp.array(1.0)  # Must match dtype and shape of return_shape\n",
    "\n",
    "        # Apply action to the robot\n",
    "        return_shape = jnp.array(1.0)  # Must match dtype and shape of return_shape\n",
    "        _ = jax.experimental.io_callback(_apply_action, return_shape, output)\n",
    "\n",
    "        # Update state\n",
    "        new_step_state = step_state\n",
    "        return new_step_state, output\n",
    "\n",
    "    def stop(self, timeout: float = None) -> bool:\n",
    "        \"\"\"Stopping routine that is called after the episode is done.\n",
    "\n",
    "        **IMPORTANT** It may happen that stop is called *before* the final .step call of an episode returns,\n",
    "        which may cause unsafe behavior when the final step undoes the work of the .stop method.\n",
    "        This should be handled by the user. For example, by stopping \"longer\" before returning here.\n",
    "\n",
    "        Only ran when running asynchronously.\n",
    "        :param timeout: The timeout of the stop\n",
    "        :return: Whether the node has stopped successfully.\n",
    "        \"\"\"\n",
    "        # Stop the robot (e.g. set the torque to 0)\n",
    "        return True\n",
    "\n",
    "\n",
    "class SimActuator(BaseNode):\n",
    "    \"\"\"This is a simple simulated actuator node definition that can either\n",
    "    1. Feedthrough the agent's action (for normal operation, e.g., training).\n",
    "       Optionally, you could include some noise or other modifications to the action.\n",
    "    2. Reapply the recorded actuator outputs for system identification if available.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, *args, outputs: ActuatorOutput = None, **kwargs):\n",
    "        \"\"\"Initialize Actuator for system identification.\n",
    "\n",
    "        Here, we will reapply the recorded actuator outputs for system identification if available.\n",
    "\n",
    "        :param outputs: Recorded actuator Outputs to be used for system identification.\n",
    "        \"\"\"\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self._outputs = outputs\n",
    "\n",
    "    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> ActuatorParams:\n",
    "        \"\"\"Default params of the node.\"\"\"\n",
    "        actuator_delay = 0.05\n",
    "        return ActuatorParams(actuator_delay=actuator_delay)\n",
    "\n",
    "    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> ActuatorOutput:\n",
    "        \"\"\"Default output of the node.\"\"\"\n",
    "        return ActuatorOutput(action=jnp.array([0.0], dtype=jnp.float32))\n",
    "\n",
    "    def step(self, step_state: StepState) -> Tuple[StepState, ActuatorOutput]:\n",
    "        # Get action from dataset if available, else use the one provided by the agent\n",
    "        if self._outputs is not None:  # Use the recorded action (for system identification)\n",
    "            output = tree_dynamic_slice(self._outputs, jnp.array([step_state.eps, step_state.seq]))\n",
    "            output = jax.tree_util.tree_map(lambda _o: _o[0, 0], output)\n",
    "        else:  # Feedthrough the agent's action (for normal operation, e.g., training)\n",
    "            output = step_state.inputs[\"agent\"][-1].data\n",
    "            output = ActuatorOutput(action=output.action)\n",
    "        new_step_state = step_state\n",
    "        return new_step_state, output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c607c4fad859faad",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-08T10:19:04.584261634Z",
     "start_time": "2024-10-08T10:19:04.572931658Z"
    },
    "cellView": "form",
    "id": "c607c4fad859faad"
   },
   "outputs": [],
   "source": [
    "# @title Example: Sensor\n",
    "\n",
    "from typing import Dict, Tuple, Union\n",
    "\n",
    "import jax\n",
    "from flax import struct\n",
    "\n",
    "from rex import base\n",
    "from rex.base import GraphState, StepState\n",
    "from rex.node import BaseNode\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class SensorOutput(base.Base):\n",
    "    \"\"\"Output message definition of the sensor node.\"\"\"\n",
    "\n",
    "    th: Union[float, jax.typing.ArrayLike]\n",
    "    thdot: Union[float, jax.typing.ArrayLike]\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class SensorParams(base.Base):\n",
    "    \"\"\"\n",
    "    Other than the sensor delay, we don't have any other parameters.\n",
    "    You could add more parameters here if needed, such as noise levels etc.\n",
    "    \"\"\"\n",
    "\n",
    "    sensor_delay: Union[float, jax.typing.ArrayLike]\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class SensorState:\n",
    "    \"\"\"We use this state to record the reconstruction loss.\"\"\"\n",
    "\n",
    "    loss_th: Union[float, jax.typing.ArrayLike]\n",
    "    loss_thdot: Union[float, jax.typing.ArrayLike]\n",
    "\n",
    "\n",
    "class Sensor(BaseNode):\n",
    "    \"\"\"This is a simple sensor node definition that interfaces a real sensor.\n",
    "\n",
    "    When interfacing real hardware, you would grab the sensor measurement in the .step method.\n",
    "    Optionally, you could also specify a startup routine that is called right before an episode starts.\n",
    "    Finally, a stop routine is called after the episode is done.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        \"\"\"No special initialization needed.\"\"\"\n",
    "        super().__init__(*args, **kwargs)\n",
    "\n",
    "    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorParams:\n",
    "        \"\"\"Default params of the node.\"\"\"\n",
    "        sensor_delay = 0.05\n",
    "        return SensorParams(sensor_delay=sensor_delay)\n",
    "\n",
    "    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorOutput:\n",
    "        \"\"\"Default output of the node.\"\"\"\n",
    "        # Randomly define some initial sensor values\n",
    "        th = jnp.pi\n",
    "        thdot = 0.0\n",
    "        return SensorOutput(th=th, thdot=thdot)\n",
    "\n",
    "    def startup(self, graph_state: base.GraphState, timeout: float = None) -> bool:\n",
    "        \"\"\"Starts the node in the state specified by graph_state.\n",
    "\n",
    "        This method is called right before an episode starts.\n",
    "        It can be used to move (a real) robot to a starting position as specified by the graph_state.\n",
    "\n",
    "        Not used when running in compiled mode.\n",
    "        :param graph_state: The graph state.\n",
    "        :param timeout: The timeout of the startup.\n",
    "        :return: Whether the node has started successfully.\n",
    "        \"\"\"\n",
    "        return True  # Not doing anything here\n",
    "\n",
    "    def step(self, step_state: StepState) -> Tuple[StepState, SensorOutput]:\n",
    "        \"\"\"If we were to interface a real hardware, you would grab the sensor measurement here.\"\"\"\n",
    "\n",
    "        \"\"\"\n",
    "        As the .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.\n",
    "        See the jax documentation for more information on how to do this:\n",
    "        https://jax.readthedocs.io/en/latest/external-callbacks.html\n",
    "        \"\"\"\n",
    "        world = step_state.inputs[\"world\"][-1].data\n",
    "\n",
    "        def _grab_measurement():\n",
    "            \"\"\"\n",
    "            Not really doing anything here, just a dummy implementation.\n",
    "            Include some side-effecting code here (e.g. grabbing measurement from sensor).\n",
    "\n",
    "            The .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.\n",
    "            See the jax documentation for more information on how to do this:\n",
    "            https://jax.readthedocs.io/en/latest/external-callbacks.html\n",
    "            \"\"\"\n",
    "            # print(\"Grabbing sensor measurement\")\n",
    "            sensor_msg = onp.array(1.0)  # Dummy sensor measurement (not actually used)\n",
    "            return sensor_msg  # Must match dtype and shape of return_shape\n",
    "\n",
    "        # Grab sensor measurement\n",
    "        return_shape = jnp.array(1.0)  # Must match dtype and shape of return_shape\n",
    "        _ = jax.experimental.io_callback(_grab_measurement, return_shape)\n",
    "\n",
    "        # Prepare output\n",
    "        output = SensorOutput(th=world.th, thdot=world.thdot)\n",
    "\n",
    "        # Update state (NOOP)\n",
    "        new_step_state = step_state\n",
    "\n",
    "        return new_step_state, output\n",
    "\n",
    "    def stop(self, timeout: float = None) -> bool:\n",
    "        \"\"\"Stopping routine that is called after the episode is done.\n",
    "\n",
    "        **IMPORTANT** It may happen that stop is called *before* the final .step call of an episode returns,\n",
    "        which may cause unsafe behavior when the final step undoes the work of the .stop method.\n",
    "        This should be handled by the user. For example, by stopping \"longer\" before returning here.\n",
    "\n",
    "        Only ran when running asynchronously.\n",
    "        :param timeout: The timeout of the stop\n",
    "        :return: Whether the node has stopped successfully.\n",
    "        \"\"\"\n",
    "        return True  # Not doing anything here\n",
    "\n",
    "\n",
    "class SimSensor(BaseNode):\n",
    "    \"\"\"This is a simple simulated sensor node definition that can either\n",
    "    1. Convert the world state into a realistic sensor measurement (for normal operation, e.g., training).\n",
    "       Optionally, you could include some noise or other modifications to the sensor measurement.\n",
    "    2. Calculate a reconstruction loss based on the sensor measurement and the recorded sensor outputs.\n",
    "\n",
    "    By calculating and aggregating the reconstruction loss here, we take time-scale differences and delays into account.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, *args, outputs: SensorOutput = None, **kwargs):\n",
    "        \"\"\"Initialize a simulated sensor for system identification.\n",
    "\n",
    "        If outputs are provided, we will calculate the reconstruction loss based on the recorded sensor outputs.\n",
    "\n",
    "        :param outputs: Recorded sensor Outputs to be used for system identification.\n",
    "        \"\"\"\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self._outputs = outputs\n",
    "\n",
    "    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorParams:\n",
    "        \"\"\"Default params of the node.\"\"\"\n",
    "        sensor_delay = 0.05\n",
    "        return SensorParams(sensor_delay=sensor_delay)\n",
    "\n",
    "    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorState:\n",
    "        \"\"\"Default state of the node.\"\"\"\n",
    "        return SensorState(loss_th=0.0, loss_thdot=0.0)  # Initialize reconstruction loss to zero at the start of the episode\n",
    "\n",
    "    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorOutput:\n",
    "        \"\"\"Default output of the node.\"\"\"\n",
    "        # Randomly define some initial sensor values\n",
    "        th = jnp.pi\n",
    "        thdot = 0.0\n",
    "        return SensorOutput(th=th, thdot=thdot)  # Fix the initial sensor values\n",
    "\n",
    "    def init_delays(\n",
    "        self, rng: jax.Array = None, graph_state: base.GraphState = None\n",
    "    ) -> Dict[str, Union[float, jax.typing.ArrayLike]]:\n",
    "        \"\"\"Initialize trainable communication delays.\n",
    "\n",
    "        **Note** These only include trainable delays that were specified while connecting the nodes.\n",
    "\n",
    "        :param rng: Random number generator.\n",
    "        :param graph_state: The graph state that may be used to get the default output.\n",
    "        :return: Trainable delays (e.g., {input_name: delay}). Can be an incomplete dictionary.\n",
    "                 Entries for non-trainable delays or non-existent connections are ignored.\n",
    "        \"\"\"\n",
    "        graph_state = graph_state or GraphState()\n",
    "        params = graph_state.params.get(self.name, self.init_params(rng, graph_state))\n",
    "        delays = {\"world\": params.sensor_delay}\n",
    "        return delays\n",
    "\n",
    "    def step(self, step_state: StepState) -> Tuple[StepState, SensorOutput]:\n",
    "        # Determine output\n",
    "        data = step_state.inputs[\"world\"][-1].data\n",
    "        output = SensorOutput(th=data.th, thdot=data.thdot)\n",
    "\n",
    "        # Calculate loss\n",
    "        if self._outputs is not None:  # Calculate reconstruction loss and aggregate in state\n",
    "            output_rec = tree_dynamic_slice(self._outputs, jnp.array([step_state.eps, step_state.seq]))\n",
    "            output_rec = jax.tree_util.tree_map(lambda _o: _o[0, 0], output_rec)\n",
    "            th_rec, thdot_rec = output_rec.th, output_rec.thdot\n",
    "            state = step_state.state\n",
    "            loss_th = state.loss_th + (jnp.sin(output.th) - jnp.sin(th_rec)) ** 2 + (jnp.cos(output.th) - jnp.cos(th_rec)) ** 2\n",
    "            loss_thdot = state.loss_thdot + (output.thdot - thdot_rec) ** 2\n",
    "            new_state = state.replace(loss_th=loss_th, loss_thdot=loss_thdot)\n",
    "        else:  # NOOP\n",
    "            new_state = step_state.state\n",
    "\n",
    "        # Update step_state\n",
    "        new_step_state = step_state.replace(state=new_state)\n",
    "        return new_step_state, output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "af33ee6be159b66f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-08T10:19:04.619357308Z",
     "start_time": "2024-10-08T10:19:04.597453689Z"
    },
    "cellView": "form",
    "id": "af33ee6be159b66f"
   },
   "outputs": [],
   "source": [
    "# @title Example: ODE simulation node\n",
    "\n",
    "from math import ceil\n",
    "from typing import Dict, Tuple, Union\n",
    "\n",
    "import jax\n",
    "from flax import struct\n",
    "\n",
    "from rex import base\n",
    "from rex.base import GraphState, StepState\n",
    "from rex.node import BaseWorld\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class OdeParams(base.Base):\n",
    "    \"\"\"Pendulum ode param definition\"\"\"\n",
    "\n",
    "    max_speed: Union[float, jax.typing.ArrayLike]\n",
    "    J: Union[float, jax.typing.ArrayLike]\n",
    "    mass: Union[float, jax.typing.ArrayLike]\n",
    "    length: Union[float, jax.typing.ArrayLike]\n",
    "    b: Union[float, jax.typing.ArrayLike]\n",
    "    K: Union[float, jax.typing.ArrayLike]\n",
    "    R: Union[float, jax.typing.ArrayLike]\n",
    "    c: Union[float, jax.typing.ArrayLike]\n",
    "    dt_substeps_min: Union[float, jax.typing.ArrayLike] = struct.field(pytree_node=False)\n",
    "    dt: Union[float, jax.typing.ArrayLike] = struct.field(pytree_node=False)\n",
    "\n",
    "    @property\n",
    "    def substeps(self) -> int:\n",
    "        substeps = ceil(self.dt / self.dt_substeps_min)\n",
    "        return int(substeps)\n",
    "\n",
    "    @property\n",
    "    def dt_substeps(self) -> float:\n",
    "        substeps = self.substeps\n",
    "        dt_substeps = self.dt / substeps\n",
    "        return dt_substeps\n",
    "\n",
    "    def step(\n",
    "        self, substeps: int, dt_substeps: jax.typing.ArrayLike, x: \"OdeState\", us: jax.typing.ArrayLike\n",
    "    ) -> Tuple[\"OdeState\", \"OdeState\"]:\n",
    "        \"\"\"Step the pendulum ode.\"\"\"\n",
    "\n",
    "        def _scan_fn(_x, _u):\n",
    "            next_x = self._runge_kutta4(dt_substeps, _x, _u)\n",
    "            # Clip velocity\n",
    "            clip_thdot = jnp.clip(next_x.thdot, -self.max_speed, self.max_speed)\n",
    "            next_x = next_x.replace(thdot=clip_thdot)\n",
    "            return next_x, next_x\n",
    "\n",
    "        x_final, x_substeps = jax.lax.scan(_scan_fn, x, us, length=substeps)\n",
    "        return x_final, x_substeps\n",
    "\n",
    "    def _runge_kutta4(self, dt: jax.typing.ArrayLike, x: \"OdeState\", u: jax.typing.ArrayLike) -> \"OdeState\":\n",
    "        k1 = self._ode(x, u)\n",
    "        k2 = self._ode(x + k1 * dt * 0.5, u)\n",
    "        k3 = self._ode(x + k2 * dt * 0.5, u)\n",
    "        k4 = self._ode(x + k3 * dt, u)\n",
    "        return x + (k1 + k2 * 2 + k3 * 2 + k4) * (dt / 6)\n",
    "\n",
    "    def _ode(self, x: \"OdeState\", u: jax.typing.ArrayLike) -> \"OdeState\":\n",
    "        \"\"\"dx function for the pendulum ode\"\"\"\n",
    "        # Downward := [pi, 0], Upward := [0, 0]\n",
    "        g, J, m, l, b, K, R, c = 9.81, self.J, self.mass, self.length, self.b, self.K, self.R, self.c  # noqa: E741\n",
    "        th, thdot = x.th, x.thdot\n",
    "        activation = jnp.sign(thdot)\n",
    "        ddx = (u * K / R + m * g * l * jnp.sin(th) - b * thdot - thdot * K * K / R - c * activation) / J\n",
    "        return OdeState(th=thdot, thdot=ddx, loss_task=0.0)  # No derivative for loss_task\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class OdeState(base.Base):\n",
    "    \"\"\"Pendulum state definition\"\"\"\n",
    "\n",
    "    loss_task: Union[float, jax.typing.ArrayLike]\n",
    "    th: Union[float, jax.typing.ArrayLike]\n",
    "    thdot: Union[float, jax.typing.ArrayLike]\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class OdeOutput(base.Base):\n",
    "    \"\"\"World output definition\"\"\"\n",
    "\n",
    "    th: Union[float, jax.typing.ArrayLike]\n",
    "    thdot: Union[float, jax.typing.ArrayLike]\n",
    "\n",
    "\n",
    "class OdeWorld(BaseWorld):  # We inherit from BaseWorld for convenience, but you can inherit from BaseNode if you want\n",
    "    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> OdeParams:\n",
    "        \"\"\"Default params of the node.\"\"\"\n",
    "        return OdeParams(\n",
    "            max_speed=40.0,  # Clip angular velocity to this value\n",
    "            J=0.000159931461600856,  # 0.000159931461600856,\n",
    "            mass=0.0508581731919534,  # 0.0508581731919534,\n",
    "            length=0.0415233722862552,  # 0.0415233722862552,\n",
    "            b=1.43298488e-05,  # 1.43298488358436e-05,\n",
    "            K=0.03333912,  # 0.0333391179016334,\n",
    "            R=7.73125142,  # 7.73125142447252,\n",
    "            c=0.000975041213361349,  # 0.000975041213361349,\n",
    "            # Backend parameters\n",
    "            dt_substeps_min=1 / 100,  # Minimum substep size for ode integration\n",
    "            dt=1 / self.rate,  # Time step per .step() call\n",
    "        )\n",
    "\n",
    "    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -> OdeState:\n",
    "        \"\"\"Default state of the node.\"\"\"\n",
    "        graph_state = graph_state or GraphState()\n",
    "\n",
    "        # Try to grab state from graph_state\n",
    "        state = graph_state.state.get(\"agent\", None)\n",
    "        init_th = state.init_th if state is not None else jnp.pi\n",
    "        init_thdot = state.init_thdot if state is not None else 0.0\n",
    "        return OdeState(th=init_th, thdot=init_thdot, loss_task=0.0)\n",
    "\n",
    "    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> OdeOutput:\n",
    "        \"\"\"Default output of the node.\"\"\"\n",
    "        graph_state = graph_state or GraphState()\n",
    "        # Grab output from state\n",
    "        world_state = graph_state.state.get(self.name, self.init_state(rng, graph_state))\n",
    "        return OdeOutput(th=world_state.th, thdot=world_state.thdot)\n",
    "\n",
    "    def init_delays(\n",
    "        self, rng: jax.Array = None, graph_state: base.GraphState = None\n",
    "    ) -> Dict[str, Union[float, jax.typing.ArrayLike]]:\n",
    "        graph_state = graph_state or GraphState()\n",
    "        params = graph_state.params.get(\"actuator\")\n",
    "        delays = {}\n",
    "        if hasattr(params, \"actuator_delay\"):\n",
    "            delays[\"actuator\"] = params.actuator_delay\n",
    "        return delays\n",
    "\n",
    "    def step(self, step_state: StepState) -> Tuple[StepState, OdeOutput]:\n",
    "        \"\"\"Step the node.\"\"\"\n",
    "        # Unpack StepState\n",
    "        _, state, params, inputs = step_state.rng, step_state.state, step_state.params, step_state.inputs\n",
    "\n",
    "        # Apply dynamics\n",
    "        u = inputs[\"actuator\"].data.action[-1][0]  # [-1] to get the latest action, [0] reduces the dimension to scalar\n",
    "        us = jnp.array([u] * params.substeps)\n",
    "        new_state = params.step(params.substeps, params.dt_substeps, state, us)[0]\n",
    "        next_th, next_thdot = new_state.th, new_state.thdot\n",
    "        output = OdeOutput(th=next_th, thdot=next_thdot)  # Prepare output\n",
    "\n",
    "        # Calculate cost (penalize angle error, angular velocity and input voltage)\n",
    "        norm_next_th = next_th - 2 * jnp.pi * jnp.floor((next_th + jnp.pi) / (2 * jnp.pi))\n",
    "        loss_task = state.loss_task + norm_next_th**2 + 0.1 * (next_thdot / (1 + 10 * abs(norm_next_th))) ** 2 + 0.01 * u**2\n",
    "\n",
    "        # Update state\n",
    "        new_state = new_state.replace(loss_task=loss_task)\n",
    "        new_step_state = step_state.replace(state=new_state)\n",
    "        return new_step_state, output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "be52a2b17fc1b826",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-08T10:19:04.794966639Z",
     "start_time": "2024-10-08T10:19:04.599595207Z"
    },
    "cellView": "form",
    "id": "be52a2b17fc1b826"
   },
   "outputs": [],
   "source": [
    "# @title Example: Brax simulation node\n",
    "from typing import Tuple, Union\n",
    "\n",
    "import jax\n",
    "from flax import struct\n",
    "\n",
    "from rex import base\n",
    "from rex.base import GraphState, StepState\n",
    "from rex.node import BaseWorld\n",
    "\n",
    "\n",
    "try:\n",
    "    from brax.generalized import pipeline as gen_pipeline\n",
    "    from brax.io import mjcf\n",
    "    from brax.positional import pipeline as pos_pipeline\n",
    "    from brax.spring import pipeline as spring_pipeline\n",
    "\n",
    "    Systems = Union[gen_pipeline.System, spring_pipeline.System, pos_pipeline.System]\n",
    "    Pipelines = Union[gen_pipeline.State, spring_pipeline.State, pos_pipeline.State]\n",
    "except ModuleNotFoundError as e:\n",
    "    print(\"Brax not installed. Install it with `pip install brax`\")\n",
    "    raise e\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class BraxParams(base.Base):\n",
    "    max_speed: Union[float, jax.typing.ArrayLike]\n",
    "    damping: Union[float, jax.typing.ArrayLike]\n",
    "    armature: Union[float, jax.typing.ArrayLike]\n",
    "    gear: Union[float, jax.typing.ArrayLike]\n",
    "    mass_weight: Union[float, jax.typing.ArrayLike]\n",
    "    radius_weight: Union[float, jax.typing.ArrayLike]\n",
    "    offset: Union[float, jax.typing.ArrayLike]\n",
    "    friction_loss: Union[float, jax.typing.ArrayLike]\n",
    "    backend: str = struct.field(pytree_node=False)\n",
    "    dt: Union[float, jax.typing.ArrayLike] = struct.field(pytree_node=False)\n",
    "\n",
    "    @property\n",
    "    def substeps(self) -> int:\n",
    "        dt_substeps_per_backend = {\"generalized\": 1 / 100, \"spring\": 1 / 100, \"positional\": 1 / 100}[self.backend]\n",
    "        substeps = ceil(self.dt / dt_substeps_per_backend)\n",
    "        return int(substeps)\n",
    "\n",
    "    @property\n",
    "    def dt_substeps(self) -> float:\n",
    "        substeps = self.substeps\n",
    "        dt_substeps = self.dt / substeps\n",
    "        return dt_substeps\n",
    "\n",
    "    @property\n",
    "    def pipeline(self) -> Pipelines:\n",
    "        return {\"generalized\": gen_pipeline, \"spring\": spring_pipeline, \"positional\": pos_pipeline}[self.backend]\n",
    "\n",
    "    @property\n",
    "    def sys(self) -> Systems:\n",
    "        base_sys = mjcf.loads(DISK_PENDULUM_XML)\n",
    "        # Appropriately replace parameters for the disk pendulum\n",
    "        itransform = base_sys.link.inertia.transform.replace(pos=jnp.array([[0.0, self.offset, 0.0]]))\n",
    "        i = base_sys.link.inertia.i.at[0, 0, 0].set(\n",
    "            0.5 * self.mass_weight * self.radius_weight**2\n",
    "        )  # inertia of cylinder in local frame.\n",
    "        inertia = base_sys.link.inertia.replace(transform=itransform, mass=jnp.array([self.mass_weight]), i=i)\n",
    "        link = base_sys.link.replace(inertia=inertia)\n",
    "        actuator = base_sys.actuator.replace(gear=jnp.array([self.gear]))\n",
    "        dof = base_sys.dof.replace(armature=jnp.array([self.armature]), damping=jnp.array([self.damping]))\n",
    "        opt = base_sys.opt.replace(timestep=self.dt_substeps)\n",
    "        new_sys = base_sys.replace(link=link, actuator=actuator, dof=dof, opt=opt)\n",
    "        return new_sys\n",
    "\n",
    "    def step(\n",
    "        self, substeps: int, dt_substeps: jax.typing.ArrayLike, x: Pipelines, us: jax.typing.ArrayLike\n",
    "    ) -> Tuple[Pipelines, Pipelines]:\n",
    "        \"\"\"Step the pendulum ode.\"\"\"\n",
    "        # Appropriately replace timestep for the disk pendulum\n",
    "        sys = self.sys.replace(opt=self.sys.opt.replace(timestep=dt_substeps))\n",
    "\n",
    "        def _scan_fn(_x, _u):\n",
    "            # Add friction loss\n",
    "            thdot = x.qd[0]\n",
    "            activation = jnp.sign(thdot)\n",
    "            friction = self.friction_loss * activation / sys.actuator.gear[0]\n",
    "            _u_friction = _u - friction\n",
    "            # Step\n",
    "            next_x = gen_pipeline.step(sys, _x, jnp.array(_u_friction)[None])\n",
    "            # Clip velocity\n",
    "            next_x = next_x.replace(qd=jnp.clip(next_x.qd, -self.max_speed, self.max_speed))\n",
    "            return next_x, next_x\n",
    "\n",
    "        x_final, x_substeps = jax.lax.scan(_scan_fn, x, us, length=substeps)\n",
    "        return x_final, x_substeps\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class BraxState(base.Base):\n",
    "    \"\"\"Pendulum state definition\"\"\"\n",
    "\n",
    "    loss_task: Union[float, jax.typing.ArrayLike]\n",
    "    pipeline_state: Pipelines\n",
    "\n",
    "    @property\n",
    "    def th(self):\n",
    "        return self.pipeline_state.q[..., 0]\n",
    "\n",
    "    @property\n",
    "    def thdot(self):\n",
    "        return self.pipeline_state.qd[..., 0]\n",
    "\n",
    "\n",
    "@struct.dataclass\n",
    "class BraxOutput(base.Base):\n",
    "    \"\"\"World output definition\"\"\"\n",
    "\n",
    "    th: Union[float, jax.typing.ArrayLike]\n",
    "    thdot: Union[float, jax.typing.ArrayLike]\n",
    "\n",
    "\n",
    "class BraxWorld(BaseWorld):  # We inherit from BaseWorld for convenience, but you can inherit from BaseNode if you want\n",
    "    def __init__(self, *args, backend: str = \"generalized\", **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.backend = backend\n",
    "\n",
    "    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> BraxParams:\n",
    "        \"\"\"Default params of the node.\"\"\"\n",
    "        return BraxParams(\n",
    "            # Realistic parameters for the disk pendulum\n",
    "            max_speed=40.0,\n",
    "            damping=0.00015877,\n",
    "            armature=6.4940527e-06,\n",
    "            gear=0.00428677,\n",
    "            mass_weight=0.05076142,\n",
    "            radius_weight=0.05121992,\n",
    "            offset=0.04161447,\n",
    "            friction_loss=0.00097525,\n",
    "            # Backend parameters\n",
    "            dt=1 / self.rate,\n",
    "            backend=self.backend,\n",
    "        )\n",
    "\n",
    "    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -> BraxState:\n",
    "        \"\"\"Default state of the node.\"\"\"\n",
    "        graph_state = graph_state or GraphState()\n",
    "\n",
    "        # Try to grab state from graph_state\n",
    "        state = graph_state.state.get(\"agent\", None)\n",
    "        init_th = state.init_th if state is not None else jnp.pi\n",
    "        init_thdot = state.init_thdot if state is not None else 0.0\n",
    "\n",
    "        # Set the initial state of the disk pendulum\n",
    "        params = graph_state.params.get(self.name, self.init_params(rng, graph_state))\n",
    "        sys = params.sys\n",
    "        q = sys.init_q.at[0].set(init_th)\n",
    "        qd = jnp.array([init_thdot])\n",
    "        pipeline_state = params.pipeline.init(sys, q, qd)\n",
    "        return BraxState(pipeline_state=pipeline_state, loss_task=0.0)\n",
    "\n",
    "    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> BraxOutput:\n",
    "        \"\"\"Default output of the node.\"\"\"\n",
    "        graph_state = graph_state or GraphState()\n",
    "        # Grab output from state\n",
    "        state = graph_state.state.get(self.name, self.init_state(rng, graph_state))\n",
    "        return BraxOutput(th=state.pipeline_state.q[0], thdot=state.pipeline_state.qd[0])\n",
    "\n",
    "    def step(self, step_state: StepState) -> Tuple[StepState, BraxOutput]:\n",
    "        \"\"\"Step the node.\"\"\"\n",
    "\n",
    "        # Unpack StepState\n",
    "        _, state, params, inputs = step_state.rng, step_state.state, step_state.params, step_state.inputs\n",
    "\n",
    "        # Apply dynamics\n",
    "        u = inputs[\"actuator\"].data.action[-1][0]  # [-1] to get the latest action, [0] reduces the dimension to scalar\n",
    "        us = jnp.array([u] * params.substeps)\n",
    "        x = state.pipeline_state\n",
    "        next_x = params.step(params.substeps, params.dt_substeps, x, us)[0]\n",
    "        new_state = state.replace(pipeline_state=next_x)\n",
    "        next_th, next_thdot = new_state.th, new_state.thdot\n",
    "        output = BraxOutput(th=next_th, thdot=next_thdot)  # Prepare output\n",
    "\n",
    "        # Calculate cost (penalize angle error, angular velocity and input voltage)\n",
    "        norm_next_th = next_th - 2 * jnp.pi * jnp.floor((next_th + jnp.pi) / (2 * jnp.pi))\n",
    "        loss_task = state.loss_task + norm_next_th**2 + 0.1 * (next_thdot / (1 + 10 * abs(norm_next_th))) ** 2 + 0.01 * u**2\n",
    "\n",
    "        # Update state\n",
    "        new_state = new_state.replace(loss_task=loss_task)\n",
    "        new_step_state = step_state.replace(state=new_state)\n",
    "        return new_step_state, output\n",
    "\n",
    "\n",
    "DISK_PENDULUM_XML = \"\"\"\n",
    "<mujoco model=\"disk_pendulum\">\n",
    "    <compiler inertiafromgeom=\"auto\" angle=\"radian\" coordinate=\"local\" eulerseq=\"xyz\" autolimits=\"true\"/>\n",
    "    <option gravity=\"0 0 -9.81\" timestep=\"0.01\" iterations=\"10\"/>\n",
    "    <custom>\n",
    "        <numeric data=\"10\" name=\"constraint_ang_damping\"/> <!-- positional & spring -->\n",
    "        <numeric data=\"1\" name=\"spring_inertia_scale\"/>  <!-- positional & spring -->\n",
    "        <numeric data=\"0\" name=\"ang_damping\"/>  <!-- positional & spring -->\n",
    "        <numeric data=\"0\" name=\"spring_mass_scale\"/>  <!-- positional & spring -->\n",
    "        <numeric data=\"0.5\" name=\"joint_scale_pos\"/> <!-- positional -->\n",
    "        <numeric data=\"0.1\" name=\"joint_scale_ang\"/> <!-- positional -->\n",
    "        <numeric data=\"3000\" name=\"constraint_stiffness\"/>  <!-- spring -->\n",
    "        <numeric data=\"10000\" name=\"constraint_limit_stiffness\"/>  <!-- spring -->\n",
    "        <numeric data=\"50\" name=\"constraint_vel_damping\"/>  <!-- spring -->\n",
    "        <numeric data=\"10\" name=\"solver_maxls\"/>  <!-- generalized -->\n",
    "    </custom>\n",
    "\n",
    "    <asset>\n",
    "        <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\n",
    "        <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\n",
    "    </asset>\n",
    "\n",
    "    <default>\n",
    "        <geom contype=\"0\" friction=\"1 0.1 0.1\" material=\"geom\"/>\n",
    "    </default>\n",
    "\n",
    "    <worldbody>\n",
    "        <light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/>\n",
    "        <geom name=\"table\" type=\"plane\" pos=\"0 0.0 -0.1\" size=\"1 1 0.1\" contype=\"8\" conaffinity=\"11\" condim=\"3\"/>\n",
    "        <body name=\"disk\" pos=\"0.0 0.0 0.0\" euler=\"1.5708 0.0 0.0\">\n",
    "            <joint name=\"hinge_joint\" type=\"hinge\" axis=\"0 0 1\" range=\"-180 180\" armature=\"0.00022993\" damping=\"0.0001\" limited=\"false\"/>\n",
    "            <geom name=\"disk_geom\" type=\"cylinder\" size=\"0.06 0.001\" contype=\"0\" conaffinity=\"0\" condim=\"3\" mass=\"0.0\"/>\n",
    "            <geom name=\"mass_geom\" type=\"cylinder\" size=\"0.02 0.005\" contype=\"0\" conaffinity=\"0\"  condim=\"3\" rgba=\"0.04 0.04 0.04 1\"\n",
    "                  pos=\"0.0 0.04 0.\" mass=\"0.05085817\"/>\n",
    "        </body>\n",
    "    </worldbody>\n",
    "\n",
    "    <actuator>\n",
    "        <motor joint=\"hinge_joint\" ctrllimited=\"false\" ctrlrange=\"-3.0 3.0\"  gear=\"0.01\"/>\n",
    "    </actuator>\n",
    "</mujoco>\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2048c34d476e6ebd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-08T10:19:04.798554708Z",
     "start_time": "2024-10-08T10:19:04.795782471Z"
    },
    "id": "2048c34d476e6ebd"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
