{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0cace868-5be2-41b1-9d73-a3472ca93a87",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "28b6982f-5077-46fa-8f85-650dd117aaf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import beanmachine.ppl.experimental.gg_algebra as gga\n",
    "from torch import tensor\n",
    "from torch.distributions import HalfNormal\n",
    "\n",
    "import beanmachine.ppl as bm\n",
    "from beanmachine.ppl.inference.bmg_inference import BMGInference"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d5a6da9-dd54-48ec-aa27-49ac79e119d7",
   "metadata": {},
   "source": [
    "Let's consider the $N(0,1) / N(0,1) \\sim Cauchy$ example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "45e8dfbd-e7f2-414a-8c81-8cfa4c2c973d",
   "metadata": {},
   "outputs": [],
   "source": [
    "@bm.random_variable\n",
    "def n():\n",
    "    return HalfNormal(1.0)\n",
    "\n",
    "\n",
    "@bm.random_variable\n",
    "def x():\n",
    "    return HalfNormal(1.0)\n",
    "\n",
    "\n",
    "@bm.functional\n",
    "def x2():\n",
    "    return x() / n()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "453b9abc-c591-4f76-8770-4902068a9240",
   "metadata": {},
   "source": [
    "Trace an execution and perform static analysis with the BMG runtime:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9a7214db-9c82-4bb8-8341-88a2b0b6d629",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{<beanmachine.ppl.compiler.bmg_nodes.UntypedConstantNode at 0x7fa626abc5b0>: 0,\n",
       " <beanmachine.ppl.compiler.bmg_nodes.HalfNormalNode at 0x7fa626abc310>: 1,\n",
       " <beanmachine.ppl.compiler.bmg_nodes.SampleNode at 0x7fa626abc160>: 2,\n",
       " <beanmachine.ppl.compiler.bmg_nodes.SampleNode at 0x7fa626c2ce80>: 3,\n",
       " <beanmachine.ppl.compiler.bmg_nodes.DivisionNode at 0x7fa626c2fc10>: 4,\n",
       " <beanmachine.ppl.compiler.bmg_nodes.Query at 0x7fa626c2f490>: 5}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "queries = [x2()]\n",
    "observations = {}\n",
    "rt = BMGInference()._accumulate_graph(queries, observations)._bmg\n",
    "rt._nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "168438a0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "digraph \"graph\" {\n",
      "  N0[label=1.0];\n",
      "  N1[label=HalfNormal];\n",
      "  N2[label=Sample];\n",
      "  N3[label=Sample];\n",
      "  N4[label=\"/\"];\n",
      "  N5[label=Query];\n",
      "  N0 -> N1[label=sigma];\n",
      "  N1 -> N2[label=operand];\n",
      "  N1 -> N3[label=operand];\n",
      "  N2 -> N4[label=left];\n",
      "  N3 -> N4[label=right];\n",
      "  N4 -> N5[label=operator];\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from beanmachine.ppl.compiler.gen_dot import to_dot\n",
    "\n",
    "print(to_dot(rt))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9cb77f11-3c7d-4060-bd72-e5f59c58751e",
   "metadata": {},
   "source": [
    "Inspect the inferred generalized Gamma algebra `.gga` tail for the half normal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3e339594-0c35-4314-b9ec-0a6b2cb17b0b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample(HalfNormal(tensor(1.)))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "normal_node = list(rt._nodes)[2]\n",
    "print(normal_node)\n",
    "\n",
    "# TODO: normal_node.gga == normal(mu, sigma)\n",
    "# use rho=0 and sigma=None for RV\n",
    "normal_node.gga == gg_algebra.normal(0, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "540df382-7a8e-461d-9d6c-11b553a153d2",
   "metadata": {},
   "source": [
    "Inspect the inferred generalized Gamma algebra `.gga` tail for the division node is cauchy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "80e43c57-40ae-4653-9d09-58f8610a2d99",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Sample(HalfNormal(tensor(1.)))/Sample(HalfNormal(tensor(1.))))\n",
      "Sample(HalfNormal(tensor(1.))) Sample(HalfNormal(tensor(1.)))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "div_node = list(rt._nodes)[4]\n",
    "print(div_node)\n",
    "\n",
    "print(div_node.left, div_node.right)\n",
    "\n",
    "# TODO: rewrite division as power and multiplication\n",
    "# assert div_node.gga == cauchy()\n",
    "div_node.gga == gg_algebra.cauchy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab5c0af9-da9f-4a3e-9895-55af50ca28a2",
   "metadata": {},
   "source": [
    "The final query node `x2()` has a tail computed as"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "a57c39b2-2b29-40a1-af8d-3ffb2dd6cb36",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "c x^(-2.0)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tail = list(rt._nodes)[-1].gga\n",
    "tail"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5115ee5-f3ad-43ee-a6d4-0f9e1ce57b26",
   "metadata": {},
   "source": [
    "Use this auxiliary data to initialize a variational approximator with guaranteed matching tail asymptotics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "764394c7-4db7-4863-9279-39665ffb3f2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.distributions as dist\n",
    "\n",
    "\n",
    "class AbsTransform(dist.transforms.Transform):\n",
    "    r\"\"\"\n",
    "    Transform via the mapping :math:`y = |x|`.\n",
    "    \"\"\"\n",
    "    domain = dist.constraints.real\n",
    "    codomain = dist.constraints.positive\n",
    "\n",
    "    def __eq__(self, other):\n",
    "        return isinstance(other, AbsTransform)\n",
    "\n",
    "    def _call(self, x):\n",
    "        return x.abs()\n",
    "\n",
    "    def _inverse(self, y):\n",
    "        return y\n",
    "\n",
    "    def log_abs_det_jacobian(self, x, y):\n",
    "        return torch.tensor(2.0).log().expand(x.shape)\n",
    "\n",
    "\n",
    "def make_positive(d: dist.Distribution) -> dist.Distribution:\n",
    "    return dist.TransformedDistribution(\n",
    "        d,\n",
    "        [AbsTransform()],\n",
    "    )\n",
    "\n",
    "\n",
    "def make_ggdist(tail: gg_algebra.GGTail) -> dist.Distribution:\n",
    "    if abs(tail.rho) < 1e-15:\n",
    "        return make_positive(\n",
    "            dist.StudentT(\n",
    "                df=-(tail.nu1),\n",
    "            )\n",
    "        )\n",
    "    return dist.TransformedDistribution(\n",
    "        dist.Gamma(\n",
    "            concentration=(tail.nu + 1) / tail.rho,\n",
    "            rate=tail.sigma,\n",
    "        ),\n",
    "        [dist.transforms.PowerTransform(exponent=1.0 / tail.rho)],\n",
    "    )\n",
    "\n",
    "\n",
    "q = make_ggdist(tail)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "40858c99-2784-4e9b-bc66-5505c9677d7e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[AbsTransform()]\n",
      "StudentT(df: 1.0, loc: 0.0, scale: 1.0)\n"
     ]
    }
   ],
   "source": [
    "print(q.transforms)\n",
    "print(q.base_dist)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb464487-b779-41d0-b4f3-a1e14725fcd0",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Let's try a multiplicative noise sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7d936d59-b13e-441e-89a5-6cb9b386b237",
   "metadata": {},
   "outputs": [],
   "source": [
    "%autoreload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "5a2ddfb7-d9bd-4118-b5f8-9d4baad5f06c",
   "metadata": {},
   "outputs": [],
   "source": [
    "@bm.random_variable\n",
    "def a(i):\n",
    "    return Normal(0.0, 1.0)\n",
    "\n",
    "\n",
    "@bm.random_variable\n",
    "def b(i):\n",
    "    return Normal(0.0, 1.0)\n",
    "\n",
    "\n",
    "@bm.functional\n",
    "def x(i):\n",
    "    if i == 0:\n",
    "        return b(0)\n",
    "    return a(i) * x(i - 1) + b(i)\n",
    "\n",
    "\n",
    "queries = [x(20)]\n",
    "observations = {}\n",
    "rt = BMGInference()._accumulate_graph(queries, observations)._bmg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "c31b907e-275a-4ee8-925b-b64c53dd7831",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "digraph \"graph\" {\n",
      "  N00[label=0.0];\n",
      "  N01[label=1.0];\n",
      "  N02[label=Normal];\n",
      "  N03[label=Sample];\n",
      "  N04[label=Sample];\n",
      "  N05[label=Sample];\n",
      "  N06[label=Sample];\n",
      "  N07[label=Sample];\n",
      "  N08[label=Sample];\n",
      "  N09[label=Sample];\n",
      "  N10[label=Sample];\n",
      "  N11[label=Sample];\n",
      "  N12[label=Sample];\n",
      "  N13[label=Sample];\n",
      "  N14[label=Sample];\n",
      "  N15[label=Sample];\n",
      "  N16[label=Sample];\n",
      "  N17[label=Sample];\n",
      "  N18[label=Sample];\n",
      "  N19[label=Sample];\n",
      "  N20[label=Sample];\n",
      "  N21[label=Sample];\n",
      "  N22[label=Sample];\n",
      "  N23[label=Sample];\n",
      "  N24[label=\"*\"];\n",
      "  N25[label=Sample];\n",
      "  N26[label=\"+\"];\n",
      "  N27[label=\"*\"];\n",
      "  N28[label=Sample];\n",
      "  N29[label=\"+\"];\n",
      "  N30[label=\"*\"];\n",
      "  N31[label=Sample];\n",
      "  N32[label=\"+\"];\n",
      "  N33[label=\"*\"];\n",
      "  N34[label=Sample];\n",
      "  N35[label=\"+\"];\n",
      "  N36[label=\"*\"];\n",
      "  N37[label=Sample];\n",
      "  N38[label=\"+\"];\n",
      "  N39[label=\"*\"];\n",
      "  N40[label=Sample];\n",
      "  N41[label=\"+\"];\n",
      "  N42[label=\"*\"];\n",
      "  N43[label=Sample];\n",
      "  N44[label=\"+\"];\n",
      "  N45[label=\"*\"];\n",
      "  N46[label=Sample];\n",
      "  N47[label=\"+\"];\n",
      "  N48[label=\"*\"];\n",
      "  N49[label=Sample];\n",
      "  N50[label=\"+\"];\n",
      "  N51[label=\"*\"];\n",
      "  N52[label=Sample];\n",
      "  N53[label=\"+\"];\n",
      "  N54[label=\"*\"];\n",
      "  N55[label=Sample];\n",
      "  N56[label=\"+\"];\n",
      "  N57[label=\"*\"];\n",
      "  N58[label=Sample];\n",
      "  N59[label=\"+\"];\n",
      "  N60[label=\"*\"];\n",
      "  N61[label=Sample];\n",
      "  N62[label=\"+\"];\n",
      "  N63[label=\"*\"];\n",
      "  N64[label=Sample];\n",
      "  N65[label=\"+\"];\n",
      "  N66[label=\"*\"];\n",
      "  N67[label=Sample];\n",
      "  N68[label=\"+\"];\n",
      "  N69[label=\"*\"];\n",
      "  N70[label=Sample];\n",
      "  N71[label=\"+\"];\n",
      "  N72[label=\"*\"];\n",
      "  N73[label=Sample];\n",
      "  N74[label=\"+\"];\n",
      "  N75[label=\"*\"];\n",
      "  N76[label=Sample];\n",
      "  N77[label=\"+\"];\n",
      "  N78[label=\"*\"];\n",
      "  N79[label=Sample];\n",
      "  N80[label=\"+\"];\n",
      "  N81[label=\"*\"];\n",
      "  N82[label=Sample];\n",
      "  N83[label=\"+\"];\n",
      "  N84[label=Query];\n",
      "  N00 -> N02[label=mu];\n",
      "  N01 -> N02[label=sigma];\n",
      "  N02 -> N03[label=operand];\n",
      "  N02 -> N04[label=operand];\n",
      "  N02 -> N05[label=operand];\n",
      "  N02 -> N06[label=operand];\n",
      "  N02 -> N07[label=operand];\n",
      "  N02 -> N08[label=operand];\n",
      "  N02 -> N09[label=operand];\n",
      "  N02 -> N10[label=operand];\n",
      "  N02 -> N11[label=operand];\n",
      "  N02 -> N12[label=operand];\n",
      "  N02 -> N13[label=operand];\n",
      "  N02 -> N14[label=operand];\n",
      "  N02 -> N15[label=operand];\n",
      "  N02 -> N16[label=operand];\n",
      "  N02 -> N17[label=operand];\n",
      "  N02 -> N18[label=operand];\n",
      "  N02 -> N19[label=operand];\n",
      "  N02 -> N20[label=operand];\n",
      "  N02 -> N21[label=operand];\n",
      "  N02 -> N22[label=operand];\n",
      "  N02 -> N23[label=operand];\n",
      "  N02 -> N25[label=operand];\n",
      "  N02 -> N28[label=operand];\n",
      "  N02 -> N31[label=operand];\n",
      "  N02 -> N34[label=operand];\n",
      "  N02 -> N37[label=operand];\n",
      "  N02 -> N40[label=operand];\n",
      "  N02 -> N43[label=operand];\n",
      "  N02 -> N46[label=operand];\n",
      "  N02 -> N49[label=operand];\n",
      "  N02 -> N52[label=operand];\n",
      "  N02 -> N55[label=operand];\n",
      "  N02 -> N58[label=operand];\n",
      "  N02 -> N61[label=operand];\n",
      "  N02 -> N64[label=operand];\n",
      "  N02 -> N67[label=operand];\n",
      "  N02 -> N70[label=operand];\n",
      "  N02 -> N73[label=operand];\n",
      "  N02 -> N76[label=operand];\n",
      "  N02 -> N79[label=operand];\n",
      "  N02 -> N82[label=operand];\n",
      "  N03 -> N81[label=left];\n",
      "  N04 -> N78[label=left];\n",
      "  N05 -> N75[label=left];\n",
      "  N06 -> N72[label=left];\n",
      "  N07 -> N69[label=left];\n",
      "  N08 -> N66[label=left];\n",
      "  N09 -> N63[label=left];\n",
      "  N10 -> N60[label=left];\n",
      "  N11 -> N57[label=left];\n",
      "  N12 -> N54[label=left];\n",
      "  N13 -> N51[label=left];\n",
      "  N14 -> N48[label=left];\n",
      "  N15 -> N45[label=left];\n",
      "  N16 -> N42[label=left];\n",
      "  N17 -> N39[label=left];\n",
      "  N18 -> N36[label=left];\n",
      "  N19 -> N33[label=left];\n",
      "  N20 -> N30[label=left];\n",
      "  N21 -> N27[label=left];\n",
      "  N22 -> N24[label=left];\n",
      "  N23 -> N24[label=right];\n",
      "  N24 -> N26[label=left];\n",
      "  N25 -> N26[label=right];\n",
      "  N26 -> N27[label=right];\n",
      "  N27 -> N29[label=left];\n",
      "  N28 -> N29[label=right];\n",
      "  N29 -> N30[label=right];\n",
      "  N30 -> N32[label=left];\n",
      "  N31 -> N32[label=right];\n",
      "  N32 -> N33[label=right];\n",
      "  N33 -> N35[label=left];\n",
      "  N34 -> N35[label=right];\n",
      "  N35 -> N36[label=right];\n",
      "  N36 -> N38[label=left];\n",
      "  N37 -> N38[label=right];\n",
      "  N38 -> N39[label=right];\n",
      "  N39 -> N41[label=left];\n",
      "  N40 -> N41[label=right];\n",
      "  N41 -> N42[label=right];\n",
      "  N42 -> N44[label=left];\n",
      "  N43 -> N44[label=right];\n",
      "  N44 -> N45[label=right];\n",
      "  N45 -> N47[label=left];\n",
      "  N46 -> N47[label=right];\n",
      "  N47 -> N48[label=right];\n",
      "  N48 -> N50[label=left];\n",
      "  N49 -> N50[label=right];\n",
      "  N50 -> N51[label=right];\n",
      "  N51 -> N53[label=left];\n",
      "  N52 -> N53[label=right];\n",
      "  N53 -> N54[label=right];\n",
      "  N54 -> N56[label=left];\n",
      "  N55 -> N56[label=right];\n",
      "  N56 -> N57[label=right];\n",
      "  N57 -> N59[label=left];\n",
      "  N58 -> N59[label=right];\n",
      "  N59 -> N60[label=right];\n",
      "  N60 -> N62[label=left];\n",
      "  N61 -> N62[label=right];\n",
      "  N62 -> N63[label=right];\n",
      "  N63 -> N65[label=left];\n",
      "  N64 -> N65[label=right];\n",
      "  N65 -> N66[label=right];\n",
      "  N66 -> N68[label=left];\n",
      "  N67 -> N68[label=right];\n",
      "  N68 -> N69[label=right];\n",
      "  N69 -> N71[label=left];\n",
      "  N70 -> N71[label=right];\n",
      "  N71 -> N72[label=right];\n",
      "  N72 -> N74[label=left];\n",
      "  N73 -> N74[label=right];\n",
      "  N74 -> N75[label=right];\n",
      "  N75 -> N77[label=left];\n",
      "  N76 -> N77[label=right];\n",
      "  N77 -> N78[label=right];\n",
      "  N78 -> N80[label=left];\n",
      "  N79 -> N80[label=right];\n",
      "  N80 -> N81[label=right];\n",
      "  N81 -> N83[label=left];\n",
      "  N82 -> N83[label=right];\n",
      "  N83 -> N84[label=operator];\n",
      "}\n",
      "\n",
      "{<beanmachine.ppl.compiler.bmg_nodes.UntypedConstantNode object at 0x128ca23d0>: 0, <beanmachine.ppl.compiler.bmg_nodes.UntypedConstantNode object at 0x128ca2730>: 1, <beanmachine.ppl.compiler.bmg_nodes.NormalNode object at 0x128ca2b50>: 2, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128ca1910>: 3, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128ca2430>: 4, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128ca4220>: 5, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128ca5280>: 6, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128ca5a90>: 7, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128ca5070>: 8, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128cac100>: 9, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128cac5b0>: 10, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128caca90>: 11, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128cacf40>: 12, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128cae460>: 13, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128cac8b0>: 14, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128ca4370>: 15, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128ca1a30>: 16, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128caeb20>: 17, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c30040>: 18, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c30520>: 19, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c30a00>: 20, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c30ee0>: 21, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128cb2400>: 22, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c92ee0>: 23, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128cb27c0>: 24, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c94460>: 25, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c94ee0>: 26, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c94f10>: 27, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c951c0>: 28, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c95640>: 29, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c955b0>: 30, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c952e0>: 31, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c9bf10>: 32, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c9b2e0>: 33, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c9beb0>: 34, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c9bbb0>: 35, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c9b9a0>: 36, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c9bb80>: 37, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c9b130>: 38, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c5d160>: 39, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c5d430>: 40, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c5d670>: 41, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c5d7c0>: 42, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c95d60>: 43, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c94280>: 44, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c942b0>: 45, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c92e20>: 46, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c5da00>: 47, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c5db50>: 48, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c5de20>: 49, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c608e0>: 50, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c603d0>: 51, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c600a0>: 52, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c60880>: 53, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c60850>: 54, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c60a30>: 55, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c60e50>: 56, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c60df0>: 57, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c5f160>: 58, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c5f3a0>: 59, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c5f4f0>: 60, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c5f7c0>: 61, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c5fa00>: 62, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c5fb50>: 63, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c5fe20>: 64, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c61220>: 65, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c611c0>: 66, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c61460>: 67, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c616a0>: 68, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c61970>: 69, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c61b80>: 70, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c61ca0>: 71, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c61f70>: 72, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c641f0>: 73, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c64370>: 74, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c64430>: 75, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c64820>: 76, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c64d60>: 77, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c64e80>: 78, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c61040>: 79, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c60eb0>: 80, <beanmachine.ppl.compiler.bmg_nodes.MultiplicationNode object at 0x128c60430>: 81, <beanmachine.ppl.compiler.bmg_nodes.SampleNode object at 0x128c64730>: 82, <beanmachine.ppl.compiler.bmg_nodes.AdditionNode object at 0x128c64a60>: 83, <beanmachine.ppl.compiler.bmg_nodes.Query object at 0x128c64e20>: 84}\n"
     ]
    }
   ],
   "source": [
    "from beanmachine.ppl.compiler.gen_dot import to_dot\n",
    "\n",
    "print(to_dot(rt))\n",
    "print(rt._nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "3dcd43a2-012b-4a54-be02-d18ca0a1850a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Sample(Normal(tensor(0.),tensor(1.)))*Sample(Normal(tensor(0.),tensor(1.)))) c x^(-0.5) exp(-x)\n"
     ]
    }
   ],
   "source": [
    "node = list(rt._nodes)[9]\n",
    "print(node, node.gga)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "6848c05c-9003-4034-8382-57f492a53bfe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "((Sample(Normal(tensor(0.),tensor(1.)))*Sample(Normal(tensor(0.),tensor(1.))))+Sample(Normal(tensor(0.),tensor(1.)))) c x^(-0.5) exp(-x)\n"
     ]
    }
   ],
   "source": [
    "node = list(rt._nodes)[11]\n",
    "print(node, node.gga)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "e793c077-89c7-430e-8bf4-9541cc02f45b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*((Sample(Normal(tensor(0.),tensor(1.)))*Sample(Normal(tensor(0.),tensor(1.))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))))+Sample(Normal(tensor(0.),tensor(1.)))) c x^(-0.9524) exp(-10.5 * x^0.09524)\n"
     ]
    }
   ],
   "source": [
    "node = list(rt._nodes)[83]\n",
    "print(node, node.gga)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dffe3629-1b38-4393-8c86-ef76812a3ca5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "83b589253c6ae2165fd99d3b5e434b8a0ff74c98e791d87ced25152a201010fd"
  },
  "kernelspec": {
   "display_name": "Python 3.10.4 ('gga')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
