{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.7/site-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n",
      "  import pandas.util.testing as tm\n"
     ]
    }
   ],
   "source": [
    "from experiments.structural_equations import *\n",
    "from experiments.data_generation import ExperimentationModel\n",
    "#from data_generation import ExperimentationModel\n",
    "import dowhy.gcm as cy\n",
    "import torch\n",
    "from experiments.exp_helper import *\n",
    "\n",
    "from dowhy.gcm  import (\n",
    "    counterfactual_samples, \n",
    "    FunctionalCausalModel, \n",
    "    StochasticModel, \n",
    "    StructuralCausalModel, \n",
    "    is_root_node)\n",
    "    \n",
    "n = 50\n",
    "scm_type = \"sachs\"\n",
    "equations_type = \"nonadditive\"\n",
    "g = get_graph(scm_type)\n",
    "\n",
    "weights = get_weight_matrices(g,equations_type, scm_type)\n",
    "\n",
    "g_sort = nx.topological_sort(g)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['x1', 'x4', 'x2', 'x5', 'x3', 'x6', 'x8', 'x7', 'x9', 'x10', 'x11']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(g_sort)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(7466, 11)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>praf</th>\n",
       "      <th>pmek</th>\n",
       "      <th>plcg</th>\n",
       "      <th>PIP2</th>\n",
       "      <th>PIP3</th>\n",
       "      <th>p44/42</th>\n",
       "      <th>pakts473</th>\n",
       "      <th>PKA</th>\n",
       "      <th>PKC</th>\n",
       "      <th>P38</th>\n",
       "      <th>pjnk</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>26.4</td>\n",
       "      <td>13.2</td>\n",
       "      <td>8.82</td>\n",
       "      <td>18.30</td>\n",
       "      <td>58.80</td>\n",
       "      <td>6.61</td>\n",
       "      <td>17.0</td>\n",
       "      <td>414.0</td>\n",
       "      <td>17.00</td>\n",
       "      <td>44.9</td>\n",
       "      <td>40.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>35.9</td>\n",
       "      <td>16.5</td>\n",
       "      <td>12.30</td>\n",
       "      <td>16.80</td>\n",
       "      <td>8.13</td>\n",
       "      <td>18.60</td>\n",
       "      <td>32.5</td>\n",
       "      <td>352.0</td>\n",
       "      <td>3.37</td>\n",
       "      <td>16.5</td>\n",
       "      <td>61.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>59.4</td>\n",
       "      <td>44.1</td>\n",
       "      <td>14.60</td>\n",
       "      <td>10.20</td>\n",
       "      <td>13.00</td>\n",
       "      <td>14.90</td>\n",
       "      <td>32.5</td>\n",
       "      <td>403.0</td>\n",
       "      <td>11.40</td>\n",
       "      <td>31.9</td>\n",
       "      <td>19.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>73.0</td>\n",
       "      <td>82.8</td>\n",
       "      <td>23.10</td>\n",
       "      <td>13.50</td>\n",
       "      <td>1.29</td>\n",
       "      <td>5.83</td>\n",
       "      <td>11.8</td>\n",
       "      <td>528.0</td>\n",
       "      <td>13.70</td>\n",
       "      <td>28.6</td>\n",
       "      <td>23.1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>33.7</td>\n",
       "      <td>19.8</td>\n",
       "      <td>5.19</td>\n",
       "      <td>9.73</td>\n",
       "      <td>24.80</td>\n",
       "      <td>21.10</td>\n",
       "      <td>46.1</td>\n",
       "      <td>305.0</td>\n",
       "      <td>4.66</td>\n",
       "      <td>25.7</td>\n",
       "      <td>81.3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   praf  pmek   plcg   PIP2   PIP3  p44/42  pakts473    PKA    PKC   P38  pjnk\n",
       "0  26.4  13.2   8.82  18.30  58.80    6.61      17.0  414.0  17.00  44.9  40.0\n",
       "1  35.9  16.5  12.30  16.80   8.13   18.60      32.5  352.0   3.37  16.5  61.5\n",
       "2  59.4  44.1  14.60  10.20  13.00   14.90      32.5  403.0  11.40  31.9  19.5\n",
       "3  73.0  82.8  23.10  13.50   1.29    5.83      11.8  528.0  13.70  28.6  23.1\n",
       "4  33.7  19.8   5.19   9.73  24.80   21.10      46.1  305.0   4.66  25.7  81.3"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from cdt.data import load_dataset\n",
    "data_sachs, graph_sachs = load_dataset(\"sachs\")\n",
    "\n",
    "data_sachs.dropna(inplace=True)\n",
    "print(data_sachs.shape)\n",
    "data_sachs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'P38',\n",
       " 'PIP2',\n",
       " 'PIP3',\n",
       " 'PKA',\n",
       " 'PKC',\n",
       " 'p44/42',\n",
       " 'pakts473',\n",
       " 'pjnk',\n",
       " 'plcg',\n",
       " 'pmek',\n",
       " 'praf'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "set(graph_sachs.nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 7.1.0 (20230121.1956)\n -->\n<!-- Pages: 1 -->\n<svg width=\"432pt\" height=\"566pt\"\n viewBox=\"0.00 0.00 432.41 566.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 562)\">\n<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-562 428.41,-562 428.41,4 -4,4\"/>\n<!-- praf -->\n<g id=\"node1\" class=\"node\">\n<title>praf</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"134.41\" cy=\"-192\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"134.41\" y=\"-188.3\" font-family=\"Times,serif\" font-size=\"14.00\">praf</text>\n</g>\n<!-- plcg -->\n<g id=\"node3\" class=\"node\">\n<title>plcg</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"197.41\" cy=\"-105\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"197.41\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">plcg</text>\n</g>\n<!-- praf&#45;&gt;plcg -->\n<g id=\"edge7\" class=\"edge\">\n<title>praf&#45;&gt;plcg</title>\n<path fill=\"none\" stroke=\"black\" d=\"M144.32,-174.89C150.63,-164.91 159.15,-151.96 167.41,-141 170.29,-137.19 173.45,-133.26 176.61,-129.46\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"179.03,-132.02 182.86,-122.14 173.71,-127.47 179.03,-132.02\"/>\n<text text-anchor=\"middle\" x=\"176.41\" y=\"-144.8\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- pmek -->\n<g id=\"node2\" class=\"node\">\n<title>pmek</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"134.41\" cy=\"-279\" rx=\"30.59\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"134.41\" y=\"-275.3\" font-family=\"Times,serif\" font-size=\"14.00\">pmek</text>\n</g>\n<!-- pmek&#45;&gt;praf -->\n<g id=\"edge1\" class=\"edge\">\n<title>pmek&#45;&gt;praf</title>\n<path fill=\"none\" stroke=\"black\" d=\"M134.41,-260.8C134.41,-249.58 134.41,-234.67 134.41,-221.69\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"137.91,-221.98 134.41,-211.98 130.91,-221.98 137.91,-221.98\"/>\n<text text-anchor=\"middle\" x=\"143.41\" y=\"-231.8\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- pmek&#45;&gt;plcg -->\n<g id=\"edge8\" class=\"edge\">\n<title>pmek&#45;&gt;plcg</title>\n<path fill=\"none\" stroke=\"black\" d=\"M145.26,-261.94C149.01,-256.09 153.09,-249.36 156.41,-243 174.81,-207.76 186.24,-163.01 192.21,-134.43\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"195.62,-135.22 194.15,-124.73 188.76,-133.85 195.62,-135.22\"/>\n<text text-anchor=\"middle\" x=\"192.41\" y=\"-188.3\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- PIP2 -->\n<g id=\"node4\" class=\"node\">\n<title>PIP2</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"197.41\" cy=\"-18\" rx=\"27.9\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"197.41\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">PIP2</text>\n</g>\n<!-- plcg&#45;&gt;PIP2 -->\n<g id=\"edge9\" class=\"edge\">\n<title>plcg&#45;&gt;PIP2</title>\n<path fill=\"none\" stroke=\"black\" d=\"M197.41,-86.8C197.41,-75.58 197.41,-60.67 197.41,-47.69\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"200.91,-47.98 197.41,-37.98 193.91,-47.98 200.91,-47.98\"/>\n<text text-anchor=\"middle\" x=\"206.41\" y=\"-57.8\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- PIP2&#45;&gt;praf -->\n<g id=\"edge2\" class=\"edge\">\n<title>PIP2&#45;&gt;praf</title>\n<path fill=\"none\" stroke=\"black\" d=\"M181.86,-33.26C169.11,-45.99 151.83,-65.85 143.41,-87 133.79,-111.18 132.13,-140.96 132.53,-162.45\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"129.02,-162.36 132.89,-172.23 136.02,-162.11 129.02,-162.36\"/>\n<text text-anchor=\"middle\" x=\"152.41\" y=\"-101.3\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- PIP3 -->\n<g id=\"node5\" class=\"node\">\n<title>PIP3</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"226.41\" cy=\"-366\" rx=\"27.9\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"226.41\" y=\"-362.3\" font-family=\"Times,serif\" font-size=\"14.00\">PIP3</text>\n</g>\n<!-- PIP3&#45;&gt;pmek -->\n<g id=\"edge3\" class=\"edge\">\n<title>PIP3&#45;&gt;pmek</title>\n<path fill=\"none\" stroke=\"black\" d=\"M206.53,-352.81C196.98,-346.54 185.63,-338.44 176.41,-330 168.06,-322.35 159.84,-313.06 152.91,-304.58\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"155.84,-302.65 146.89,-296.98 150.36,-306.99 155.84,-302.65\"/>\n<text text-anchor=\"middle\" x=\"185.41\" y=\"-318.8\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- pakts473 -->\n<g id=\"node7\" class=\"node\">\n<title>pakts473</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"273.41\" cy=\"-279\" rx=\"42.49\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"273.41\" y=\"-275.3\" font-family=\"Times,serif\" font-size=\"14.00\">pakts473</text>\n</g>\n<!-- PIP3&#45;&gt;pakts473 -->\n<g id=\"edge11\" class=\"edge\">\n<title>PIP3&#45;&gt;pakts473</title>\n<path fill=\"none\" stroke=\"black\" d=\"M235.48,-348.61C242.02,-336.77 251.01,-320.52 258.6,-306.8\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"261.55,-308.69 263.33,-298.24 255.42,-305.3 261.55,-308.69\"/>\n<text text-anchor=\"middle\" x=\"262.41\" y=\"-318.8\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- p44/42 -->\n<g id=\"node6\" class=\"node\">\n<title>p44/42</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"145.41\" cy=\"-366\" rx=\"35.19\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"145.41\" y=\"-362.3\" font-family=\"Times,serif\" font-size=\"14.00\">p44/42</text>\n</g>\n<!-- p44/42&#45;&gt;pmek -->\n<g id=\"edge4\" class=\"edge\">\n<title>p44/42&#45;&gt;pmek</title>\n<path fill=\"none\" stroke=\"black\" d=\"M139.63,-348.1C137.95,-342.42 136.32,-336 135.41,-330 134.37,-323.07 133.89,-315.54 133.71,-308.47\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"137.21,-308.72 133.67,-298.73 130.21,-308.74 137.21,-308.72\"/>\n<text text-anchor=\"middle\" x=\"144.41\" y=\"-318.8\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- p44/42&#45;&gt;pakts473 -->\n<g id=\"edge12\" class=\"edge\">\n<title>p44/42&#45;&gt;pakts473</title>\n<path fill=\"none\" stroke=\"black\" d=\"M166.26,-351.15C186.72,-337.57 218.18,-316.68 241.73,-301.04\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"243.63,-303.98 250.02,-295.53 239.76,-298.15 243.63,-303.98\"/>\n<text text-anchor=\"middle\" x=\"226.41\" y=\"-318.8\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- PKA -->\n<g id=\"node8\" class=\"node\">\n<title>PKA</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"63.41\" cy=\"-366\" rx=\"28.7\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"63.41\" y=\"-362.3\" font-family=\"Times,serif\" font-size=\"14.00\">PKA</text>\n</g>\n<!-- PKA&#45;&gt;pmek -->\n<g id=\"edge5\" class=\"edge\">\n<title>PKA&#45;&gt;pmek</title>\n<path fill=\"none\" stroke=\"black\" d=\"M54.33,-348.51C49.94,-338.15 46.9,-324.91 53.41,-315 62.55,-301.09 78.52,-292.69 93.88,-287.62\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"94.74,-291.01 103.39,-284.9 92.81,-284.28 94.74,-291.01\"/>\n<text text-anchor=\"middle\" x=\"62.41\" y=\"-318.8\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- PKA&#45;&gt;pakts473 -->\n<g id=\"edge13\" class=\"edge\">\n<title>PKA&#45;&gt;pakts473</title>\n<path fill=\"none\" stroke=\"black\" d=\"M74.92,-349.05C83.88,-337.82 97.31,-323.35 112.41,-315 122.21,-309.59 180.21,-297.71 223.81,-289.3\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"224.46,-292.74 233.62,-287.42 223.14,-285.86 224.46,-292.74\"/>\n<text text-anchor=\"middle\" x=\"121.41\" y=\"-318.8\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- PKC -->\n<g id=\"node9\" class=\"node\">\n<title>PKC</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"220.41\" cy=\"-453\" rx=\"27.9\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"220.41\" y=\"-449.3\" font-family=\"Times,serif\" font-size=\"14.00\">PKC</text>\n</g>\n<!-- PKC&#45;&gt;pmek -->\n<g id=\"edge6\" class=\"edge\">\n<title>PKC&#45;&gt;pmek</title>\n<path fill=\"none\" stroke=\"black\" d=\"M192.05,-452.62C144.65,-451.93 50.35,-442.77 7.41,-384 -22.31,-343.31 49.62,-309.21 96.95,-292.01\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"98.09,-295.32 106.37,-288.71 95.78,-288.72 98.09,-295.32\"/>\n<text text-anchor=\"middle\" x=\"16.41\" y=\"-362.3\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- PKC&#45;&gt;pakts473 -->\n<g id=\"edge14\" class=\"edge\">\n<title>PKC&#45;&gt;pakts473</title>\n<path fill=\"none\" stroke=\"black\" d=\"M232.51,-436.38C242.21,-423.12 255.48,-403.24 263.41,-384 272.79,-361.27 272.73,-354.44 275.41,-330 276.18,-323.03 276.26,-315.48 276.02,-308.42\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"279.53,-308.46 275.45,-298.68 272.54,-308.87 279.53,-308.46\"/>\n<text text-anchor=\"middle\" x=\"282.41\" y=\"-362.3\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- PKC&#45;&gt;PKA -->\n<g id=\"edge17\" class=\"edge\">\n<title>PKC&#45;&gt;PKA</title>\n<path fill=\"none\" stroke=\"black\" d=\"M199.57,-440.71C172.66,-426.15 125.62,-400.68 94.49,-383.83\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"96.39,-380.87 85.93,-379.19 93.06,-387.03 96.39,-380.87\"/>\n<text text-anchor=\"middle\" x=\"160.41\" y=\"-405.8\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- P38 -->\n<g id=\"node10\" class=\"node\">\n<title>P38</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"278.41\" cy=\"-540\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"278.41\" y=\"-536.3\" font-family=\"Times,serif\" font-size=\"14.00\">P38</text>\n</g>\n<!-- P38&#45;&gt;pakts473 -->\n<g id=\"edge15\" class=\"edge\">\n<title>P38&#45;&gt;pakts473</title>\n<path fill=\"none\" stroke=\"black\" d=\"M281.82,-521.89C288.06,-488.49 300.25,-412.24 295.41,-348 294.3,-333.21 294.82,-329.17 290.41,-315 289.62,-312.44 288.66,-309.84 287.61,-307.26\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"290.91,-306.05 283.58,-298.4 284.54,-308.95 290.91,-306.05\"/>\n<text text-anchor=\"middle\" x=\"305.41\" y=\"-405.8\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- P38&#45;&gt;PKC -->\n<g id=\"edge18\" class=\"edge\">\n<title>P38&#45;&gt;PKC</title>\n<path fill=\"none\" stroke=\"black\" d=\"M267.78,-523.41C259.33,-511.03 247.34,-493.47 237.54,-479.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"240.56,-477.32 232.03,-471.03 234.78,-481.26 240.56,-477.32\"/>\n<text text-anchor=\"middle\" x=\"262.41\" y=\"-492.8\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- pjnk -->\n<g id=\"node11\" class=\"node\">\n<title>pjnk</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"397.41\" cy=\"-366\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"397.41\" y=\"-362.3\" font-family=\"Times,serif\" font-size=\"14.00\">pjnk</text>\n</g>\n<!-- pjnk&#45;&gt;PIP2 -->\n<g id=\"edge10\" class=\"edge\">\n<title>pjnk&#45;&gt;PIP2</title>\n<path fill=\"none\" stroke=\"black\" d=\"M399.02,-347.76C400.48,-330.6 402.41,-303.5 402.41,-280 402.41,-280 402.41,-280 402.41,-104 402.41,-33.02 296.26,-20.18 236.73,-18.56\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"237.06,-15.07 227,-18.4 236.94,-22.06 237.06,-15.07\"/>\n<text text-anchor=\"middle\" x=\"411.41\" y=\"-188.3\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n<!-- pjnk&#45;&gt;pakts473 -->\n<g id=\"edge16\" class=\"edge\">\n<title>pjnk&#45;&gt;pakts473</title>\n<path fill=\"none\" stroke=\"black\" d=\"M384.98,-350.06C375.49,-339.32 361.67,-325.04 347.41,-315 337.38,-307.93 325.69,-301.65 314.6,-296.42\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"316.17,-293.29 305.61,-292.37 313.29,-299.67 316.17,-293.29\"/>\n<text text-anchor=\"middle\" x=\"374.41\" y=\"-318.8\" font-family=\"Times,serif\" font-size=\"14.00\">1.0</text>\n</g>\n</g>\n</svg>\n",
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7ffb8d59c950>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import graphviz\n",
    "import networkx as nx \n",
    "\n",
    "def make_graph(adjacency_matrix, labels=None):\n",
    "    idx = np.abs(adjacency_matrix) > 0.01\n",
    "    dirs = np.where(idx)\n",
    "    d = graphviz.Digraph(engine='dot')\n",
    "    names = labels if labels else [f'x{i}' for i in range(len(adjacency_matrix))]\n",
    "    for name in names:\n",
    "        d.node(name)\n",
    "    for to, from_, coef in zip(dirs[0], dirs[1], adjacency_matrix[idx]):\n",
    "        d.edge(names[from_], names[to], label=str(coef))\n",
    "    return d\n",
    "\n",
    "labels = [f'{col}' for i, col in enumerate(data_sachs.columns)]\n",
    "adj_matrix = nx.to_numpy_matrix(graph_sachs)\n",
    "adj_matrix = np.asarray(adj_matrix)\n",
    "graph_dot = make_graph(adj_matrix, labels)\n",
    "display(graph_dot)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x1 0 True\n",
      "x2 0 True\n",
      "x3 0 True\n",
      "x4 0 True\n",
      "x5 0 True\n",
      "x6 0 True\n",
      "x22 0 True\n",
      "x23 0 True\n",
      "x24 0 True\n",
      "x25 0 True\n",
      "x26 0 True\n",
      "x27 0 True\n",
      "x7 3 False\n",
      "x8 3 False\n",
      "x9 3 False\n",
      "x10 3 False\n",
      "x11 3 False\n",
      "x12 3 False\n",
      "x16 6 False\n",
      "x17 6 False\n",
      "x18 6 False\n",
      "x19 6 False\n",
      "x20 6 False\n",
      "x21 6 False\n",
      "x13 6 False\n",
      "x14 6 False\n",
      "x15 6 False\n",
      "x28 9 False\n",
      "x29 9 False\n",
      "x30 9 False\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "for node in g_sort:\n",
    "    print(node, g.in_degree(node), is_root_node(g,node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data, graph = load_dataset(\"sachs\")\n",
    "output = obj.orient_graph(data, nx.Graph(graph))\n",
    "\n",
    "        #To view the directed graph run the following command\n",
    "nx.draw_networkx(output, font_size=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No GPU automatically detected. Setting SETTINGS.GPU to 0, and SETTINGS.NJOBS to cpu_count.\n",
      "Fitting causal mechanism of node pakts473: 100%|██████████| 11/11 [00:00<00:00, 629.33it/s]\n",
      "overflow encountered in exp\n",
      "invalid value encountered in true_divide\n",
      "Fitting causal mechanism of node pakts473: 100%|██████████| 11/11 [00:02<00:00,  5.17it/s]\n"
     ]
    }
   ],
   "source": [
    "structural_equations, noise_distributions = select_struct_and_noise(equations_type, scm_type, weights, g)\n",
    "exper_model = ExperimentationModel(g, scm_type, structural_equations, noise_distributions)\n",
    "factual, noise = exper_model.sample(n)\n",
    "\n",
    "params = {'num_epochs' : 200,\n",
    "          'lr' : 1e-4,\n",
    "          'batch_size': 64,\n",
    "          'hidden_dim' : 64}\n",
    "\n",
    "diff_model = create_diff_model(scm_type, params, g)\n",
    "\n",
    "cy.fit(diff_model, factual)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<dowhy.gcm.stochastic_models.EmpiricalDistribution at 0x7fac539a9b90>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "noise_distributions[\"plcg\"] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "c = noise_distributions[\"plcg\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'EmpiricalDistribution'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[4.]])"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "structural_equations['x4'](np.array([10]),np.array([5]),np.array([6]),np.array([7]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>x4</th>\n",
       "      <th>x5</th>\n",
       "      <th>x6</th>\n",
       "      <th>x7</th>\n",
       "      <th>x8</th>\n",
       "      <th>x9</th>\n",
       "      <th>x10</th>\n",
       "      <th>...</th>\n",
       "      <th>x21</th>\n",
       "      <th>x22</th>\n",
       "      <th>x23</th>\n",
       "      <th>x24</th>\n",
       "      <th>x25</th>\n",
       "      <th>x26</th>\n",
       "      <th>x27</th>\n",
       "      <th>x28</th>\n",
       "      <th>x29</th>\n",
       "      <th>x30</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.280730</td>\n",
       "      <td>-0.736336</td>\n",
       "      <td>-0.518266</td>\n",
       "      <td>1.514953</td>\n",
       "      <td>0.003381</td>\n",
       "      <td>-0.774351</td>\n",
       "      <td>7.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>...</td>\n",
       "      <td>21.0</td>\n",
       "      <td>0.322201</td>\n",
       "      <td>-2.600692</td>\n",
       "      <td>0.839476</td>\n",
       "      <td>-0.748764</td>\n",
       "      <td>-1.151320</td>\n",
       "      <td>0.334883</td>\n",
       "      <td>28.0</td>\n",
       "      <td>29.0</td>\n",
       "      <td>30.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.194791</td>\n",
       "      <td>1.218467</td>\n",
       "      <td>0.282073</td>\n",
       "      <td>0.753858</td>\n",
       "      <td>1.779787</td>\n",
       "      <td>1.026975</td>\n",
       "      <td>7.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>...</td>\n",
       "      <td>21.0</td>\n",
       "      <td>0.815590</td>\n",
       "      <td>-0.595904</td>\n",
       "      <td>2.049084</td>\n",
       "      <td>-1.337861</td>\n",
       "      <td>-0.709285</td>\n",
       "      <td>-0.180292</td>\n",
       "      <td>28.0</td>\n",
       "      <td>29.0</td>\n",
       "      <td>30.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.019362</td>\n",
       "      <td>0.613177</td>\n",
       "      <td>-0.366826</td>\n",
       "      <td>-1.333737</td>\n",
       "      <td>1.180799</td>\n",
       "      <td>-0.251376</td>\n",
       "      <td>7.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>...</td>\n",
       "      <td>21.0</td>\n",
       "      <td>0.754049</td>\n",
       "      <td>1.723956</td>\n",
       "      <td>-1.086761</td>\n",
       "      <td>0.706095</td>\n",
       "      <td>0.061811</td>\n",
       "      <td>-0.810736</td>\n",
       "      <td>28.0</td>\n",
       "      <td>29.0</td>\n",
       "      <td>30.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-1.977910</td>\n",
       "      <td>-0.401665</td>\n",
       "      <td>-0.952598</td>\n",
       "      <td>-0.314029</td>\n",
       "      <td>-1.125427</td>\n",
       "      <td>-1.549231</td>\n",
       "      <td>7.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>...</td>\n",
       "      <td>21.0</td>\n",
       "      <td>0.432008</td>\n",
       "      <td>-0.571919</td>\n",
       "      <td>-1.608113</td>\n",
       "      <td>0.291620</td>\n",
       "      <td>-0.576826</td>\n",
       "      <td>0.125973</td>\n",
       "      <td>28.0</td>\n",
       "      <td>29.0</td>\n",
       "      <td>30.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-0.734826</td>\n",
       "      <td>-0.803476</td>\n",
       "      <td>0.327771</td>\n",
       "      <td>-1.011682</td>\n",
       "      <td>0.566823</td>\n",
       "      <td>-0.414858</td>\n",
       "      <td>7.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>...</td>\n",
       "      <td>21.0</td>\n",
       "      <td>-1.754055</td>\n",
       "      <td>1.487604</td>\n",
       "      <td>0.105887</td>\n",
       "      <td>-0.226048</td>\n",
       "      <td>1.236473</td>\n",
       "      <td>0.719216</td>\n",
       "      <td>28.0</td>\n",
       "      <td>29.0</td>\n",
       "      <td>30.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 30 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         x1        x2        x3        x4        x5        x6   x7   x8   x9  \\\n",
       "0  1.280730 -0.736336 -0.518266  1.514953  0.003381 -0.774351  7.0  8.0  9.0   \n",
       "1  1.194791  1.218467  0.282073  0.753858  1.779787  1.026975  7.0  8.0  9.0   \n",
       "2  1.019362  0.613177 -0.366826 -1.333737  1.180799 -0.251376  7.0  8.0  9.0   \n",
       "3 -1.977910 -0.401665 -0.952598 -0.314029 -1.125427 -1.549231  7.0  8.0  9.0   \n",
       "4 -0.734826 -0.803476  0.327771 -1.011682  0.566823 -0.414858  7.0  8.0  9.0   \n",
       "\n",
       "    x10  ...   x21       x22       x23       x24       x25       x26  \\\n",
       "0  10.0  ...  21.0  0.322201 -2.600692  0.839476 -0.748764 -1.151320   \n",
       "1  10.0  ...  21.0  0.815590 -0.595904  2.049084 -1.337861 -0.709285   \n",
       "2  10.0  ...  21.0  0.754049  1.723956 -1.086761  0.706095  0.061811   \n",
       "3  10.0  ...  21.0  0.432008 -0.571919 -1.608113  0.291620 -0.576826   \n",
       "4  10.0  ...  21.0 -1.754055  1.487604  0.105887 -0.226048  1.236473   \n",
       "\n",
       "        x27   x28   x29   x30  \n",
       "0  0.334883  28.0  29.0  30.0  \n",
       "1 -0.180292  28.0  29.0  30.0  \n",
       "2 -0.810736  28.0  29.0  30.0  \n",
       "3  0.125973  28.0  29.0  30.0  \n",
       "4  0.719216  28.0  29.0  30.0  \n",
       "\n",
       "[5 rows x 30 columns]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "factual, noise = reindex_columns(column_order,exper_model.sample(n))\n",
    "factual.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "column_order  = ['x'+str(i+1) for i in range(30)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reindex_columns(column_order, *dfs):\n",
    "    if type(dfs) is tuple:\n",
    "        dfs = dfs[0]\n",
    "    result = [df.reindex(columns = column_order) for df in dfs]\n",
    "    if len(result) == 1:\n",
    "        return result[0]\n",
    "    else:\n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'factual' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/p2/qjns7zj118l97lkbqykpkg7c0000gn/T/ipykernel_40166/3849623792.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfactual\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdescribe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'factual' is not defined"
     ]
    }
   ],
   "source": [
    "factual.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cy.fit(diff_model, factual)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 1\n",
      "0 2\n",
      "0 3\n",
      "0 4\n",
      "0 5\n",
      "0 6\n",
      "0 7\n",
      "0 8\n",
      "0 9\n",
      "1 0\n",
      "1 2\n",
      "1 3\n",
      "1 4\n",
      "1 5\n",
      "1 6\n",
      "1 7\n",
      "1 8\n",
      "1 9\n",
      "2 0\n",
      "2 1\n",
      "2 3\n",
      "2 4\n",
      "2 5\n",
      "2 6\n",
      "2 7\n",
      "2 8\n",
      "2 9\n",
      "3 0\n",
      "3 1\n",
      "3 2\n",
      "3 4\n",
      "3 5\n",
      "3 6\n",
      "3 7\n",
      "3 8\n",
      "3 9\n",
      "4 0\n",
      "4 1\n",
      "4 2\n",
      "4 3\n",
      "4 5\n",
      "4 6\n",
      "4 7\n",
      "4 8\n",
      "4 9\n",
      "5 0\n",
      "5 1\n",
      "5 2\n",
      "5 3\n",
      "5 4\n",
      "5 6\n",
      "5 7\n",
      "5 8\n",
      "5 9\n",
      "6 0\n",
      "6 1\n",
      "6 2\n",
      "6 3\n",
      "6 4\n",
      "6 5\n",
      "6 7\n",
      "6 8\n",
      "6 9\n",
      "7 0\n",
      "7 1\n",
      "7 2\n",
      "7 3\n",
      "7 4\n",
      "7 5\n",
      "7 6\n",
      "7 8\n",
      "7 9\n",
      "8 0\n",
      "8 1\n",
      "8 2\n",
      "8 3\n",
      "8 4\n",
      "8 5\n",
      "8 6\n",
      "8 7\n",
      "8 9\n",
      "9 0\n",
      "9 1\n",
      "9 2\n",
      "9 3\n",
      "9 4\n",
      "9 5\n",
      "9 6\n",
      "9 7\n",
      "9 8\n"
     ]
    }
   ],
   "source": [
    "G=nx.gnp_random_graph(10,2,directed=True)\n",
    "for i,j in G.edges():\n",
    "    print(i,j)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>x4</th>\n",
       "      <th>x5</th>\n",
       "      <th>x6</th>\n",
       "      <th>x7</th>\n",
       "      <th>x8</th>\n",
       "      <th>x9</th>\n",
       "      <th>x10</th>\n",
       "      <th>...</th>\n",
       "      <th>x21</th>\n",
       "      <th>x22</th>\n",
       "      <th>x23</th>\n",
       "      <th>x24</th>\n",
       "      <th>x25</th>\n",
       "      <th>x26</th>\n",
       "      <th>x27</th>\n",
       "      <th>x28</th>\n",
       "      <th>x29</th>\n",
       "      <th>x30</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.380432</td>\n",
       "      <td>-0.054595</td>\n",
       "      <td>0.238639</td>\n",
       "      <td>2</td>\n",
       "      <td>-0.612684</td>\n",
       "      <td>-0.156689</td>\n",
       "      <td>-0.262524</td>\n",
       "      <td>1.210203</td>\n",
       "      <td>0.843139</td>\n",
       "      <td>-0.775936</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.116675</td>\n",
       "      <td>2.053635</td>\n",
       "      <td>1.602174</td>\n",
       "      <td>2.086162</td>\n",
       "      <td>1.822321</td>\n",
       "      <td>2.006210</td>\n",
       "      <td>2.046672</td>\n",
       "      <td>0.717799</td>\n",
       "      <td>0.399789</td>\n",
       "      <td>1.293791</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-0.955347</td>\n",
       "      <td>0.992932</td>\n",
       "      <td>0.383092</td>\n",
       "      <td>2</td>\n",
       "      <td>0.112514</td>\n",
       "      <td>0.337923</td>\n",
       "      <td>-0.508810</td>\n",
       "      <td>-0.669491</td>\n",
       "      <td>-1.309308</td>\n",
       "      <td>0.117035</td>\n",
       "      <td>...</td>\n",
       "      <td>-6.247526</td>\n",
       "      <td>-9.416972</td>\n",
       "      <td>-9.309222</td>\n",
       "      <td>-9.414317</td>\n",
       "      <td>12.078979</td>\n",
       "      <td>12.034266</td>\n",
       "      <td>11.978007</td>\n",
       "      <td>16.915749</td>\n",
       "      <td>17.479839</td>\n",
       "      <td>17.401810</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.401062</td>\n",
       "      <td>-1.059337</td>\n",
       "      <td>0.419592</td>\n",
       "      <td>2</td>\n",
       "      <td>0.127034</td>\n",
       "      <td>-0.751926</td>\n",
       "      <td>1.074784</td>\n",
       "      <td>-2.552156</td>\n",
       "      <td>1.899485</td>\n",
       "      <td>-2.357856</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.801044</td>\n",
       "      <td>1.467458</td>\n",
       "      <td>1.333084</td>\n",
       "      <td>1.507125</td>\n",
       "      <td>4.350846</td>\n",
       "      <td>4.356070</td>\n",
       "      <td>4.359715</td>\n",
       "      <td>3.635096</td>\n",
       "      <td>4.254907</td>\n",
       "      <td>3.310828</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-0.515384</td>\n",
       "      <td>0.204643</td>\n",
       "      <td>-0.595731</td>\n",
       "      <td>2</td>\n",
       "      <td>-0.206072</td>\n",
       "      <td>0.421071</td>\n",
       "      <td>0.780384</td>\n",
       "      <td>0.919033</td>\n",
       "      <td>1.213293</td>\n",
       "      <td>-0.737511</td>\n",
       "      <td>...</td>\n",
       "      <td>0.529430</td>\n",
       "      <td>-4.374533</td>\n",
       "      <td>-4.579927</td>\n",
       "      <td>-5.191138</td>\n",
       "      <td>0.174204</td>\n",
       "      <td>0.173448</td>\n",
       "      <td>1.273415</td>\n",
       "      <td>6.900311</td>\n",
       "      <td>5.098008</td>\n",
       "      <td>5.988788</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-1.313250</td>\n",
       "      <td>-0.658888</td>\n",
       "      <td>-0.375854</td>\n",
       "      <td>2</td>\n",
       "      <td>0.007002</td>\n",
       "      <td>0.270937</td>\n",
       "      <td>-0.790399</td>\n",
       "      <td>-0.058367</td>\n",
       "      <td>-0.749085</td>\n",
       "      <td>0.864503</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.230122</td>\n",
       "      <td>0.903154</td>\n",
       "      <td>0.750776</td>\n",
       "      <td>0.917787</td>\n",
       "      <td>2.752282</td>\n",
       "      <td>2.621456</td>\n",
       "      <td>2.630431</td>\n",
       "      <td>1.674232</td>\n",
       "      <td>1.845793</td>\n",
       "      <td>1.755621</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 30 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         x1        x2        x3  x4        x5        x6        x7        x8  \\\n",
       "0  1.380432 -0.054595  0.238639   2 -0.612684 -0.156689 -0.262524  1.210203   \n",
       "1 -0.955347  0.992932  0.383092   2  0.112514  0.337923 -0.508810 -0.669491   \n",
       "2  0.401062 -1.059337  0.419592   2  0.127034 -0.751926  1.074784 -2.552156   \n",
       "3 -0.515384  0.204643 -0.595731   2 -0.206072  0.421071  0.780384  0.919033   \n",
       "4 -1.313250 -0.658888 -0.375854   2  0.007002  0.270937 -0.790399 -0.058367   \n",
       "\n",
       "         x9       x10  ...       x21       x22       x23       x24        x25  \\\n",
       "0  0.843139 -0.775936  ... -1.116675  2.053635  1.602174  2.086162   1.822321   \n",
       "1 -1.309308  0.117035  ... -6.247526 -9.416972 -9.309222 -9.414317  12.078979   \n",
       "2  1.899485 -2.357856  ... -1.801044  1.467458  1.333084  1.507125   4.350846   \n",
       "3  1.213293 -0.737511  ...  0.529430 -4.374533 -4.579927 -5.191138   0.174204   \n",
       "4 -0.749085  0.864503  ... -1.230122  0.903154  0.750776  0.917787   2.752282   \n",
       "\n",
       "         x26        x27        x28        x29        x30  \n",
       "0   2.006210   2.046672   0.717799   0.399789   1.293791  \n",
       "1  12.034266  11.978007  16.915749  17.479839  17.401810  \n",
       "2   4.356070   4.359715   3.635096   4.254907   3.310828  \n",
       "3   0.173448   1.273415   6.900311   5.098008   5.988788  \n",
       "4   2.621456   2.630431   1.674232   1.845793   1.755621  \n",
       "\n",
       "[5 rows x 30 columns]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from dowhy.gcm  import counterfactual_samples\n",
    "intervention = {\"x4\": lambda x: 2}\n",
    "cf = counterfactual_samples(diff_model, intervention, observed_data = factual)\n",
    "cf.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>x4</th>\n",
       "      <th>x5</th>\n",
       "      <th>x6</th>\n",
       "      <th>x7</th>\n",
       "      <th>x8</th>\n",
       "      <th>x9</th>\n",
       "      <th>x10</th>\n",
       "      <th>...</th>\n",
       "      <th>x21</th>\n",
       "      <th>x22</th>\n",
       "      <th>x23</th>\n",
       "      <th>x24</th>\n",
       "      <th>x25</th>\n",
       "      <th>x26</th>\n",
       "      <th>x27</th>\n",
       "      <th>x28</th>\n",
       "      <th>x29</th>\n",
       "      <th>x30</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.380432</td>\n",
       "      <td>-0.054595</td>\n",
       "      <td>0.238639</td>\n",
       "      <td>2</td>\n",
       "      <td>-0.610865</td>\n",
       "      <td>-0.154493</td>\n",
       "      <td>-0.271134</td>\n",
       "      <td>1.200994</td>\n",
       "      <td>0.840306</td>\n",
       "      <td>0.927244</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.143358</td>\n",
       "      <td>2.082734</td>\n",
       "      <td>1.403598</td>\n",
       "      <td>2.179997</td>\n",
       "      <td>1.858397</td>\n",
       "      <td>2.030448</td>\n",
       "      <td>2.061045</td>\n",
       "      <td>0.715350</td>\n",
       "      <td>0.360021</td>\n",
       "      <td>1.253323</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-0.955347</td>\n",
       "      <td>0.992932</td>\n",
       "      <td>0.383092</td>\n",
       "      <td>2</td>\n",
       "      <td>0.110149</td>\n",
       "      <td>0.337573</td>\n",
       "      <td>-0.514662</td>\n",
       "      <td>-0.672042</td>\n",
       "      <td>-1.306220</td>\n",
       "      <td>1.346642</td>\n",
       "      <td>...</td>\n",
       "      <td>-6.241237</td>\n",
       "      <td>-9.424867</td>\n",
       "      <td>-9.361702</td>\n",
       "      <td>-9.388106</td>\n",
       "      <td>12.071120</td>\n",
       "      <td>12.010654</td>\n",
       "      <td>11.969192</td>\n",
       "      <td>16.917230</td>\n",
       "      <td>17.504717</td>\n",
       "      <td>17.530664</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.401062</td>\n",
       "      <td>-1.059337</td>\n",
       "      <td>0.419592</td>\n",
       "      <td>2</td>\n",
       "      <td>0.124072</td>\n",
       "      <td>-0.747213</td>\n",
       "      <td>1.072889</td>\n",
       "      <td>-2.547284</td>\n",
       "      <td>1.893577</td>\n",
       "      <td>0.763354</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.829005</td>\n",
       "      <td>-3.558096</td>\n",
       "      <td>-3.536239</td>\n",
       "      <td>-3.609093</td>\n",
       "      <td>4.372499</td>\n",
       "      <td>4.379045</td>\n",
       "      <td>4.368882</td>\n",
       "      <td>6.284487</td>\n",
       "      <td>6.677610</td>\n",
       "      <td>6.156301</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-0.515384</td>\n",
       "      <td>0.204643</td>\n",
       "      <td>-0.595731</td>\n",
       "      <td>2</td>\n",
       "      <td>-0.207191</td>\n",
       "      <td>0.422201</td>\n",
       "      <td>0.778857</td>\n",
       "      <td>0.910243</td>\n",
       "      <td>1.206585</td>\n",
       "      <td>1.146851</td>\n",
       "      <td>...</td>\n",
       "      <td>0.484317</td>\n",
       "      <td>-4.066872</td>\n",
       "      <td>-4.353944</td>\n",
       "      <td>-4.930606</td>\n",
       "      <td>0.210748</td>\n",
       "      <td>0.209979</td>\n",
       "      <td>1.288528</td>\n",
       "      <td>6.387726</td>\n",
       "      <td>4.610743</td>\n",
       "      <td>5.509547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-1.313250</td>\n",
       "      <td>-0.658888</td>\n",
       "      <td>-0.375854</td>\n",
       "      <td>2</td>\n",
       "      <td>0.004302</td>\n",
       "      <td>0.272771</td>\n",
       "      <td>-0.795434</td>\n",
       "      <td>-0.065140</td>\n",
       "      <td>-0.748678</td>\n",
       "      <td>1.398704</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.256000</td>\n",
       "      <td>0.314371</td>\n",
       "      <td>0.140731</td>\n",
       "      <td>0.231219</td>\n",
       "      <td>2.782066</td>\n",
       "      <td>2.642280</td>\n",
       "      <td>2.640446</td>\n",
       "      <td>1.966016</td>\n",
       "      <td>2.187673</td>\n",
       "      <td>2.153543</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 30 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         x1        x2        x3  x4        x5        x6        x7        x8  \\\n",
       "0  1.380432 -0.054595  0.238639   2 -0.610865 -0.154493 -0.271134  1.200994   \n",
       "1 -0.955347  0.992932  0.383092   2  0.110149  0.337573 -0.514662 -0.672042   \n",
       "2  0.401062 -1.059337  0.419592   2  0.124072 -0.747213  1.072889 -2.547284   \n",
       "3 -0.515384  0.204643 -0.595731   2 -0.207191  0.422201  0.778857  0.910243   \n",
       "4 -1.313250 -0.658888 -0.375854   2  0.004302  0.272771 -0.795434 -0.065140   \n",
       "\n",
       "         x9       x10  ...       x21       x22       x23       x24        x25  \\\n",
       "0  0.840306  0.927244  ... -1.143358  2.082734  1.403598  2.179997   1.858397   \n",
       "1 -1.306220  1.346642  ... -6.241237 -9.424867 -9.361702 -9.388106  12.071120   \n",
       "2  1.893577  0.763354  ... -1.829005 -3.558096 -3.536239 -3.609093   4.372499   \n",
       "3  1.206585  1.146851  ...  0.484317 -4.066872 -4.353944 -4.930606   0.210748   \n",
       "4 -0.748678  1.398704  ... -1.256000  0.314371  0.140731  0.231219   2.782066   \n",
       "\n",
       "         x26        x27        x28        x29        x30  \n",
       "0   2.030448   2.061045   0.715350   0.360021   1.253323  \n",
       "1  12.010654  11.969192  16.917230  17.504717  17.530664  \n",
       "2   4.379045   4.368882   6.284487   6.677610   6.156301  \n",
       "3   0.209979   1.288528   6.387726   4.610743   5.509547  \n",
       "4   2.642280   2.640446   1.966016   2.187673   2.153543  \n",
       "\n",
       "[5 rows x 30 columns]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gt_cf = exper_model.get_counterfactuals(intervention, noise)\n",
    "gt_cf.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1_1</th>\n",
       "      <th>x1_2</th>\n",
       "      <th>x1_3</th>\n",
       "      <th>x2_1</th>\n",
       "      <th>x2_2</th>\n",
       "      <th>x2_3</th>\n",
       "      <th>x3_1</th>\n",
       "      <th>x3_2</th>\n",
       "      <th>x3_3</th>\n",
       "      <th>x4_1</th>\n",
       "      <th>...</th>\n",
       "      <th>x7_3</th>\n",
       "      <th>x8_1</th>\n",
       "      <th>x8_2</th>\n",
       "      <th>x8_3</th>\n",
       "      <th>x9_1</th>\n",
       "      <th>x9_2</th>\n",
       "      <th>x9_3</th>\n",
       "      <th>x10_1</th>\n",
       "      <th>x10_2</th>\n",
       "      <th>x10_3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>-0.103733</td>\n",
       "      <td>-0.033400</td>\n",
       "      <td>0.566829</td>\n",
       "      <td>0.876375</td>\n",
       "      <td>1.263319</td>\n",
       "      <td>1.791450</td>\n",
       "      <td>0.493246</td>\n",
       "      <td>0.482372</td>\n",
       "      <td>0.678418</td>\n",
       "      <td>0.226348</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.633237</td>\n",
       "      <td>0.424472</td>\n",
       "      <td>-0.240294</td>\n",
       "      <td>-0.273037</td>\n",
       "      <td>-1.020092</td>\n",
       "      <td>1.890594</td>\n",
       "      <td>0.411665</td>\n",
       "      <td>-0.500063</td>\n",
       "      <td>2.160795</td>\n",
       "      <td>2.434649</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.667339</td>\n",
       "      <td>2.225528</td>\n",
       "      <td>0.741775</td>\n",
       "      <td>3.976137</td>\n",
       "      <td>3.936624</td>\n",
       "      <td>3.905270</td>\n",
       "      <td>2.312294</td>\n",
       "      <td>3.734514</td>\n",
       "      <td>3.071793</td>\n",
       "      <td>-0.572782</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.611924</td>\n",
       "      <td>-0.933952</td>\n",
       "      <td>-0.948881</td>\n",
       "      <td>-0.926097</td>\n",
       "      <td>1.574893</td>\n",
       "      <td>0.247763</td>\n",
       "      <td>-0.491311</td>\n",
       "      <td>-1.483850</td>\n",
       "      <td>-1.773281</td>\n",
       "      <td>-1.347077</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.388569</td>\n",
       "      <td>0.228234</td>\n",
       "      <td>0.012651</td>\n",
       "      <td>-0.416194</td>\n",
       "      <td>-0.172314</td>\n",
       "      <td>0.000270</td>\n",
       "      <td>0.026170</td>\n",
       "      <td>0.138145</td>\n",
       "      <td>0.320625</td>\n",
       "      <td>0.088661</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.033992</td>\n",
       "      <td>0.742928</td>\n",
       "      <td>0.837364</td>\n",
       "      <td>0.781059</td>\n",
       "      <td>1.258485</td>\n",
       "      <td>0.464501</td>\n",
       "      <td>-0.248795</td>\n",
       "      <td>-0.416523</td>\n",
       "      <td>-0.414264</td>\n",
       "      <td>-0.405229</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-0.045460</td>\n",
       "      <td>0.155043</td>\n",
       "      <td>-1.874966</td>\n",
       "      <td>-0.611283</td>\n",
       "      <td>-0.838967</td>\n",
       "      <td>-0.840588</td>\n",
       "      <td>1.906283</td>\n",
       "      <td>-1.691144</td>\n",
       "      <td>0.091714</td>\n",
       "      <td>-4.834391</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.808452</td>\n",
       "      <td>4.537625</td>\n",
       "      <td>4.859809</td>\n",
       "      <td>4.486367</td>\n",
       "      <td>1.999552</td>\n",
       "      <td>1.489966</td>\n",
       "      <td>1.371748</td>\n",
       "      <td>-5.395244</td>\n",
       "      <td>-3.546914</td>\n",
       "      <td>-5.832434</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-0.605171</td>\n",
       "      <td>1.514338</td>\n",
       "      <td>0.312392</td>\n",
       "      <td>2.308388</td>\n",
       "      <td>2.289083</td>\n",
       "      <td>2.522494</td>\n",
       "      <td>1.673373</td>\n",
       "      <td>3.769293</td>\n",
       "      <td>2.880118</td>\n",
       "      <td>0.100824</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.838072</td>\n",
       "      <td>5.897972</td>\n",
       "      <td>6.804343</td>\n",
       "      <td>6.394295</td>\n",
       "      <td>4.454259</td>\n",
       "      <td>4.291104</td>\n",
       "      <td>4.315624</td>\n",
       "      <td>-7.631972</td>\n",
       "      <td>-6.201092</td>\n",
       "      <td>-8.334439</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 30 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       x1_1      x1_2      x1_3      x2_1      x2_2      x2_3      x3_1  \\\n",
       "0 -0.103733 -0.033400  0.566829  0.876375  1.263319  1.791450  0.493246   \n",
       "1  0.667339  2.225528  0.741775  3.976137  3.936624  3.905270  2.312294   \n",
       "2  0.388569  0.228234  0.012651 -0.416194 -0.172314  0.000270  0.026170   \n",
       "3 -0.045460  0.155043 -1.874966 -0.611283 -0.838967 -0.840588  1.906283   \n",
       "4 -0.605171  1.514338  0.312392  2.308388  2.289083  2.522494  1.673373   \n",
       "\n",
       "       x3_2      x3_3      x4_1  ...      x7_3      x8_1      x8_2      x8_3  \\\n",
       "0  0.482372  0.678418  0.226348  ... -0.633237  0.424472 -0.240294 -0.273037   \n",
       "1  3.734514  3.071793 -0.572782  ... -0.611924 -0.933952 -0.948881 -0.926097   \n",
       "2  0.138145  0.320625  0.088661  ... -0.033992  0.742928  0.837364  0.781059   \n",
       "3 -1.691144  0.091714 -4.834391  ... -1.808452  4.537625  4.859809  4.486367   \n",
       "4  3.769293  2.880118  0.100824  ... -1.838072  5.897972  6.804343  6.394295   \n",
       "\n",
       "       x9_1      x9_2      x9_3     x10_1     x10_2     x10_3  \n",
       "0 -1.020092  1.890594  0.411665 -0.500063  2.160795  2.434649  \n",
       "1  1.574893  0.247763 -0.491311 -1.483850 -1.773281 -1.347077  \n",
       "2  1.258485  0.464501 -0.248795 -0.416523 -0.414264 -0.405229  \n",
       "3  1.999552  1.489966  1.371748 -5.395244 -3.546914 -5.832434  \n",
       "4  4.454259  4.291104  4.315624 -7.631972 -6.201092 -8.334439  \n",
       "\n",
       "[5 rows x 30 columns]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "factual.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "positions = {}\n",
    "for i in range(1,31):\n",
    "    col = (((i-1)%6)<=2)*1\n",
    "    row = (i-1)//6\n",
    "    positions[\"x\"+str(i)+\"_\"+str(j)] = [col*4 + j*1, row*10]\n",
    "nx.draw_networkx(g,pos=positions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[ 3,  6],\n",
       "        [ 9, 12]]),\n",
       " array([[12,  9],\n",
       "        [ 6,  3]]))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def create_f(c):\n",
    "    \n",
    "    # Create a lambda function capturing the current state of 'c'\n",
    "    f1 = lambda x, c=c.copy(): c * x\n",
    "\n",
    "    # Modify the original matrix\n",
    "    c[0, 0] = 99\n",
    "\n",
    "    # Call the lambda function with an input\n",
    "    result = f1(2)\n",
    "\n",
    "    return f1\n",
    "f1 = create_f(np.array([[1, 2], [3, 4]]))\n",
    "f2 = create_f(np.array([[4, 3], [2, 1]]))\n",
    "f1(3), f2(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['x10', 'x11', 'x12', 'x13', 'x14', 'x15']"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(g.neighbors(\"x5\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OutEdgeView([('x1', 'x7'), ('x1', 'x8'), ('x1', 'x9'), ('x1', 'x10'), ('x1', 'x11'), ('x1', 'x12'), ('x1', 'x13'), ('x1', 'x14'), ('x1', 'x15'), ('x7', 'x10'), ('x7', 'x11'), ('x7', 'x12'), ('x7', 'x13'), ('x7', 'x14'), ('x7', 'x15'), ('x8', 'x10'), ('x8', 'x11'), ('x8', 'x12'), ('x8', 'x13'), ('x8', 'x14'), ('x8', 'x15'), ('x9', 'x10'), ('x9', 'x11'), ('x9', 'x12'), ('x9', 'x13'), ('x9', 'x14'), ('x9', 'x15'), ('x2', 'x7'), ('x2', 'x8'), ('x2', 'x9'), ('x2', 'x10'), ('x2', 'x11'), ('x2', 'x12'), ('x2', 'x13'), ('x2', 'x14'), ('x2', 'x15'), ('x3', 'x7'), ('x3', 'x8'), ('x3', 'x9'), ('x3', 'x10'), ('x3', 'x11'), ('x3', 'x12'), ('x3', 'x13'), ('x3', 'x14'), ('x3', 'x15'), ('x4', 'x10'), ('x4', 'x11'), ('x4', 'x12'), ('x4', 'x13'), ('x4', 'x14'), ('x4', 'x15'), ('x5', 'x10'), ('x5', 'x11'), ('x5', 'x12'), ('x5', 'x13'), ('x5', 'x14'), ('x5', 'x15'), ('x6', 'x10'), ('x6', 'x11'), ('x6', 'x12'), ('x6', 'x13'), ('x6', 'x14'), ('x6', 'x15')])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g.edges()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "pd.set_option('display.max_columns', None)\n",
    "with open('vaca_cf.pickle', 'rb') as handle:\n",
    "    b = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 intervened\n",
      "0 children\n",
      "0 all\n",
      "1 intervened\n",
      "1 children\n",
      "1 all\n",
      "2 all\n"
     ]
    }
   ],
   "source": [
    "extended = []\n",
    "for i in range(len(b)):\n",
    "    for key in b[i]:\n",
    "        print(i,key)\n",
    "        extended.append(pd.DataFrame(b[i][key].numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>10</th>\n",
       "      <th>11</th>\n",
       "      <th>12</th>\n",
       "      <th>13</th>\n",
       "      <th>14</th>\n",
       "      <th>15</th>\n",
       "      <th>16</th>\n",
       "      <th>17</th>\n",
       "      <th>18</th>\n",
       "      <th>19</th>\n",
       "      <th>20</th>\n",
       "      <th>21</th>\n",
       "      <th>22</th>\n",
       "      <th>23</th>\n",
       "      <th>24</th>\n",
       "      <th>25</th>\n",
       "      <th>26</th>\n",
       "      <th>27</th>\n",
       "      <th>28</th>\n",
       "      <th>29</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.175660</td>\n",
       "      <td>-1.252793</td>\n",
       "      <td>-0.964547</td>\n",
       "      <td>0.697684</td>\n",
       "      <td>0.906275</td>\n",
       "      <td>-1.260133</td>\n",
       "      <td>7.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>24.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>26.0</td>\n",
       "      <td>27.0</td>\n",
       "      <td>28.0</td>\n",
       "      <td>29.0</td>\n",
       "      <td>30.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.983459</td>\n",
       "      <td>-0.406646</td>\n",
       "      <td>1.435120</td>\n",
       "      <td>0.554117</td>\n",
       "      <td>0.273827</td>\n",
       "      <td>-1.260133</td>\n",
       "      <td>7.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>24.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>26.0</td>\n",
       "      <td>27.0</td>\n",
       "      <td>28.0</td>\n",
       "      <td>29.0</td>\n",
       "      <td>30.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.334610</td>\n",
       "      <td>-1.505293</td>\n",
       "      <td>-0.273915</td>\n",
       "      <td>-0.397968</td>\n",
       "      <td>-1.780755</td>\n",
       "      <td>-1.260133</td>\n",
       "      <td>7.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>24.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>26.0</td>\n",
       "      <td>27.0</td>\n",
       "      <td>28.0</td>\n",
       "      <td>29.0</td>\n",
       "      <td>30.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-0.284237</td>\n",
       "      <td>0.883004</td>\n",
       "      <td>-0.821244</td>\n",
       "      <td>0.473253</td>\n",
       "      <td>-0.529820</td>\n",
       "      <td>-1.260133</td>\n",
       "      <td>7.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>24.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>26.0</td>\n",
       "      <td>27.0</td>\n",
       "      <td>28.0</td>\n",
       "      <td>29.0</td>\n",
       "      <td>30.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-0.511007</td>\n",
       "      <td>-0.714949</td>\n",
       "      <td>-0.155606</td>\n",
       "      <td>0.208860</td>\n",
       "      <td>-0.317596</td>\n",
       "      <td>-1.260133</td>\n",
       "      <td>7.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>9.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>14.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>23.0</td>\n",
       "      <td>24.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>26.0</td>\n",
       "      <td>27.0</td>\n",
       "      <td>28.0</td>\n",
       "      <td>29.0</td>\n",
       "      <td>30.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         0         1         2         3         4         5    6    7    8   \\\n",
       "0  1.175660 -1.252793 -0.964547  0.697684  0.906275 -1.260133  7.0  8.0  9.0   \n",
       "1  0.983459 -0.406646  1.435120  0.554117  0.273827 -1.260133  7.0  8.0  9.0   \n",
       "2  0.334610 -1.505293 -0.273915 -0.397968 -1.780755 -1.260133  7.0  8.0  9.0   \n",
       "3 -0.284237  0.883004 -0.821244  0.473253 -0.529820 -1.260133  7.0  8.0  9.0   \n",
       "4 -0.511007 -0.714949 -0.155606  0.208860 -0.317596 -1.260133  7.0  8.0  9.0   \n",
       "\n",
       "     9     10    11    12    13    14    15    16    17    18    19    20  \\\n",
       "0  10.0  11.0  12.0  13.0  14.0  15.0  16.0  17.0  18.0  19.0  20.0  21.0   \n",
       "1  10.0  11.0  12.0  13.0  14.0  15.0  16.0  17.0  18.0  19.0  20.0  21.0   \n",
       "2  10.0  11.0  12.0  13.0  14.0  15.0  16.0  17.0  18.0  19.0  20.0  21.0   \n",
       "3  10.0  11.0  12.0  13.0  14.0  15.0  16.0  17.0  18.0  19.0  20.0  21.0   \n",
       "4  10.0  11.0  12.0  13.0  14.0  15.0  16.0  17.0  18.0  19.0  20.0  21.0   \n",
       "\n",
       "     21    22    23    24    25    26    27    28    29  \n",
       "0  22.0  23.0  24.0  25.0  26.0  27.0  28.0  29.0  30.0  \n",
       "1  22.0  23.0  24.0  25.0  26.0  27.0  28.0  29.0  30.0  \n",
       "2  22.0  23.0  24.0  25.0  26.0  27.0  28.0  29.0  30.0  \n",
       "3  22.0  23.0  24.0  25.0  26.0  27.0  28.0  29.0  30.0  \n",
       "4  22.0  23.0  24.0  25.0  26.0  27.0  28.0  29.0  30.0  "
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "extended[-2].head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5001.190384768413 0\n",
      "8.589344703125107 0\n",
      "8.589344703125107 12\n",
      "5010.07999941008 0\n",
      "14.116810523437607 0\n",
      "14.116810523437607 12\n",
      "5.440602027343857 12\n"
     ]
    }
   ],
   "source": [
    "gt = pd.Series([999.762727 ,  1000.426657  , 998.827536 ,  1000.113370 ,  999.832834])\n",
    "#gt = pd.Series([999.541886, 1001.880628, 1000.031750  , 1001.290895 , 1000.017388])\n",
    "for i in range(len(extended)):\n",
    "    diffs = np.abs(np.array(extended[i].head().sub(gt,axis=0).sum(axis=0)))\n",
    "    print(np.min(diffs),np.argmin(diffs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>10</th>\n",
       "      <th>11</th>\n",
       "      <th>12</th>\n",
       "      <th>13</th>\n",
       "      <th>14</th>\n",
       "      <th>15</th>\n",
       "      <th>16</th>\n",
       "      <th>17</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>998.075134</td>\n",
       "      <td>-0.177327</td>\n",
       "      <td>0.658807</td>\n",
       "      <td>354.925232</td>\n",
       "      <td>-402.208954</td>\n",
       "      <td>-168.273041</td>\n",
       "      <td>-555.842834</td>\n",
       "      <td>891.648010</td>\n",
       "      <td>-67.056458</td>\n",
       "      <td>627.107544</td>\n",
       "      <td>671.225708</td>\n",
       "      <td>-18.162128</td>\n",
       "      <td>909.044434</td>\n",
       "      <td>211.253754</td>\n",
       "      <td>-273.006927</td>\n",
       "      <td>-658.565674</td>\n",
       "      <td>751.012024</td>\n",
       "      <td>-216.188736</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>998.074463</td>\n",
       "      <td>-0.177569</td>\n",
       "      <td>0.657873</td>\n",
       "      <td>354.924835</td>\n",
       "      <td>-402.209320</td>\n",
       "      <td>-168.272995</td>\n",
       "      <td>-555.843140</td>\n",
       "      <td>891.648315</td>\n",
       "      <td>-67.056839</td>\n",
       "      <td>627.107361</td>\n",
       "      <td>671.225525</td>\n",
       "      <td>-18.160534</td>\n",
       "      <td>909.059937</td>\n",
       "      <td>211.254059</td>\n",
       "      <td>-273.006042</td>\n",
       "      <td>-658.565857</td>\n",
       "      <td>751.012634</td>\n",
       "      <td>-216.187927</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>998.075134</td>\n",
       "      <td>-0.177123</td>\n",
       "      <td>0.658522</td>\n",
       "      <td>354.924835</td>\n",
       "      <td>-402.209259</td>\n",
       "      <td>-168.273544</td>\n",
       "      <td>-555.842896</td>\n",
       "      <td>891.648193</td>\n",
       "      <td>-67.056206</td>\n",
       "      <td>627.107422</td>\n",
       "      <td>671.225525</td>\n",
       "      <td>-18.161379</td>\n",
       "      <td>909.056152</td>\n",
       "      <td>211.252594</td>\n",
       "      <td>-273.009186</td>\n",
       "      <td>-658.565247</td>\n",
       "      <td>751.012634</td>\n",
       "      <td>-216.187775</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>998.074280</td>\n",
       "      <td>-0.177889</td>\n",
       "      <td>0.657544</td>\n",
       "      <td>354.923859</td>\n",
       "      <td>-402.209106</td>\n",
       "      <td>-168.273468</td>\n",
       "      <td>-555.842773</td>\n",
       "      <td>891.647583</td>\n",
       "      <td>-67.057220</td>\n",
       "      <td>627.107422</td>\n",
       "      <td>671.225708</td>\n",
       "      <td>-18.159908</td>\n",
       "      <td>909.050842</td>\n",
       "      <td>211.252045</td>\n",
       "      <td>-273.003204</td>\n",
       "      <td>-658.565552</td>\n",
       "      <td>751.012024</td>\n",
       "      <td>-216.188477</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>998.074768</td>\n",
       "      <td>-0.177731</td>\n",
       "      <td>0.658285</td>\n",
       "      <td>354.924988</td>\n",
       "      <td>-402.209320</td>\n",
       "      <td>-168.273605</td>\n",
       "      <td>-555.842957</td>\n",
       "      <td>891.648560</td>\n",
       "      <td>-67.056885</td>\n",
       "      <td>627.107849</td>\n",
       "      <td>671.225891</td>\n",
       "      <td>-18.160280</td>\n",
       "      <td>909.063171</td>\n",
       "      <td>211.249832</td>\n",
       "      <td>-273.005371</td>\n",
       "      <td>-658.565430</td>\n",
       "      <td>751.012268</td>\n",
       "      <td>-216.188385</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           0         1         2           3           4           5   \\\n",
       "0  998.075134 -0.177327  0.658807  354.925232 -402.208954 -168.273041   \n",
       "1  998.074463 -0.177569  0.657873  354.924835 -402.209320 -168.272995   \n",
       "2  998.075134 -0.177123  0.658522  354.924835 -402.209259 -168.273544   \n",
       "3  998.074280 -0.177889  0.657544  354.923859 -402.209106 -168.273468   \n",
       "4  998.074768 -0.177731  0.658285  354.924988 -402.209320 -168.273605   \n",
       "\n",
       "           6           7          8           9           10         11  \\\n",
       "0 -555.842834  891.648010 -67.056458  627.107544  671.225708 -18.162128   \n",
       "1 -555.843140  891.648315 -67.056839  627.107361  671.225525 -18.160534   \n",
       "2 -555.842896  891.648193 -67.056206  627.107422  671.225525 -18.161379   \n",
       "3 -555.842773  891.647583 -67.057220  627.107422  671.225708 -18.159908   \n",
       "4 -555.842957  891.648560 -67.056885  627.107849  671.225891 -18.160280   \n",
       "\n",
       "           12          13          14          15          16          17  \n",
       "0  909.044434  211.253754 -273.006927 -658.565674  751.012024 -216.188736  \n",
       "1  909.059937  211.254059 -273.006042 -658.565857  751.012634 -216.187927  \n",
       "2  909.056152  211.252594 -273.009186 -658.565247  751.012634 -216.187775  \n",
       "3  909.050842  211.252045 -273.003204 -658.565552  751.012024 -216.188477  \n",
       "4  909.063171  211.249832 -273.005371 -658.565430  751.012268 -216.188385  "
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "extended[1].head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([  -12.3887677 , -5003.65018698, -4999.47151635, -3228.13879822,\n",
       "       -7013.80850647, -5844.12920044, -7781.97714661,  -544.52188538,\n",
       "       -5338.04615448, -1867.22494934, -1646.63419007, -5093.56677674,\n",
       "        -457.48801087, -3946.50026367, -6367.7932782 , -8295.59030579,\n",
       "       -1247.70096253, -6083.70384644])"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('vaca_cf.pickle', 'rb') as handle:\n",
    "    data_loader = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>10</th>\n",
       "      <th>11</th>\n",
       "      <th>12</th>\n",
       "      <th>13</th>\n",
       "      <th>14</th>\n",
       "      <th>15</th>\n",
       "      <th>16</th>\n",
       "      <th>17</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>998.075134</td>\n",
       "      <td>-0.177327</td>\n",
       "      <td>0.658807</td>\n",
       "      <td>354.925232</td>\n",
       "      <td>-402.208954</td>\n",
       "      <td>-168.273041</td>\n",
       "      <td>-555.842834</td>\n",
       "      <td>891.648010</td>\n",
       "      <td>-67.056458</td>\n",
       "      <td>627.107544</td>\n",
       "      <td>671.225708</td>\n",
       "      <td>-18.162128</td>\n",
       "      <td>909.044434</td>\n",
       "      <td>211.253754</td>\n",
       "      <td>-273.006927</td>\n",
       "      <td>-658.565674</td>\n",
       "      <td>751.012024</td>\n",
       "      <td>-216.188736</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>998.074463</td>\n",
       "      <td>-0.177569</td>\n",
       "      <td>0.657873</td>\n",
       "      <td>354.924835</td>\n",
       "      <td>-402.209320</td>\n",
       "      <td>-168.272995</td>\n",
       "      <td>-555.843140</td>\n",
       "      <td>891.648315</td>\n",
       "      <td>-67.056839</td>\n",
       "      <td>627.107361</td>\n",
       "      <td>671.225525</td>\n",
       "      <td>-18.160534</td>\n",
       "      <td>909.059937</td>\n",
       "      <td>211.254059</td>\n",
       "      <td>-273.006042</td>\n",
       "      <td>-658.565857</td>\n",
       "      <td>751.012634</td>\n",
       "      <td>-216.187927</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>998.075134</td>\n",
       "      <td>-0.177123</td>\n",
       "      <td>0.658522</td>\n",
       "      <td>354.924835</td>\n",
       "      <td>-402.209259</td>\n",
       "      <td>-168.273544</td>\n",
       "      <td>-555.842896</td>\n",
       "      <td>891.648193</td>\n",
       "      <td>-67.056206</td>\n",
       "      <td>627.107422</td>\n",
       "      <td>671.225525</td>\n",
       "      <td>-18.161379</td>\n",
       "      <td>909.056152</td>\n",
       "      <td>211.252594</td>\n",
       "      <td>-273.009186</td>\n",
       "      <td>-658.565247</td>\n",
       "      <td>751.012634</td>\n",
       "      <td>-216.187775</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>998.074280</td>\n",
       "      <td>-0.177889</td>\n",
       "      <td>0.657544</td>\n",
       "      <td>354.923859</td>\n",
       "      <td>-402.209106</td>\n",
       "      <td>-168.273468</td>\n",
       "      <td>-555.842773</td>\n",
       "      <td>891.647583</td>\n",
       "      <td>-67.057220</td>\n",
       "      <td>627.107422</td>\n",
       "      <td>671.225708</td>\n",
       "      <td>-18.159908</td>\n",
       "      <td>909.050842</td>\n",
       "      <td>211.252045</td>\n",
       "      <td>-273.003204</td>\n",
       "      <td>-658.565552</td>\n",
       "      <td>751.012024</td>\n",
       "      <td>-216.188477</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>998.074768</td>\n",
       "      <td>-0.177731</td>\n",
       "      <td>0.658285</td>\n",
       "      <td>354.924988</td>\n",
       "      <td>-402.209320</td>\n",
       "      <td>-168.273605</td>\n",
       "      <td>-555.842957</td>\n",
       "      <td>891.648560</td>\n",
       "      <td>-67.056885</td>\n",
       "      <td>627.107849</td>\n",
       "      <td>671.225891</td>\n",
       "      <td>-18.160280</td>\n",
       "      <td>909.063171</td>\n",
       "      <td>211.249832</td>\n",
       "      <td>-273.005371</td>\n",
       "      <td>-658.565430</td>\n",
       "      <td>751.012268</td>\n",
       "      <td>-216.188385</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           0         1         2           3           4           5   \\\n",
       "0  998.075134 -0.177327  0.658807  354.925232 -402.208954 -168.273041   \n",
       "1  998.074463 -0.177569  0.657873  354.924835 -402.209320 -168.272995   \n",
       "2  998.075134 -0.177123  0.658522  354.924835 -402.209259 -168.273544   \n",
       "3  998.074280 -0.177889  0.657544  354.923859 -402.209106 -168.273468   \n",
       "4  998.074768 -0.177731  0.658285  354.924988 -402.209320 -168.273605   \n",
       "\n",
       "           6           7          8           9           10         11  \\\n",
       "0 -555.842834  891.648010 -67.056458  627.107544  671.225708 -18.162128   \n",
       "1 -555.843140  891.648315 -67.056839  627.107361  671.225525 -18.160534   \n",
       "2 -555.842896  891.648193 -67.056206  627.107422  671.225525 -18.161379   \n",
       "3 -555.842773  891.647583 -67.057220  627.107422  671.225708 -18.159908   \n",
       "4 -555.842957  891.648560 -67.056885  627.107849  671.225891 -18.160280   \n",
       "\n",
       "           12          13          14          15          16          17  \n",
       "0  909.044434  211.253754 -273.006927 -658.565674  751.012024 -216.188736  \n",
       "1  909.059937  211.254059 -273.006042 -658.565857  751.012634 -216.187927  \n",
       "2  909.056152  211.252594 -273.009186 -658.565247  751.012634 -216.187775  \n",
       "3  909.050842  211.252045 -273.003204 -658.565552  751.012024 -216.188477  \n",
       "4  909.063171  211.249832 -273.005371 -658.565430  751.012268 -216.188385  "
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "extended[1].head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5 3\n"
     ]
    }
   ],
   "source": [
    "a = 3\n",
    "b = a\n",
    "a = 5\n",
    "print(a,b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ObservableValue:\n",
    "    def __init__(self, value):\n",
    "        self._value = value\n",
    "        self._observers = []\n",
    "\n",
    "    @property\n",
    "    def value(self):\n",
    "        return self._value\n",
    "\n",
    "    @value.setter\n",
    "    def value(self, new_value):\n",
    "        if new_value != self._value:\n",
    "            self._value = new_value\n",
    "            self._notify_observers()\n",
    "\n",
    "    def _notify_observers(self):\n",
    "        for observer in self._observers:\n",
    "            observer(self._value)\n",
    "\n",
    "    def add_observer(self, observer):\n",
    "        self._observers.append(observer)\n",
    "\n",
    "def on_value_changed(new_value):\n",
    "    print(f\"Value changed to: {new_value}\")\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.4 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "40d3a090f54c6569ab1632332b64b2c03c39dcf918b08424e98f38b5ae0af88f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
