{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "planned-pharmacy",
   "metadata": {},
   "source": [
    "# Automatic rendering of NumPyro models\n",
    "\n",
    "In this tutorial we will demonstrate how to create beautiful visualizations of your probabilistic graphical models using [numpyro.render_model](https://num.pyro.ai/en/stable/utilities.html#render-model)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3292972e-941b-4ad7-933a-7c0f11d32ef3",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "nearby-beach",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import flax.linen as flax_nn\n",
    "from jax import nn\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import numpyro\n",
    "from numpyro.contrib.module import flax_module\n",
    "import numpyro.distributions as dist\n",
    "import numpyro.distributions.constraints as constraints\n",
    "\n",
    "assert numpyro.__version__.startswith(\"0.15.0\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "pleasant-alias",
   "metadata": {},
   "source": [
    "## A Simple Example\n",
    "\n",
    "The visualization interface can be readily used with your models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fresh-throw",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(data):\n",
    "    m = numpyro.sample(\"m\", dist.Normal(0, 1))\n",
    "    sd = numpyro.sample(\"sd\", dist.LogNormal(m, 1))\n",
    "    with numpyro.plate(\"N\", len(data)):\n",
    "        numpyro.sample(\"obs\", dist.Normal(m, sd), obs=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "capital-ferry",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 7.0.3 (0)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"105pt\" height=\"227pt\"\n",
       " viewBox=\"0.00 0.00 105.00 227.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 223)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-223 101,-223 101,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster_N</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"19,-8 19,-83 89,-83 89,-8 19,-8\"/>\n",
       "<text text-anchor=\"middle\" x=\"74.5\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">N</text>\n",
       "</g>\n",
       "<!-- m -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>m</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"54\" cy=\"-201\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"54\" y=\"-197.3\" font-family=\"Times,serif\" font-size=\"14.00\">m</text>\n",
       "</g>\n",
       "<!-- sd -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>sd</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"27\" cy=\"-129\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-125.3\" font-family=\"Times,serif\" font-size=\"14.00\">sd</text>\n",
       "</g>\n",
       "<!-- m&#45;&gt;sd -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>m&#45;&gt;sd</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M47.46,-183.05C44.48,-175.32 40.87,-165.96 37.52,-157.27\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"40.88,-156.27 34.02,-148.2 34.35,-158.79 40.88,-156.27\"/>\n",
       "</g>\n",
       "<!-- obs -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>obs</title>\n",
       "<ellipse fill=\"grey\" stroke=\"black\" cx=\"54\" cy=\"-57\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"54\" y=\"-53.3\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n",
       "</g>\n",
       "<!-- m&#45;&gt;obs -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>m&#45;&gt;obs</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M57.65,-182.91C59.68,-172.57 61.98,-159.09 63,-147 64.34,-131.06 64.34,-126.94 63,-111 62.32,-102.97 61.08,-94.33 59.73,-86.4\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"63.17,-85.79 57.93,-76.58 56.29,-87.05 63.17,-85.79\"/>\n",
       "</g>\n",
       "<!-- sd&#45;&gt;obs -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>sd&#45;&gt;obs</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M33.54,-111.05C36.52,-103.32 40.13,-93.96 43.48,-85.27\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"46.65,-86.79 46.98,-76.2 40.12,-84.27 46.65,-86.79\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fe553f0a1f0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = jnp.ones(10)\n",
    "numpyro.render_model(model, model_args=(data,))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "preceding-moisture",
   "metadata": {},
   "source": [
    "The visualization can be saved to a file by providing `filename='path'` to `numpyro.render_model`. You can use different formats such as PDF or PNG by changing the filename's suffix.\n",
    "When not saving to a file (`filename=None`), you can also change the format with `graph.format = 'pdf'` where `graph` is the object returned by `numpyro.render_model`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "extreme-bacteria",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = numpyro.render_model(model, model_args=(data,), filename=\"model.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "naughty-intent",
   "metadata": {},
   "source": [
    "## Tweaking the visualization\n",
    "\n",
    "As `numpyro.render_model` returns an object of type `graphviz.dot.Digraph`, you can further improve the visualization of this graph.\n",
    "For example, you could use the [unflatten preprocessor](https://graphviz.readthedocs.io/en/stable/api.html#graphviz.unflatten) to improve the layout aspect ratio for more complex models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "coordinate-valve",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mace(positions, annotations):\n",
    "    \"\"\"\n",
    "    This model corresponds to the plate diagram in Figure 3 of https://www.aclweb.org/anthology/Q18-1040.pdf.\n",
    "    \"\"\"\n",
    "    num_annotators = int(np.max(positions)) + 1\n",
    "    num_classes = int(np.max(annotations)) + 1\n",
    "    num_items, num_positions = annotations.shape\n",
    "\n",
    "    with numpyro.plate(\"annotator\", num_annotators):\n",
    "        epsilon = numpyro.sample(\"epsilon\", dist.Dirichlet(jnp.full(num_classes, 10)))\n",
    "        theta = numpyro.sample(\"theta\", dist.Beta(0.5, 0.5))\n",
    "\n",
    "    with numpyro.plate(\"item\", num_items, dim=-2):\n",
    "        c = numpyro.sample(\"c\", dist.DiscreteUniform(0, num_classes - 1))\n",
    "\n",
    "        with numpyro.plate(\"position\", num_positions):\n",
    "            s = numpyro.sample(\"s\", dist.Bernoulli(1 - theta[positions]))\n",
    "            probs = jnp.where(\n",
    "                s[..., None] == 0, nn.one_hot(c, num_classes), epsilon[positions]\n",
    "            )\n",
    "            numpyro.sample(\"y\", dist.Categorical(probs), obs=annotations)\n",
    "\n",
    "\n",
    "positions = np.array([1, 1, 1, 2, 3, 4, 5])\n",
    "# fmt: off\n",
    "annotations = np.array([\n",
    "    [1, 3, 1, 2, 2, 2, 1, 3, 2, 2, 4, 2, 1, 2, 1,\n",
    "     1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,\n",
    "     1, 3, 1, 2, 2, 4, 2, 2, 3, 1, 1, 1, 2, 1, 2],\n",
    "    [1, 3, 1, 2, 2, 2, 2, 3, 2, 3, 4, 2, 1, 2, 2,\n",
    "     1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 3, 1, 1, 1,\n",
    "     1, 3, 1, 2, 2, 3, 2, 3, 3, 1, 1, 2, 3, 2, 2],\n",
    "    [1, 3, 2, 2, 2, 2, 2, 3, 2, 2, 4, 2, 1, 2, 1,\n",
    "     1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2,\n",
    "     1, 3, 1, 2, 2, 3, 1, 2, 3, 1, 1, 1, 2, 1, 2],\n",
    "    [1, 4, 2, 3, 3, 3, 2, 3, 2, 2, 4, 3, 1, 3, 1,\n",
    "     2, 1, 1, 2, 1, 2, 2, 3, 2, 1, 1, 2, 1, 1, 1,\n",
    "     1, 3, 1, 2, 3, 4, 2, 3, 3, 1, 1, 2, 2, 1, 2],\n",
    "    [1, 3, 1, 1, 2, 3, 1, 4, 2, 2, 4, 3, 1, 2, 1,\n",
    "     1, 1, 1, 2, 3, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,\n",
    "     1, 2, 1, 2, 2, 3, 2, 2, 4, 1, 1, 1, 2, 1, 2],\n",
    "    [1, 3, 2, 2, 2, 2, 1, 3, 2, 2, 4, 4, 1, 1, 1,\n",
    "     1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 2,\n",
    "     1, 3, 1, 2, 3, 4, 3, 3, 3, 1, 1, 1, 2, 1, 2],\n",
    "    [1, 4, 2, 1, 2, 2, 1, 3, 3, 3, 4, 3, 1, 2, 1,\n",
    "     1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1,\n",
    "     1, 3, 1, 2, 2, 3, 2, 3, 2, 1, 1, 1, 2, 1, 2],\n",
    "]).T\n",
    "# fmt: on\n",
    "\n",
    "# we subtract 1 because the first index starts with 0 in Python\n",
    "positions -= 1\n",
    "annotations -= 1\n",
    "\n",
    "mace_graph = numpyro.render_model(mace, model_args=(positions, annotations))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "loose-spotlight",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 7.0.3 (0)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"233pt\" height=\"293pt\"\n",
       " viewBox=\"0.00 0.00 233.00 293.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 289)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-289 229,-289 229,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster_annotator</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"8,-202 8,-277 200,-277 200,-202 8,-202\"/>\n",
       "<text text-anchor=\"middle\" x=\"157\" y=\"-209.8\" font-family=\"Times,serif\" font-size=\"14.00\">annotator</text>\n",
       "</g>\n",
       "<g id=\"clust2\" class=\"cluster\">\n",
       "<title>cluster_item</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"65,-8 65,-194 217,-194 217,-8 65,-8\"/>\n",
       "<text text-anchor=\"middle\" x=\"192.5\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">item</text>\n",
       "</g>\n",
       "<g id=\"clust3\" class=\"cluster\">\n",
       "<title>cluster_position</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"135,-39 135,-186 209,-186 209,-39 135,-39\"/>\n",
       "<text text-anchor=\"middle\" x=\"172\" y=\"-46.8\" font-family=\"Times,serif\" font-size=\"14.00\">position</text>\n",
       "</g>\n",
       "<!-- epsilon -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>epsilon</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"60\" cy=\"-251\" rx=\"44.39\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"60\" y=\"-247.3\" font-family=\"Times,serif\" font-size=\"14.00\">epsilon</text>\n",
       "</g>\n",
       "<!-- y -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>y</title>\n",
       "<ellipse fill=\"grey\" stroke=\"black\" cx=\"170\" cy=\"-88\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"170\" y=\"-84.3\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n",
       "</g>\n",
       "<!-- epsilon&#45;&gt;y -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>epsilon&#45;&gt;y</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M54.97,-232.74C49.44,-210.02 43.47,-169.74 61,-142 76.8,-117.01 108.08,-103.49 133.1,-96.37\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"133.8,-99.81 142.61,-93.91 132.05,-93.03 133.8,-99.81\"/>\n",
       "</g>\n",
       "<!-- theta -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>theta</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"157\" cy=\"-251\" rx=\"35.19\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"157\" y=\"-247.3\" font-family=\"Times,serif\" font-size=\"14.00\">theta</text>\n",
       "</g>\n",
       "<!-- s -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>s</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"172\" cy=\"-160\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"172\" y=\"-156.3\" font-family=\"Times,serif\" font-size=\"14.00\">s</text>\n",
       "</g>\n",
       "<!-- theta&#45;&gt;s -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>theta&#45;&gt;s</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M159.89,-232.84C161.97,-220.53 164.81,-203.66 167.22,-189.36\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"170.65,-190.04 168.87,-179.6 163.75,-188.88 170.65,-190.04\"/>\n",
       "</g>\n",
       "<!-- c -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>c</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"100\" cy=\"-160\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"100\" y=\"-156.3\" font-family=\"Times,serif\" font-size=\"14.00\">c</text>\n",
       "</g>\n",
       "<!-- c&#45;&gt;y -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>c&#45;&gt;y</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M114.5,-144.5C124.08,-134.92 136.81,-122.19 147.66,-111.34\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"149.85,-114.1 154.44,-104.56 144.9,-109.15 149.85,-114.1\"/>\n",
       "</g>\n",
       "<!-- s&#45;&gt;y -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>s&#45;&gt;y</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M171.51,-141.7C171.3,-134.41 171.05,-125.73 170.82,-117.54\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"174.32,-117.51 170.53,-107.62 167.32,-117.71 174.32,-117.51\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fe590b83d60>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# default layout\n",
    "mace_graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "thrown-filling",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 7.0.3 (0)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"322pt\" height=\"285pt\"\n",
       " viewBox=\"0.00 0.00 322.00 285.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 281)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-281 318,-281 318,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster_annotator</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"8,-194 8,-269 200,-269 200,-194 8,-194\"/>\n",
       "<text text-anchor=\"middle\" x=\"157\" y=\"-201.8\" font-family=\"Times,serif\" font-size=\"14.00\">annotator</text>\n",
       "</g>\n",
       "<g id=\"clust2\" class=\"cluster\">\n",
       "<title>cluster_item</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"208,-8 208,-269 306,-269 306,-8 208,-8\"/>\n",
       "<text text-anchor=\"middle\" x=\"281.5\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">item</text>\n",
       "</g>\n",
       "<g id=\"clust3\" class=\"cluster\">\n",
       "<title>cluster_position</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"216,-39 216,-186 290,-186 290,-39 216,-39\"/>\n",
       "<text text-anchor=\"middle\" x=\"253\" y=\"-46.8\" font-family=\"Times,serif\" font-size=\"14.00\">position</text>\n",
       "</g>\n",
       "<!-- epsilon -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>epsilon</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"60\" cy=\"-243\" rx=\"44.39\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"60\" y=\"-239.3\" font-family=\"Times,serif\" font-size=\"14.00\">epsilon</text>\n",
       "</g>\n",
       "<!-- y -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>y</title>\n",
       "<ellipse fill=\"grey\" stroke=\"black\" cx=\"251\" cy=\"-88\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"251\" y=\"-84.3\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n",
       "</g>\n",
       "<!-- epsilon&#45;&gt;y -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>epsilon&#45;&gt;y</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M76.86,-226.06C87.11,-216.53 100.59,-204.31 113,-194 150.49,-162.85 195.76,-129.1 223.97,-108.5\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"225.95,-111.39 231.98,-102.68 221.83,-105.73 225.95,-111.39\"/>\n",
       "</g>\n",
       "<!-- theta -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>theta</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"157\" cy=\"-243\" rx=\"35.19\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"157\" y=\"-239.3\" font-family=\"Times,serif\" font-size=\"14.00\">theta</text>\n",
       "</g>\n",
       "<!-- s -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>s</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"251\" cy=\"-160\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"251\" y=\"-156.3\" font-family=\"Times,serif\" font-size=\"14.00\">s</text>\n",
       "</g>\n",
       "<!-- theta&#45;&gt;s -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>theta&#45;&gt;s</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M171.14,-226.14C180.14,-216.42 192.27,-203.97 204,-194 209.92,-188.97 216.6,-183.94 223.01,-179.39\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"224.63,-182.52 230.87,-173.96 220.65,-176.76 224.63,-182.52\"/>\n",
       "</g>\n",
       "<!-- s&#45;&gt;y -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>s&#45;&gt;y</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M251,-141.7C251,-134.41 251,-125.73 251,-117.54\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"254.5,-117.62 251,-107.62 247.5,-117.62 254.5,-117.62\"/>\n",
       "</g>\n",
       "<!-- c -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>c</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"271\" cy=\"-243\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"271\" y=\"-239.3\" font-family=\"Times,serif\" font-size=\"14.00\">c</text>\n",
       "</g>\n",
       "<!-- c&#45;&gt;y -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>c&#45;&gt;y</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M278.18,-225.4C285.85,-205.29 295.79,-170.57 287,-142 283.82,-131.65 277.85,-121.51 271.67,-112.91\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"274.55,-110.91 265.67,-105.13 269.01,-115.19 274.55,-110.91\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.sources.Source at 0x7fe590bda940>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# layout after processing the layout with unflatten\n",
    "mace_graph.unflatten(stagger=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50a92902",
   "metadata": {},
   "source": [
    "## Rendering the parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74f32e20",
   "metadata": {},
   "source": [
    "We can render the parameters defined as `numpyro.param` by setting `render_params=True` in `numpyro.render_model`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "645df936",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(data):\n",
    "    m = numpyro.param(\"m\", 0.0)\n",
    "    sd = numpyro.param(\"sd\", 1.0, constraint=constraints.positive)\n",
    "    lambd = numpyro.sample(\"lambda\", dist.LogNormal(m, sd))\n",
    "    with numpyro.plate(\"N\", len(data)):\n",
    "        numpyro.sample(\"obs\", dist.Exponential(lambd), obs=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "66fc9f55",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 7.0.3 (0)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"98pt\" height=\"206pt\"\n",
       " viewBox=\"0.00 0.00 97.69 206.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 202)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-202 93.69,-202 93.69,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster_N</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"9.85,-8 9.85,-83 79.85,-83 79.85,-8 9.85,-8\"/>\n",
       "<text text-anchor=\"middle\" x=\"65.35\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">N</text>\n",
       "</g>\n",
       "<!-- lambda -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>lambda</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"44.85\" cy=\"-129\" rx=\"44.69\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"44.85\" y=\"-125.3\" font-family=\"Times,serif\" font-size=\"14.00\">lambda</text>\n",
       "</g>\n",
       "<!-- obs -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>obs</title>\n",
       "<ellipse fill=\"grey\" stroke=\"black\" cx=\"44.85\" cy=\"-57\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"44.85\" y=\"-53.3\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n",
       "</g>\n",
       "<!-- lambda&#45;&gt;obs -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>lambda&#45;&gt;obs</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M44.85,-110.7C44.85,-103.41 44.85,-94.73 44.85,-86.54\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"48.35,-86.62 44.85,-76.62 41.35,-86.62 48.35,-86.62\"/>\n",
       "</g>\n",
       "<!-- sd -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>sd</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"36.35,-198 19.35,-198 19.35,-183 36.35,-183 36.35,-198\"/>\n",
       "<text text-anchor=\"middle\" x=\"27.85\" y=\"-186.8\" font-family=\"Times,serif\" font-size=\"14.00\">sd</text>\n",
       "</g>\n",
       "<!-- sd&#45;&gt;lambda -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>sd&#45;&gt;lambda</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M29.67,-183.13C31.39,-177.08 34.12,-167.54 36.79,-158.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"40.15,-159.19 39.53,-148.62 33.41,-157.27 40.15,-159.19\"/>\n",
       "</g>\n",
       "<!-- m -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>m</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"68.85,-198 54.85,-198 54.85,-183 68.85,-183 68.85,-198\"/>\n",
       "<text text-anchor=\"middle\" x=\"61.85\" y=\"-186.8\" font-family=\"Times,serif\" font-size=\"14.00\">m</text>\n",
       "</g>\n",
       "<!-- m&#45;&gt;lambda -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>m&#45;&gt;lambda</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M60.03,-183.13C58.3,-177.08 55.57,-167.54 52.9,-158.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"56.28,-157.27 50.16,-148.62 49.55,-159.19 56.28,-157.27\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fe5b0516640>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = jnp.ones(10)\n",
    "numpyro.render_model(model, model_args=(data,), render_params=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f09e73e1",
   "metadata": {},
   "source": [
    "## Distribution and Constraint annotations\n",
    "\n",
    "It is possible to display the distribution of each RV in the generated plot by providing `render_distributions=True` when calling `numpyro.render_model`. The constraints associated with parameters are also displayed when `render_distributions=True`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e3ac34ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 7.0.3 (0)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"374pt\" height=\"244pt\"\n",
       " viewBox=\"0.00 0.00 374.35 244.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 240)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-240 370.35,-240 370.35,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster_N</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"9.85,-8 9.85,-83 79.85,-83 79.85,-8 9.85,-8\"/>\n",
       "<text text-anchor=\"middle\" x=\"65.35\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">N</text>\n",
       "</g>\n",
       "<!-- lambda -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>lambda</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"44.85\" cy=\"-129\" rx=\"44.69\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"44.85\" y=\"-125.3\" font-family=\"Times,serif\" font-size=\"14.00\">lambda</text>\n",
       "</g>\n",
       "<!-- obs -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>obs</title>\n",
       "<ellipse fill=\"grey\" stroke=\"black\" cx=\"44.85\" cy=\"-57\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"44.85\" y=\"-53.3\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n",
       "</g>\n",
       "<!-- lambda&#45;&gt;obs -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>lambda&#45;&gt;obs</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M44.85,-110.7C44.85,-103.41 44.85,-94.73 44.85,-86.54\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"48.35,-86.62 44.85,-76.62 41.35,-86.62 48.35,-86.62\"/>\n",
       "</g>\n",
       "<!-- sd -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>sd</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"36.35,-217 19.35,-217 19.35,-202 36.35,-202 36.35,-217\"/>\n",
       "<text text-anchor=\"middle\" x=\"27.85\" y=\"-205.8\" font-family=\"Times,serif\" font-size=\"14.00\">sd</text>\n",
       "</g>\n",
       "<!-- sd&#45;&gt;lambda -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>sd&#45;&gt;lambda</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M29.24,-202.08C31.3,-192.53 35.31,-174.02 38.75,-158.14\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"42.1,-159.21 40.8,-148.7 35.26,-157.73 42.1,-159.21\"/>\n",
       "</g>\n",
       "<!-- m -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>m</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"68.85,-217 54.85,-217 54.85,-202 68.85,-202 68.85,-217\"/>\n",
       "<text text-anchor=\"middle\" x=\"61.85\" y=\"-205.8\" font-family=\"Times,serif\" font-size=\"14.00\">m</text>\n",
       "</g>\n",
       "<!-- m&#45;&gt;lambda -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>m&#45;&gt;lambda</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M60.45,-202.08C58.39,-192.53 54.38,-174.02 50.94,-158.14\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"54.43,-157.73 48.9,-148.7 47.59,-159.21 54.43,-157.73\"/>\n",
       "</g>\n",
       "<!-- distribution_description_node -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>distribution_description_node</title>\n",
       "<text text-anchor=\"start\" x=\"95.35\" y=\"-220.8\" font-family=\"Times,serif\" font-size=\"14.00\">lambda ~ LogNormal</text>\n",
       "<text text-anchor=\"start\" x=\"95.35\" y=\"-205.8\" font-family=\"Times,serif\" font-size=\"14.00\">obs ~ Exponential</text>\n",
       "<text text-anchor=\"start\" x=\"95.35\" y=\"-190.8\" font-family=\"Times,serif\" font-size=\"14.00\">sd ∈ GreaterThan(lower_bound=0.0)</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fe5b04970a0>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "numpyro.render_model(\n",
    "    model, model_args=(data,), render_params=True, render_distributions=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a1e4c4e",
   "metadata": {},
   "source": [
    "In the above plot **'~'** denotes the distribution of RV and **'$\\in$'** denotes the constraint of parameter."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dd92f46-8221-4b94-b6c9-32dafedb0e48",
   "metadata": {},
   "source": [
    "## Rendering deterministic sites"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f247446-d110-4ed5-8943-72ccc5d26f6c",
   "metadata": {},
   "source": [
    "We can also render deterministic sites defined via `numpyro.deterministic`. Such sites will be drawn with a dashed-line to distinguish from random sites. The following example illustrates this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "105d3c5e-a8f0-4b3e-b5fe-38f3bf729994",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(data):\n",
    "    m = numpyro.sample(\"m\", dist.Normal(0, 1))\n",
    "    sd = numpyro.sample(\"sd\", dist.LogNormal(m, 1))\n",
    "    # deterministic site\n",
    "    m_transformed = numpyro.deterministic(\"m_transformed\", m + 1)\n",
    "    with numpyro.plate(\"N\", len(data)):\n",
    "        numpyro.sample(\"obs\", dist.Normal(m_transformed, sd), obs=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9cfc818f-6338-4b2f-a37c-f584955a1fd8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 7.0.3 (0)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"242pt\" height=\"227pt\"\n",
       " viewBox=\"0.00 0.00 242.24 227.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 223)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-223 238.24,-223 238.24,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster_N</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"55,-8 55,-83 125,-83 125,-8 55,-8\"/>\n",
       "<text text-anchor=\"middle\" x=\"110.5\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">N</text>\n",
       "</g>\n",
       "<!-- m -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>m</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"90\" cy=\"-201\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"90\" y=\"-197.3\" font-family=\"Times,serif\" font-size=\"14.00\">m</text>\n",
       "</g>\n",
       "<!-- sd -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>sd</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"27\" cy=\"-129\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-125.3\" font-family=\"Times,serif\" font-size=\"14.00\">sd</text>\n",
       "</g>\n",
       "<!-- m&#45;&gt;sd -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>m&#45;&gt;sd</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M76.64,-185.15C68.29,-175.87 57.35,-163.73 47.87,-153.19\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"50.65,-151.04 41.35,-145.95 45.44,-155.72 50.65,-151.04\"/>\n",
       "</g>\n",
       "<!-- m_transformed -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>m_transformed</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"153\" cy=\"-129\" rx=\"81.49\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"153\" y=\"-125.3\" font-family=\"Times,serif\" font-size=\"14.00\">m_transformed</text>\n",
       "</g>\n",
       "<!-- m&#45;&gt;m_transformed -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>m&#45;&gt;m_transformed</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M103.36,-185.15C111.13,-176.53 121.12,-165.42 130.11,-155.43\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"132.56,-157.94 136.65,-148.17 127.36,-153.26 132.56,-157.94\"/>\n",
       "</g>\n",
       "<!-- obs -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>obs</title>\n",
       "<ellipse fill=\"grey\" stroke=\"black\" cx=\"90\" cy=\"-57\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"90\" y=\"-53.3\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n",
       "</g>\n",
       "<!-- sd&#45;&gt;obs -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>sd&#45;&gt;obs</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M40.36,-113.15C48.71,-103.87 59.65,-91.73 69.13,-81.19\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"71.56,-83.72 75.65,-73.95 66.35,-79.04 71.56,-83.72\"/>\n",
       "</g>\n",
       "<!-- m_transformed&#45;&gt;obs -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>m_transformed&#45;&gt;obs</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M137.75,-111.05C129.72,-102.13 119.75,-91.06 111,-81.33\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"113.7,-79.1 104.41,-74.01 108.5,-83.79 113.7,-79.1\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fe590b8a130>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = jnp.ones(10)\n",
    "numpyro.render_model(model, model_args=(data,))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cad4b94a",
   "metadata": {},
   "source": [
    "## Rendering neural network's parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "cc0df5e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(data):\n",
    "    lambda_base = numpyro.sample(\"lambda\", dist.Normal(0, 1))\n",
    "    net = flax_module(\"affine_net\", flax_nn.Dense(1), input_shape=(1,))\n",
    "    lambd = jnp.exp(net(jnp.expand_dims(lambda_base, -1)).squeeze(-1))\n",
    "    with numpyro.plate(\"N\", len(data)):\n",
    "        numpyro.sample(\"obs\", dist.Exponential(lambd), obs=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e24497d7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 7.0.3 (0)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"352pt\" height=\"157pt\"\n",
       " viewBox=\"0.00 0.00 351.85 157.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 153)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-153 347.85,-153 347.85,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster_N</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"58.85,-8 58.85,-83 128.85,-83 128.85,-8 58.85,-8\"/>\n",
       "<text text-anchor=\"middle\" x=\"114.35\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">N</text>\n",
       "</g>\n",
       "<!-- lambda -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>lambda</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"44.85\" cy=\"-130\" rx=\"44.69\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"44.85\" y=\"-126.3\" font-family=\"Times,serif\" font-size=\"14.00\">lambda</text>\n",
       "</g>\n",
       "<!-- obs -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>obs</title>\n",
       "<ellipse fill=\"grey\" stroke=\"black\" cx=\"93.85\" cy=\"-57\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"93.85\" y=\"-53.3\" font-family=\"Times,serif\" font-size=\"14.00\">obs</text>\n",
       "</g>\n",
       "<!-- lambda&#45;&gt;obs -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>lambda&#45;&gt;obs</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M56.46,-112.17C62.49,-103.44 69.97,-92.6 76.65,-82.91\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"79.33,-85.2 82.13,-74.98 73.57,-81.22 79.33,-85.2\"/>\n",
       "</g>\n",
       "<!-- affine_net$params -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>affine_net$params</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"177.85,-137.5 107.85,-137.5 107.85,-122.5 177.85,-122.5 177.85,-137.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"142.85\" y=\"-126.3\" font-family=\"Times,serif\" font-size=\"14.00\">affine_net</text>\n",
       "</g>\n",
       "<!-- affine_net$params&#45;&gt;obs -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>affine_net$params&#45;&gt;obs</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M138.67,-122.95C132.61,-114.17 121.04,-97.41 111.17,-83.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"114.12,-81.22 105.56,-74.98 108.36,-85.2 114.12,-81.22\"/>\n",
       "</g>\n",
       "<!-- distribution_description_node -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>distribution_description_node</title>\n",
       "<text text-anchor=\"start\" x=\"203.85\" y=\"-133.8\" font-family=\"Times,serif\" font-size=\"14.00\">lambda ~ Normal</text>\n",
       "<text text-anchor=\"start\" x=\"203.85\" y=\"-118.8\" font-family=\"Times,serif\" font-size=\"14.00\">obs ~ Exponential</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fe5b0526fa0>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "numpyro.render_model(\n",
    "    model, model_args=(data,), render_distributions=True, render_params=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "751d2870-66e5-4c6a-8806-62d901beedad",
   "metadata": {},
   "source": [
    "## Overlapping non-nested plates"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "853768b6-31e2-4499-baf6-7c424c537a45",
   "metadata": {},
   "source": [
    "Note that overlapping non-nested plates may be drawn as multiple rectangles."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "414ef383-fc62-4dba-8f90-b2c0b952be3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model():\n",
    "    plate1 = numpyro.plate(\"plate1\", 2, dim=-2)\n",
    "    plate2 = numpyro.plate(\"plate2\", 3, dim=-1)\n",
    "    with plate1:\n",
    "        x = numpyro.sample(\"x\", dist.Normal(0, 1))\n",
    "    with plate1, plate2:\n",
    "        y = numpyro.sample(\"y\", dist.Normal(x, 1))\n",
    "    with plate2:\n",
    "        numpyro.sample(\"z\", dist.Normal(y.sum(-2, keepdims=True), 1), obs=jnp.zeros(3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "11b05889-7df5-477f-baf2-c3b522a79813",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 7.0.3 (0)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"110pt\" height=\"285pt\"\n",
       " viewBox=\"0.00 0.00 110.00 285.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 281)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-281 106,-281 106,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster_plate1</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"8,-91 8,-269 94,-269 94,-91 8,-91\"/>\n",
       "<text text-anchor=\"middle\" x=\"63\" y=\"-98.8\" font-family=\"Times,serif\" font-size=\"14.00\">plate1</text>\n",
       "</g>\n",
       "<g id=\"clust2\" class=\"cluster\">\n",
       "<title>cluster_plate2</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"16,-122 16,-197 86,-197 86,-122 16,-122\"/>\n",
       "<text text-anchor=\"middle\" x=\"55\" y=\"-129.8\" font-family=\"Times,serif\" font-size=\"14.00\">plate2</text>\n",
       "</g>\n",
       "<g id=\"clust3\" class=\"cluster\">\n",
       "<title>cluster_plate2__CLONE</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"16,-8 16,-83 86,-83 86,-8 16,-8\"/>\n",
       "<text text-anchor=\"middle\" x=\"55\" y=\"-15.8\" font-family=\"Times,serif\" font-size=\"14.00\">plate2</text>\n",
       "</g>\n",
       "<!-- x -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>x</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"51\" cy=\"-243\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"51\" y=\"-239.3\" font-family=\"Times,serif\" font-size=\"14.00\">x</text>\n",
       "</g>\n",
       "<!-- y -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>y</title>\n",
       "<ellipse fill=\"white\" stroke=\"black\" cx=\"51\" cy=\"-171\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"51\" y=\"-167.3\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n",
       "</g>\n",
       "<!-- x&#45;&gt;y -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>x&#45;&gt;y</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M51,-224.7C51,-217.41 51,-208.73 51,-200.54\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"54.5,-200.62 51,-190.62 47.5,-200.62 54.5,-200.62\"/>\n",
       "</g>\n",
       "<!-- z -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>z</title>\n",
       "<ellipse fill=\"grey\" stroke=\"black\" cx=\"51\" cy=\"-57\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"51\" y=\"-53.3\" font-family=\"Times,serif\" font-size=\"14.00\">z</text>\n",
       "</g>\n",
       "<!-- y&#45;&gt;z -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>y&#45;&gt;z</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M51,-152.99C51,-135.39 51,-107.6 51,-86.61\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"54.5,-86.68 51,-76.68 47.5,-86.68 54.5,-86.68\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fe553f0a280>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "numpyro.render_model(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bf72672-e6d7-4301-9b87-880ef84047f6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "numpyro",
   "language": "python",
   "name": "numpyro"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
