{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basics: Node and MessageNode\n",
    "\n",
    "`trace` is a comptuational grpah framework for tracing and optimizing codes. Its core data structure is the \"node\" container of python objects. To create a node, use `node` method, which creates a `Node` object. To access, the content of a node, use the `data` attribute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "node of int 1\n",
      "string\n",
      "[1, 2, 3]\n",
      "{'a': 1, 'b': 2}\n",
      "<__main__.Foo object at 0x7f5205ed1d00>\n"
     ]
    }
   ],
   "source": [
    "from autogen.trace import node\n",
    "\n",
    "x = node(1)  # node of int\n",
    "print(\"node of int\", x.data)\n",
    "x = node(\"string\")  # node of str\n",
    "print(x.data)\n",
    "x = node([1, 2, 3])  # node of list\n",
    "print(x.data)\n",
    "x = node({\"a\": 1, \"b\": 2})  # node of dict\n",
    "print(x.data)\n",
    "\n",
    "\n",
    "class Foo:\n",
    "    def __init__(self, x):\n",
    "        self.x = x\n",
    "        self.secret = \"secret\"\n",
    "\n",
    "    def print(self, val):\n",
    "        print(val)\n",
    "\n",
    "\n",
    "x = node(Foo(\"foo\"))  # node of a class instance\n",
    "print(x.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When a computation is performed using the contents of nodes, the result is also a node. This allows for the creation of a computation graph. The computation graph is a directed acyclic graph where the edges indicate the data dependencies.\n",
    "\n",
    "Nodes that are defined manually can be marked as trainable by setting their `trainable` attribute to True; such nodes are a subclass of Node called `ParameterNode`.\n",
    "Nodes that are created automatically as a result of computations are a different subclass of Node called `MessageNode`.\n",
    "\n",
    "Nodes can be copied. This can be done in two ways with `clone` or `detach`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<__main__.Foo object at 0x7f512cee9970>\n"
     ]
    }
   ],
   "source": [
    "# clone returns a MessageNode whose parent is the original node\n",
    "x_clone = x.clone()\n",
    "assert x in x_clone.parents\n",
    "assert x_clone.data != x.data\n",
    "assert x_clone.data.x == x.data.x\n",
    "print(x_clone.data)\n",
    "# detach returns a new Node which is not connected to the original node\n",
    "x_detach = x.detach()\n",
    "assert len(x_detach.parents) == 0\n",
    "assert x_detach.data != x.data\n",
    "assert x_detach.data.x == x.data.x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`trace` overloads python's magic methods that gives return value explicitly (such as `__add__`), except logical operations such as `__bool__` and setters. (The comparison magic methods compares the level of the nodes in the global graph, rather than comparing the data.) \n",
    "\n",
    "When nodes are used with these magic methods, the output would be a `MessageNode`, which is a subclass of `Node` that has the inputs of the method as the parents. The attribute `description` of a `MessageNode` documents the method's function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MessageNode: (divide:0, dtype=<class 'float'>, data=0.3333333333333333)\n",
      "MessageNode: (divide:0, dtype=<class 'float'>, data=0.3333333333333333)\n",
      "parents: ['node_x:0', 'node_y:0']\n",
      "\n",
      "\n",
      "MessageNode: (getitem:0, dtype=<class 'int'>, data=1)\n",
      "parents: ['dict_node:0', 'str:1']\n",
      "len(dict_node) = MessageNode: (len:0, dtype=<class 'int'>, data=2)\n",
      "\n",
      "\n",
      "Node: (str:3, dtype=<class 'str'>, data=hello world)\n",
      "MessageNode: (getattr:1, dtype=<class 'str'>, data=secret)\n",
      "parents: ['Foo:1', 'str:4']\n"
     ]
    }
   ],
   "source": [
    "def print_node(node):\n",
    "    print(node)\n",
    "    print(f\"parents: {[p.name for p in node.parents]}\")\n",
    "\n",
    "\n",
    "# Basic arithmetic operations\n",
    "x = node(1, name=\"node_x\")\n",
    "y = node(3, name=\"node_y\")\n",
    "z = x / y\n",
    "z2 = x / 3  # the int 3 would be converted to a node automatically\n",
    "print(z)\n",
    "print_node(z)\n",
    "print(\"\\n\")\n",
    "\n",
    "# Index a node\n",
    "dict_node = node({\"a\": 1, \"b\": 2}, name=\"dict_node\")\n",
    "a = dict_node[\"a\"]\n",
    "print_node(a)\n",
    "print(\"len(dict_node) =\", dict_node.len())\n",
    "\n",
    "print(\"\\n\")\n",
    "\n",
    "# Getting class attribute and calling class method\n",
    "x = node(Foo(\"foo\"))\n",
    "x.call(\"print\", \"hello world\")\n",
    "print_node(x.getattr(\"secret\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Nodes can not be used in logical operations like and, or, not. This is an explicit design choice so as to ensure that logical operations in python code is explicitly traced."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "x = node(True)\n",
    "try:\n",
    "    if x:\n",
    "        print(\"True\")\n",
    "except Exception as e:\n",
    "    print(e)\n",
    "    print(\"Use if x.data instead of if x\")\n",
    "\n",
    "\n",
    "x = node([1, 2, 3])\n",
    "try:\n",
    "    1 in x\n",
    "except Exception as e:\n",
    "    print(e)\n",
    "    print(\"Use 1 in x.data instead of 1 in x\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Nodes can be used to encapsulate any python object, including functions. Here're a few examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output: MessageNode: (call:1, dtype=<class 'int'>, data=2)\n",
      "parents [('function:0', <function fun at 0x7f5205ed63a0>), ('int:3', 1)]\n",
      "\n",
      "\n",
      "\n",
      "The attribute of the wrapped object cannot be directly accessed. Instead use getattr() or call()\n",
      "foo_node: MessageNode: (getattr:2, dtype=<class 'int'>, data=1)\n",
      "parents [('Foo:2', <__main__.Foo object at 0x7f5205ed1f10>), ('str:5', 'node')]\n",
      "non_node: MessageNode: (getattr:3, dtype=<class 'int'>, data=2)\n",
      "parents [('Foo:2', <__main__.Foo object at 0x7f5205ed1f10>), ('str:6', 'non_node')]\n",
      "output: MessageNode: (call:2, dtype=<class 'int'>, data=4)\n",
      "parents [('getattr:4', <bound method Foo.non_trace_fun of <__main__.Foo object at 0x7f5205ed1f10>>)]\n",
      "output: MessageNode: (call:4, dtype=<class 'int'>, data=4)\n",
      "parents [('getattr:6', <bound method Foo.non_trace_fun of <__main__.Foo object at 0x7f5205ed1f10>>)]\n"
     ]
    }
   ],
   "source": [
    "def fun(x):\n",
    "    return x + 1\n",
    "\n",
    "\n",
    "fun_node = node(fun)\n",
    "y = fun_node(node(1))\n",
    "print(f\"output: {y}\\nparents {[(p.name, p.data) for p in y.parents]}\")\n",
    "print(\"\\n\\n\")\n",
    "\n",
    "\n",
    "class Foo:\n",
    "\n",
    "    def __init__(self):\n",
    "        self.node = node(1)\n",
    "        self.non_node = 2\n",
    "\n",
    "    def trace_fun(self):\n",
    "        return self.node * 2\n",
    "\n",
    "    def non_trace_fun(self):\n",
    "        return self.non_node * 2\n",
    "\n",
    "\n",
    "foo = node(Foo())\n",
    "\n",
    "try:\n",
    "    foo.node\n",
    "    foo.trace_fun()\n",
    "except AttributeError:\n",
    "    print(\"The attribute of the wrapped object cannot be directly accessed. Instead use getattr() or call()\")\n",
    "\n",
    "\n",
    "attr = foo.getattr(\"node\")\n",
    "print(f\"foo_node: {attr}\\nparents {[(p.name, p.data) for p in attr.parents]}\")\n",
    "\n",
    "\n",
    "attr = foo.getattr(\"non_node\")\n",
    "print(f\"non_node: {attr}\\nparents {[(p.name, p.data) for p in attr.parents]}\")\n",
    "\n",
    "\n",
    "fun = foo.getattr(\"non_trace_fun\")\n",
    "y = fun()\n",
    "print(f\"output: {y}\\nparents {[(p.name, p.data) for p in y.parents]}\")\n",
    "\n",
    "try:\n",
    "    fun = foo.getattr(\"trace_fun\")\n",
    "    y = fun()\n",
    "except AssertionError as e:\n",
    "    print(e)\n",
    "\n",
    "y = foo.call(\"non_trace_fun\")\n",
    "print(f\"output: {y}\\nparents {[(p.name, p.data) for p in y.parents]}\")\n",
    "\n",
    "try:\n",
    "    y = foo.call(\"trace_fun\")\n",
    "except AssertionError as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Writing Custom Node Operators\n",
    "In addition to magical methods, we can use `bundle` to write custom methods that are traceable. When decorating a method with `bundle`, it needs a description of the method. It has a format of `[method_name] description`. `bundle` will automatically add all nodes whose `data` attribute is used within the function as the parents of the output `MessageNode`.\n",
    "\n",
    "Given a function `fun`, the decorated function `bundle(description)(fun)` by default will unpack all the inputs (it unpacks all node containers), send them to `fun`, and then creates a `MessageNode` to wrap the output of `fun` which has parents containing all the nodes used in this operation. \n",
    "\n",
    "Since all inputs are unpacked, they will be set as the parents. The user can override this behavior by setting `bundle(description, unpack_input=False)`, which would let `fun` see the original inputs.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MessageNode: (add_1:0, dtype=<class 'int'>, data=2)\n",
      "parents: ['node_x:1']\n",
      "\n",
      "\n",
      "MessageNode: (add:1, dtype=<class 'int'>, data=3)\n",
      "parents: ['node_x:2', 'node_y:1']\n",
      "\n",
      "\n",
      "MessageNode: (pass_through:0, dtype=<class 'tuple'>, data=(1, 2))\n",
      "\n",
      "\n",
      "(<autogen.trace.nodes.MessageNode object at 0x7f512cf28370>, <autogen.trace.nodes.MessageNode object at 0x7f512cf28190>)\n"
     ]
    }
   ],
   "source": [
    "from autogen.trace import bundle\n",
    "\n",
    "\n",
    "@bundle(\"[add_1] Add 1 to input x\")\n",
    "def foo(x):\n",
    "    return x + 1\n",
    "\n",
    "\n",
    "x = node(1, name=\"node_x\")\n",
    "z = foo(x)\n",
    "print_node(z)\n",
    "print(\"\\n\")\n",
    "\n",
    "\n",
    "@bundle(\"[add] Add input x and input y\")\n",
    "def foo(x, y):\n",
    "    return x + y\n",
    "\n",
    "\n",
    "x = node(1, name=\"node_x\")\n",
    "y = node(2, name=\"node_y\")\n",
    "z = foo(x, y)\n",
    "print_node(z)\n",
    "print(\"\\n\")\n",
    "\n",
    "# The output is a node of a tuple of two nodes\n",
    "\n",
    "\n",
    "@bundle(\"[pass_through] No operation, just return inputs\")\n",
    "def foo(x, y):\n",
    "    return x, y\n",
    "\n",
    "\n",
    "x = node(1, name=\"node_x\")\n",
    "y = node(2, name=\"node_y\")\n",
    "z = foo(x, y)\n",
    "print(z)\n",
    "from autogen.trace.nodes import Node\n",
    "\n",
    "assert isinstance(z, Node)\n",
    "assert isinstance(z.data, tuple)\n",
    "assert len(z.data) == 2\n",
    "print(\"\\n\")\n",
    "\n",
    "\n",
    "# The output is a tuple of two nodes\n",
    "@bundle(\"[pass_through] No operation, just return inputs\", n_outputs=2)\n",
    "def foo(x, y):\n",
    "    return x, y\n",
    "\n",
    "\n",
    "x = node(1, name=\"node_x\")\n",
    "y = node(2, name=\"node_y\")\n",
    "z = foo(x, y)\n",
    "print(z)\n",
    "assert isinstance(z, tuple)\n",
    "assert len(z) == 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Describing Relationship between Inputs and Outputs and Nodes in the Graph\n",
    "One can additionally provide `node_dict` to specify how each variable mentioned in `description` is related to the nodes in the graph. This relationship is stored in the `inputs` attribute of `MessageNode`. See examples \n",
    "below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'x': ('node_x:4', 1)}\n",
      "{'input': ('node_x:4', 1)}\n",
      "{'x': ('node_x:4', 1), 'custom_x': ('node_x:4', 1)}\n"
     ]
    }
   ],
   "source": [
    "# The default value of node_dict is None. In this case, the key of the inputs dict is the name of the input nodes.\n",
    "@bundle(\"[add_1] Add 1 to input x\")\n",
    "def foo(x):\n",
    "    return x + 1\n",
    "\n",
    "\n",
    "z = foo(x)\n",
    "print({k: (v.name, v.data) for k, v in z.inputs.items()})\n",
    "\n",
    "# When node_dict is set to 'auto', the key of the inputs dict is the name specified in the function signature.\n",
    "\n",
    "\n",
    "@bundle(\"[add_1] Add 1 to input x\", node_dict=\"auto\")\n",
    "def foo(input):\n",
    "    return input + 1\n",
    "\n",
    "\n",
    "z = foo(x)\n",
    "print({k: (v.name, v.data) for k, v in z.inputs.items()})\n",
    "\n",
    "# When node_dict is set to a dict, the key of the inputs dict is the name specified in the dict.\n",
    "node_dict = {\"custom_x\": x}\n",
    "\n",
    "\n",
    "@bundle(\"[add_1] Add 1 to input x\", node_dict=node_dict)\n",
    "def foo(x):\n",
    "    return x + 1\n",
    "\n",
    "\n",
    "z = foo(x)\n",
    "print({k: (v.name, v.data) for k, v in z.inputs.items()})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using node_dict is useful when the function uses nodes that are not in the function signature.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Not all nodes used in the operator <function foo at 0x7f512cf3b310> are specified as inputs of the returned node. Missing ['node_y:4'] \n",
      "{'x': ('node_x:5', 1), 'node_y': ('node_y:4', 2)}\n"
     ]
    }
   ],
   "source": [
    "# By default, the inputs dict only contains the nodes that are in the function signature. One can update the inputs dict by using node_dict.\n",
    "x = node(1, name=\"node_x\")\n",
    "y = node(2, name=\"node_y\")\n",
    "\n",
    "\n",
    "@bundle(\"[add_1] Add input x to node_y.\", node_dict=\"auto\")\n",
    "def foo(x):\n",
    "    return x + y.data\n",
    "\n",
    "\n",
    "try:\n",
    "    z = foo(x)\n",
    "except Exception as e:\n",
    "    # Since the function signature does not contain y, the function will raise an error.\n",
    "    print(e)\n",
    "# We can use node_dict to add y to the inputs dict.\n",
    "node_dict = {\"node_y\": y}\n",
    "\n",
    "\n",
    "@bundle(\"[add_1] Add input x to node_y.\", node_dict=node_dict)\n",
    "def foo(x):\n",
    "    return x + y.data\n",
    "\n",
    "\n",
    "z = foo(x)\n",
    "print({k: (v.name, v.data) for k, v in z.inputs.items()})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize Graph\n",
    "\n",
    "The graph of nodes can be visualized by calling `backward` method of a node. (Later we will cover how `backward` also sends feedback across the graph). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"1229pt\" height=\"305pt\"\n",
       " viewBox=\"0.00 0.00 1228.57 304.86\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 300.86)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-300.86 1224.5656,-300.86 1224.5656,4 -4,4\"/>\n",
       "<!-- add0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>add0</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"433.2828\" cy=\"-148.43\" rx=\"167.6687\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"433.2828\" y=\"-159.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">add0</text>\n",
       "<text text-anchor=\"middle\" x=\"433.2828\" y=\"-144.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[add] This is an add operator of x and y.</text>\n",
       "<text text-anchor=\"middle\" x=\"433.2828\" y=\"-129.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">3</text>\n",
       "</g>\n",
       "<!-- add2 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>add2</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"609.2828\" cy=\"-37.4767\" rx=\"167.6687\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"609.2828\" y=\"-48.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">add2</text>\n",
       "<text text-anchor=\"middle\" x=\"609.2828\" y=\"-33.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[add] This is an add operator of x and y.</text>\n",
       "<text text-anchor=\"middle\" x=\"609.2828\" y=\"-18.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">5</text>\n",
       "</g>\n",
       "<!-- add0&#45;&gt;add2 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>add0&#45;&gt;add2</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M489.4542,-113.0186C506.9179,-102.0092 526.3143,-89.7814 544.3254,-78.4269\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"546.4486,-81.2259 553.0414,-72.9322 542.7155,-75.3044 546.4486,-81.2259\"/>\n",
       "</g>\n",
       "<!-- add1 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>add1</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"786.2828\" cy=\"-148.43\" rx=\"167.6687\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"786.2828\" y=\"-159.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">add1</text>\n",
       "<text text-anchor=\"middle\" x=\"786.2828\" y=\"-144.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[add] This is an add operator of x and y.</text>\n",
       "<text text-anchor=\"middle\" x=\"786.2828\" y=\"-129.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">2</text>\n",
       "</g>\n",
       "<!-- add1&#45;&gt;add2 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>add1&#45;&gt;add2</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M729.7922,-113.0186C712.2293,-102.0092 692.7227,-89.7814 674.6093,-78.4269\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"676.1757,-75.278 665.8438,-72.9322 672.4577,-81.209 676.1757,-75.278\"/>\n",
       "</g>\n",
       "<!-- node_x0 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>node_x0</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"610.2828\" cy=\"-259.3833\" rx=\"197.0658\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"610.2828\" y=\"-270.6833\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">node_x0</text>\n",
       "<text text-anchor=\"middle\" x=\"610.2828\" y=\"-255.6833\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[Node] This is a node in a computational graph.</text>\n",
       "<text text-anchor=\"middle\" x=\"610.2828\" y=\"-240.6833\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">1</text>\n",
       "</g>\n",
       "<!-- node_x0&#45;&gt;add0 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>node_x0&#45;&gt;add0</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M552.816,-223.3599C535.4664,-212.4843 516.2949,-200.4665 498.4798,-189.299\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"500.189,-186.2397 489.8571,-183.8939 496.4711,-192.1707 500.189,-186.2397\"/>\n",
       "</g>\n",
       "<!-- node_x0&#45;&gt;add1 -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>node_x0&#45;&gt;add1</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M667.4249,-223.3599C684.6765,-212.4843 703.7397,-200.4665 721.4541,-189.299\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"723.4353,-192.1876 730.0281,-183.8939 719.7023,-186.266 723.4353,-192.1876\"/>\n",
       "</g>\n",
       "<!-- node_y0 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>node_y0</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"197.2828\" cy=\"-259.3833\" rx=\"197.0658\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-270.6833\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">node_y0</text>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-255.6833\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[Node] This is a node in a computational graph.</text>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-240.6833\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">2</text>\n",
       "</g>\n",
       "<!-- node_y0&#45;&gt;add0 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>node_y0&#45;&gt;add0</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M271.3072,-224.5814C296.9424,-212.5292 325.8089,-198.9579 351.9961,-186.6462\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"353.5768,-189.7706 361.1374,-182.3485 350.5985,-183.4358 353.5768,-189.7706\"/>\n",
       "</g>\n",
       "<!-- int0 -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>int0</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"1023.2828\" cy=\"-259.3833\" rx=\"197.0658\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"1023.2828\" y=\"-270.6833\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">int0</text>\n",
       "<text text-anchor=\"middle\" x=\"1023.2828\" y=\"-255.6833\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[Node] This is a node in a computational graph.</text>\n",
       "<text text-anchor=\"middle\" x=\"1023.2828\" y=\"-240.6833\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">1</text>\n",
       "</g>\n",
       "<!-- int0&#45;&gt;add1 -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>int0&#45;&gt;add1</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M948.9448,-224.5814C923.2009,-212.5292 894.2121,-198.9579 867.9139,-186.6462\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"869.2745,-183.4186 858.7339,-182.3485 866.3065,-189.7583 869.2745,-183.4186\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7f512cf0f0a0>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from autogen.trace.nodes import GRAPH\n",
    "\n",
    "GRAPH.clear()  # to remove all the nodes\n",
    "x = node(1, name=\"node_x\")\n",
    "y = node(2, name=\"node_y\")\n",
    "a = x + y\n",
    "b = x + 1\n",
    "final = a + b\n",
    "final.backward(visualize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node: (bool:0, dtype=<class 'bool'>, data=True) Node: (int:0, dtype=<class 'int'>, data=1) Node: (int:1, dtype=<class 'int'>, data=0)\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"403pt\" height=\"83pt\"\n",
       " viewBox=\"0.00 0.00 402.57 82.95\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 78.9533)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-78.9533 398.5656,-78.9533 398.5656,4 -4,4\"/>\n",
       "<!-- int0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>int0</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"197.2828\" cy=\"-37.4767\" rx=\"197.0658\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-48.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">int0</text>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-33.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[Node] This is a node in a computational graph.</text>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-18.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">1</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7f512cf0f850>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "GRAPH.clear()\n",
    "x = node(True)\n",
    "one = node(1)\n",
    "zero = node(0)\n",
    "print(x, one, zero)\n",
    "# Logical operations are not traceable\n",
    "y = one if x.data else zero\n",
    "y.backward(visualize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"1229pt\" height=\"194pt\"\n",
       " viewBox=\"0.00 0.00 1228.57 193.91\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 189.9066)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-189.9066 1224.5656,-189.9066 1224.5656,4 -4,4\"/>\n",
       "<!-- bool0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>bool0</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"197.2828\" cy=\"-148.43\" rx=\"197.0658\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-159.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">bool0</text>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-144.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[Node] This is a node in a computational graph.</text>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-129.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">True</text>\n",
       "</g>\n",
       "<!-- fun0 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>fun0</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"610.2828\" cy=\"-37.4767\" rx=\"230.5336\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"610.2828\" y=\"-48.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">fun0</text>\n",
       "<text text-anchor=\"middle\" x=\"610.2828\" y=\"-33.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[fun] Return one if input x is True, otherwise return zero</text>\n",
       "<text text-anchor=\"middle\" x=\"610.2828\" y=\"-18.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">1</text>\n",
       "</g>\n",
       "<!-- bool0&#45;&gt;fun0 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>bool0&#45;&gt;fun0</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M311.2301,-117.8178C363.823,-103.6886 426.5149,-86.8463 480.8815,-72.2406\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"481.8307,-75.6098 490.5802,-69.635 480.0145,-68.8495 481.8307,-75.6098\"/>\n",
       "</g>\n",
       "<!-- int0 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>int0</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"610.2828\" cy=\"-148.43\" rx=\"197.0658\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"610.2828\" y=\"-159.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">int0</text>\n",
       "<text text-anchor=\"middle\" x=\"610.2828\" y=\"-144.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[Node] This is a node in a computational graph.</text>\n",
       "<text text-anchor=\"middle\" x=\"610.2828\" y=\"-129.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">1</text>\n",
       "</g>\n",
       "<!-- int0&#45;&gt;fun0 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>int0&#45;&gt;fun0</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M610.2828,-110.8662C610.2828,-102.6423 610.2828,-93.8301 610.2828,-85.267\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"613.7829,-85.0017 610.2828,-75.0017 606.7829,-85.0017 613.7829,-85.0017\"/>\n",
       "</g>\n",
       "<!-- int1 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>int1</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"1023.2828\" cy=\"-148.43\" rx=\"197.0658\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"1023.2828\" y=\"-159.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">int1</text>\n",
       "<text text-anchor=\"middle\" x=\"1023.2828\" y=\"-144.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[Node] This is a node in a computational graph.</text>\n",
       "<text text-anchor=\"middle\" x=\"1023.2828\" y=\"-129.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">0</text>\n",
       "</g>\n",
       "<!-- int1&#45;&gt;fun0 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>int1&#45;&gt;fun0</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M909.3355,-117.8178C856.7426,-103.6886 794.0507,-86.8463 739.6841,-72.2406\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"740.5511,-68.8495 729.9854,-69.635 738.7348,-75.6098 740.5511,-68.8495\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7f512cf0f0d0>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# This is traceable\n",
    "node_dict = {\"one\": one, \"zero\": zero}\n",
    "\n",
    "\n",
    "@bundle(\"[fun] Return one if input x is True, otherwise return zero\", node_dict=node_dict)\n",
    "def fun(x):\n",
    "    return one.data if x else zero.data\n",
    "\n",
    "\n",
    "y = fun(x)\n",
    "y.backward(visualize=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Broadcasting\n",
    "Using `apply_op`, we can broadcast node operators to a container of nodes. A container of nodes are either `list`, `tuple`, `dict`, or subclass of an abstract class `BaseModule`. `apply_op` recursively applies the operator to all nodes in the container. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x [1, 2, 1]\n",
      "y [3, 4, 2]\n",
      "Elements in z should be added, except for the last one. Value:  [4, 6, 1]\n",
      "1+3=4\n",
      "0==0==0\n",
      "x_plus_y.x should be added. Value:  xy\n",
      "x_plus_y.y should be added. Value:  [2, 4]\n",
      "x_plus_y.z should be not added, just 1. Value:  1\n"
     ]
    }
   ],
   "source": [
    "from autogen.trace import apply_op, node, NodeContainer\n",
    "from autogen.trace import operators as ops\n",
    "\n",
    "import copy\n",
    "\n",
    "# Using list as a node container\n",
    "x = [node(1), node(2), 1]\n",
    "y = [node(3), node(4), 2]\n",
    "z = copy.deepcopy(x)\n",
    "z = apply_op(ops.add, z, x, y)\n",
    "print(\"x\", [x[0].data, x[1].data, x[2]])\n",
    "print(\"y\", [y[0].data, y[1].data, y[2]])\n",
    "print(\"Elements in z should be added, except for the last one. Value: \", [z[0].data, z[1].data, z[2]])\n",
    "\n",
    "\n",
    "# Using list as a node container\n",
    "x = dict(a=node(1), b=0)\n",
    "y = dict(a=node(3), b=0)\n",
    "z = copy.deepcopy(x)\n",
    "z = apply_op(ops.add, z, x, y)\n",
    "print(f\"{x['a'].data}+{y['a'].data}={z['a'].data}\")\n",
    "print(f\"{x['b']}=={y['b']}=={z['b']}\")\n",
    "\n",
    "# Using a custom class as a node container\n",
    "\n",
    "\n",
    "class Foo(NodeContainer):\n",
    "    def __init__(self, x):\n",
    "        self.x = node(x)\n",
    "        self.y = [node(1), node(2)]\n",
    "        self.z = 1\n",
    "\n",
    "\n",
    "x = Foo(\"x\")\n",
    "y = Foo(\"y\")\n",
    "x_plus_y = Foo(\"template\")\n",
    "x_plus_y = apply_op(ops.add, x_plus_y, x, y)\n",
    "print(\"x_plus_y.x should be added. Value: \", x_plus_y.x.data)\n",
    "print(\"x_plus_y.y should be added. Value: \", [n.data for n in x_plus_y.y])\n",
    "print(\"x_plus_y.z should be not added, just 1. Value: \", x_plus_y.z)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Nodes and Python Data Structure\n",
    "\n",
    "We can create a `node` over Python data structure like dictionary, tuple, set, or list. We automatically handle the iteration and you can wrap a node around any data structure and use them like normal python objects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MessageNode: (getitem:0, dtype=<class 'str'>, data=arg1)\n",
      "MessageNode: (getitem:1, dtype=<class 'str'>, data=arg2)\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"756pt\" height=\"305pt\"\n",
       " viewBox=\"0.00 0.00 755.57 304.86\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 300.86)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-300.86 751.5656,-300.86 751.5656,4 -4,4\"/>\n",
       "<!-- to_list0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>to_list0</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"550.2828\" cy=\"-148.43\" rx=\"137.7717\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"550.2828\" y=\"-159.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">to_list0</text>\n",
       "<text text-anchor=\"middle\" x=\"550.2828\" y=\"-144.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[to_list] This converts x to a list.</text>\n",
       "<text text-anchor=\"middle\" x=\"550.2828\" y=\"-129.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[&#39;arg1&#39;, &#39;arg2&#39;]</text>\n",
       "</g>\n",
       "<!-- getitem1 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>getitem1</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"373.2828\" cy=\"-37.4767\" rx=\"230.5336\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"373.2828\" y=\"-48.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">getitem1</text>\n",
       "<text text-anchor=\"middle\" x=\"373.2828\" y=\"-33.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[getitem] This is a getitem operator of x based on index.</text>\n",
       "<text text-anchor=\"middle\" x=\"373.2828\" y=\"-18.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">arg2</text>\n",
       "</g>\n",
       "<!-- to_list0&#45;&gt;getitem1 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>to_list0&#45;&gt;getitem1</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M495.2491,-113.9319C477.8069,-102.9981 458.3308,-90.7894 440.1518,-79.3939\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"441.6783,-76.2199 431.3464,-73.8741 437.9604,-82.151 441.6783,-76.2199\"/>\n",
       "</g>\n",
       "<!-- int15 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>int15</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"197.2828\" cy=\"-148.43\" rx=\"197.0658\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-159.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">int15</text>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-144.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[Node] This is a node in a computational graph.</text>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-129.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">1</text>\n",
       "</g>\n",
       "<!-- int15&#45;&gt;getitem1 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>int15&#45;&gt;getitem1</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M254.4249,-112.4066C271.0923,-101.8992 289.4507,-90.3258 306.65,-79.4831\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"308.9306,-82.1829 315.5234,-73.8891 305.1975,-76.2613 308.9306,-82.1829\"/>\n",
       "</g>\n",
       "<!-- set0 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>set0</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"550.2828\" cy=\"-259.3833\" rx=\"197.0658\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"550.2828\" y=\"-270.6833\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">set0</text>\n",
       "<text text-anchor=\"middle\" x=\"550.2828\" y=\"-255.6833\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[Node] This is a node in a computational graph.</text>\n",
       "<text text-anchor=\"middle\" x=\"550.2828\" y=\"-240.6833\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">{&#39;arg1&#39;, &#39;arg2&#39;}</text>\n",
       "</g>\n",
       "<!-- set0&#45;&gt;to_list0 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>set0&#45;&gt;to_list0</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M550.2828,-221.8196C550.2828,-213.5956 550.2828,-204.7834 550.2828,-196.2203\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"553.7829,-195.955 550.2828,-185.955 546.7829,-195.9551 553.7829,-195.955\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7f512cf28d90>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from autogen.trace import node\n",
    "\n",
    "args = node({\"arg1\", \"arg2\"}, trainable=False)\n",
    "for a in args:\n",
    "    print(a)\n",
    "\n",
    "a.backward(visualize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MessageNode: (getitem:2, dtype=<class 'str'>, data=arg1) MessageNode: (getitem:3, dtype=<class 'int'>, data=1)\n",
      "MessageNode: (getitem:4, dtype=<class 'str'>, data=arg2) MessageNode: (getitem:5, dtype=<class 'int'>, data=2)\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"816pt\" height=\"194pt\"\n",
       " viewBox=\"0.00 0.00 815.57 193.91\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 189.9066)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-189.9066 811.5656,-189.9066 811.5656,4 -4,4\"/>\n",
       "<!-- list0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>list0</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"197.2828\" cy=\"-148.43\" rx=\"197.0658\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-159.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">list0</text>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-144.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[Node] This is a node in a computational graph.</text>\n",
       "<text text-anchor=\"middle\" x=\"197.2828\" y=\"-129.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[1, 2]</text>\n",
       "</g>\n",
       "<!-- getitem5 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>getitem5</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"403.2828\" cy=\"-37.4767\" rx=\"230.5336\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"403.2828\" y=\"-48.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">getitem5</text>\n",
       "<text text-anchor=\"middle\" x=\"403.2828\" y=\"-33.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[getitem] This is a getitem operator of x based on index.</text>\n",
       "<text text-anchor=\"middle\" x=\"403.2828\" y=\"-18.7767\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">2</text>\n",
       "</g>\n",
       "<!-- list0&#45;&gt;getitem5 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>list0&#45;&gt;getitem5</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M263.0289,-113.0186C283.5319,-101.9755 306.3109,-89.7065 327.4466,-78.3227\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"329.2482,-81.3277 336.3927,-73.5042 325.9288,-75.1648 329.2482,-81.3277\"/>\n",
       "</g>\n",
       "<!-- int19 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>int19</title>\n",
       "<ellipse fill=\"none\" stroke=\"#000000\" cx=\"610.2828\" cy=\"-148.43\" rx=\"197.0658\" ry=\"37.4533\"/>\n",
       "<text text-anchor=\"middle\" x=\"610.2828\" y=\"-159.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">int19</text>\n",
       "<text text-anchor=\"middle\" x=\"610.2828\" y=\"-144.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">[Node] This is a node in a computational graph.</text>\n",
       "<text text-anchor=\"middle\" x=\"610.2828\" y=\"-129.73\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">1</text>\n",
       "</g>\n",
       "<!-- int19&#45;&gt;getitem5 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>int19&#45;&gt;getitem5</title>\n",
       "<path fill=\"none\" stroke=\"#000000\" d=\"M544.2176,-113.0186C523.615,-101.9755 500.7254,-89.7065 479.4871,-78.3227\"/>\n",
       "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"480.9648,-75.1437 470.4976,-73.5042 477.6578,-81.3133 480.9648,-75.1437\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7f5205ed1bb0>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parms = node([1, 2], trainable=False)\n",
    "args = node([\"arg1\", \"arg2\"], trainable=False)\n",
    "\n",
    "for a, p in zip(args, parms):\n",
    "    print(a, p)\n",
    "\n",
    "p.backward(visualize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "autogen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
