{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RESIT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import and settings\n",
    "In this example, we need to import `numpy`, `pandas`, and `graphviz` in addition to `lingam`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T06:59:57.385770Z",
     "start_time": "2021-06-25T06:59:57.370810Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['1.21.5', '1.3.2', '0.17', '1.6.0']\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import graphviz\n",
    "import lingam\n",
    "from lingam.utils import print_causal_directions, print_dagc, make_dot\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "print([np.__version__, pd.__version__, graphviz.__version__, lingam.__version__])\n",
    "\n",
    "np.set_printoptions(precision=3, suppress=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test data\n",
    "First, we generate a causal structure with 7 variables. Then we create a dataset with 6 variables from x0 to x5, with x6 being the latent variable for x2 and x3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T06:59:59.105787Z",
     "start_time": "2021-06-25T06:59:59.059854Z"
    }
   },
   "outputs": [],
   "source": [
    "X = pd.read_csv('nonlinear_data.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T07:00:01.201484Z",
     "start_time": "2021-06-25T07:00:00.198943Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: %3 Pages: 1 -->\r\n",
       "<svg width=\"122pt\" height=\"392pt\"\r\n",
       " viewBox=\"0.00 0.00 122.00 392.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 388)\">\r\n",
       "<title>%3</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-388 118,-388 118,4 -4,4\"/>\r\n",
       "<!-- x0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>x0</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"69\" cy=\"-366\" rx=\"27\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"69\" y=\"-362.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x0</text>\r\n",
       "</g>\r\n",
       "<!-- x1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>x1</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-279\" rx=\"27\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-275.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x1</text>\r\n",
       "</g>\r\n",
       "<!-- x0&#45;&gt;x1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>x0&#45;&gt;x1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M60.4819,-348.666C57.453,-342.798 54.0412,-336.119 51,-330 47.0392,-322.031 42.8102,-313.299 39.0228,-305.394\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"42.1342,-303.787 34.6699,-296.268 35.8161,-306.801 42.1342,-303.787\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"63.5\" y=\"-318.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.00</text>\r\n",
       "</g>\r\n",
       "<!-- x2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>x2</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"87\" cy=\"-192\" rx=\"27\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"87\" y=\"-188.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x2</text>\r\n",
       "</g>\r\n",
       "<!-- x0&#45;&gt;x2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>x0&#45;&gt;x2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M72.8209,-347.979C73.9871,-342.286 75.1772,-335.892 76,-330 81.2736,-292.239 84.2494,-248.242 85.7413,-220.506\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"89.2442,-220.538 86.2589,-210.373 82.2533,-220.181 89.2442,-220.538\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"95.5\" y=\"-275.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.00</text>\r\n",
       "</g>\r\n",
       "<!-- x1&#45;&gt;x2 -->\r\n",
       "<g id=\"edge3\" class=\"edge\"><title>x1&#45;&gt;x2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M33.2143,-261.434C37.4299,-251.289 43.5789,-238.318 51,-228 54.5866,-223.013 58.9265,-218.12 63.3331,-213.646\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"65.9085,-216.025 70.6911,-206.571 61.0568,-210.979 65.9085,-216.025\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"63.5\" y=\"-231.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.00</text>\r\n",
       "</g>\r\n",
       "<!-- x3 -->\r\n",
       "<g id=\"node4\" class=\"node\"><title>x3</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"42\" cy=\"-105\" rx=\"27\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"42\" y=\"-101.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x3</text>\r\n",
       "</g>\r\n",
       "<!-- x1&#45;&gt;x3 -->\r\n",
       "<g id=\"edge4\" class=\"edge\"><title>x1&#45;&gt;x3</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M25.5391,-260.998C24.0633,-240.348 22.4611,-204.491 26,-174 27.6058,-160.165 30.9908,-145.058 34.2356,-132.604\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"37.6403,-133.422 36.886,-122.854 30.8854,-131.586 37.6403,-133.422\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"38.5\" y=\"-188.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.00</text>\r\n",
       "</g>\r\n",
       "<!-- x2&#45;&gt;x3 -->\r\n",
       "<g id=\"edge5\" class=\"edge\"><title>x2&#45;&gt;x3</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M78.3235,-174.611C71.8492,-162.382 62.8815,-145.443 55.4767,-131.456\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"58.422,-129.539 50.6498,-122.339 52.2355,-132.814 58.422,-129.539\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"79.5\" y=\"-144.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.00</text>\r\n",
       "</g>\r\n",
       "<!-- x4 -->\r\n",
       "<g id=\"node5\" class=\"node\"><title>x4</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"42\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"42\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x4</text>\r\n",
       "</g>\r\n",
       "<!-- x3&#45;&gt;x4 -->\r\n",
       "<g id=\"edge6\" class=\"edge\"><title>x3&#45;&gt;x4</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M42,-86.799C42,-75.1626 42,-59.5479 42,-46.2368\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"45.5001,-46.1754 42,-36.1754 38.5001,-46.1755 45.5001,-46.1754\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"54.5\" y=\"-57.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.00</text>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.dot.Digraph at 0x201a773c848>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = np.array([\n",
    "    [0, 0, 0, 0, 0],\n",
    "    [1, 0, 0, 0, 0],\n",
    "    [1, 1, 0, 0, 0],\n",
    "    [0, 1, 1, 0, 0],\n",
    "    [0, 0, 0, 1, 0]])\n",
    "\n",
    "dot = make_dot(m)\n",
    "\n",
    "# Save pdf\n",
    "dot.render('dag')\n",
    "\n",
    "# Save png\n",
    "dot.format = 'png'\n",
    "dot.render('dag')\n",
    "\n",
    "dot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Causal Discovery\n",
    "To run causal discovery, we create a `RESIT` object and call the `fit` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T07:00:31.798259Z",
     "start_time": "2021-06-25T07:00:14.378602Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<lingam.resit.RESIT at 0x201a773c548>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.ensemble import RandomForestRegressor\n",
    "reg = RandomForestRegressor(max_depth=4, random_state=0)\n",
    "\n",
    "model = lingam.RESIT(regressor=reg)\n",
    "model.fit(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using the `causal_order_` properties, we can see the causal ordering as a result of the causal discovery. x2 and x3, which have latent confounders as parents, are stored in a list without causal ordering."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T07:00:31.814217Z",
     "start_time": "2021-06-25T07:00:31.801251Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 1, 2, 3, 4]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.causal_order_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-09-09T01:24:30.429100Z",
     "start_time": "2019-09-09T01:24:30.422118Z"
    }
   },
   "source": [
    "Also, using the `adjacency_matrix_` properties, we can see the adjacency matrix as a result of the causal discovery. The coefficients between variables with latent confounders are np.nan."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T07:00:31.855308Z",
     "start_time": "2021-06-25T07:00:31.818205Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0., 0., 0., 0., 0.],\n",
       "       [1., 0., 0., 0., 0.],\n",
       "       [0., 1., 0., 0., 0.],\n",
       "       [1., 1., 0., 0., 0.],\n",
       "       [0., 0., 0., 1., 0.]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.adjacency_matrix_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can draw a causal graph by utility funciton."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T07:00:32.302957Z",
     "start_time": "2021-06-25T07:00:31.855308Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: %3 Pages: 1 -->\r\n",
       "<svg width=\"142pt\" height=\"305pt\"\r\n",
       " viewBox=\"0.00 0.00 142.00 305.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 301)\">\r\n",
       "<title>%3</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-301 138,-301 138,4 -4,4\"/>\r\n",
       "<!-- x0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>x0</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"94\" cy=\"-279\" rx=\"27\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"94\" y=\"-275.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x0</text>\r\n",
       "</g>\r\n",
       "<!-- x1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>x1</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"52\" cy=\"-192\" rx=\"27\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"52\" y=\"-188.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x1</text>\r\n",
       "</g>\r\n",
       "<!-- x0&#45;&gt;x1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>x0&#45;&gt;x1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M85.4819,-261.666C82.453,-255.798 79.0412,-249.119 76,-243 72.0392,-235.031 67.8102,-226.299 64.0228,-218.394\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"67.1342,-216.787 59.6699,-209.268 60.8161,-219.801 67.1342,-216.787\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"88.5\" y=\"-231.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.00</text>\r\n",
       "</g>\r\n",
       "<!-- x3 -->\r\n",
       "<g id=\"node4\" class=\"node\"><title>x3</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"107\" cy=\"-105\" rx=\"27\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"107\" y=\"-101.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x3</text>\r\n",
       "</g>\r\n",
       "<!-- x0&#45;&gt;x3 -->\r\n",
       "<g id=\"edge3\" class=\"edge\"><title>x0&#45;&gt;x3</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M97.9466,-260.995C99.1211,-255.304 100.284,-248.906 101,-243 105.576,-205.234 106.777,-161.239 107.034,-133.504\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"110.534,-133.391 107.091,-123.371 103.534,-133.351 110.534,-133.391\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"119.5\" y=\"-188.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.00</text>\r\n",
       "</g>\r\n",
       "<!-- x2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>x2</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-105\" rx=\"27\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-101.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x2</text>\r\n",
       "</g>\r\n",
       "<!-- x1&#45;&gt;x2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>x1&#45;&gt;x2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M38.8871,-176.01C34.5729,-170.165 30.2964,-163.156 28,-156 25.6919,-148.808 24.837,-140.729 24.7143,-133.198\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"28.2165,-133.143 24.9531,-123.064 21.2184,-132.978 28.2165,-133.143\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"40.5\" y=\"-144.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.00</text>\r\n",
       "</g>\r\n",
       "<!-- x1&#45;&gt;x3 -->\r\n",
       "<g id=\"edge4\" class=\"edge\"><title>x1&#45;&gt;x3</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M58.6864,-174.164C62.9947,-164.158 69.0755,-151.428 76,-141 78.9005,-136.632 82.3109,-132.24 85.7892,-128.114\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"88.4446,-130.395 92.4457,-120.585 83.2001,-125.759 88.4446,-130.395\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"88.5\" y=\"-144.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.00</text>\r\n",
       "</g>\r\n",
       "<!-- x4 -->\r\n",
       "<g id=\"node5\" class=\"node\"><title>x4</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"107\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"107\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x4</text>\r\n",
       "</g>\r\n",
       "<!-- x3&#45;&gt;x4 -->\r\n",
       "<g id=\"edge5\" class=\"edge\"><title>x3&#45;&gt;x4</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M107,-86.799C107,-75.1626 107,-59.5479 107,-46.2368\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"110.5,-46.1754 107,-36.1754 103.5,-46.1755 110.5,-46.1754\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"119.5\" y=\"-57.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.00</text>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.dot.Digraph at 0x201a7905988>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_dot(model.adjacency_matrix_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bootstrapping\n",
    "We call `bootstrap()` method instead of `fit()`. Here, the second argument specifies the number of bootstrap sampling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T06:24:17.805729Z",
     "start_time": "2021-06-25T06:05:28.262570Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore', category=UserWarning)\n",
    "\n",
    "n_sampling = 100\n",
    "model = lingam.RESIT(regressor=reg)\n",
    "result = model.bootstrap(X, n_sampling=n_sampling)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Causal Directions\n",
    "Since `BootstrapResult` object is returned, we can get the ranking of the causal directions extracted by `get_causal_direction_counts()` method. In the following sample code, `n_directions` option is limited to the causal directions of the top 8 rankings, and `min_causal_effect` option is limited to causal directions with a coefficient of 0.01 or more."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T06:24:17.823861Z",
     "start_time": "2021-06-25T06:24:17.805729Z"
    }
   },
   "outputs": [],
   "source": [
    "cdc = result.get_causal_direction_counts(n_directions=8, min_causal_effect=0.01, split_by_causal_effect_sign=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can check the result by utility function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T06:24:17.838991Z",
     "start_time": "2021-06-25T06:24:17.823861Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x1 <--- x0 (b>0) (100.0%)\n",
      "x2 <--- x1 (b>0) (71.0%)\n",
      "x4 <--- x1 (b>0) (62.0%)\n",
      "x2 <--- x0 (b>0) (62.0%)\n",
      "x3 <--- x1 (b>0) (53.0%)\n",
      "x3 <--- x4 (b>0) (52.0%)\n",
      "x4 <--- x3 (b>0) (47.0%)\n",
      "x3 <--- x0 (b>0) (44.0%)\n"
     ]
    }
   ],
   "source": [
    "print_causal_directions(cdc, n_sampling)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Directed Acyclic Graphs\n",
    "Also, using the `get_directed_acyclic_graph_counts()` method, we can get the ranking of the DAGs extracted. In the following sample code, `n_dags` option is limited to the dags of the top 3 rankings, and `min_causal_effect` option is limited to causal directions with a coefficient of 0.01 or more."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T06:24:17.864592Z",
     "start_time": "2021-06-25T06:24:17.841984Z"
    }
   },
   "outputs": [],
   "source": [
    "dagc = result.get_directed_acyclic_graph_counts(n_dags=3, min_causal_effect=0.01, split_by_causal_effect_sign=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can check the result by utility function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T06:24:17.879908Z",
     "start_time": "2021-06-25T06:24:17.864592Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DAG[0]: 13.0%\n",
      "\tx1 <--- x0 (b>0)\n",
      "\tx2 <--- x1 (b>0)\n",
      "\tx3 <--- x4 (b>0)\n",
      "\tx4 <--- x0 (b>0)\n",
      "\tx4 <--- x1 (b>0)\n",
      "DAG[1]: 13.0%\n",
      "\tx1 <--- x0 (b>0)\n",
      "\tx2 <--- x0 (b>0)\n",
      "\tx2 <--- x1 (b>0)\n",
      "\tx3 <--- x4 (b>0)\n",
      "\tx4 <--- x1 (b>0)\n",
      "DAG[2]: 11.0%\n",
      "\tx1 <--- x0 (b>0)\n",
      "\tx2 <--- x1 (b>0)\n",
      "\tx3 <--- x0 (b>0)\n",
      "\tx3 <--- x1 (b>0)\n",
      "\tx4 <--- x3 (b>0)\n"
     ]
    }
   ],
   "source": [
    "print_dagc(dagc, n_sampling)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Probability\n",
    "Using the `get_probabilities()` method, we can get the probability of bootstrapping."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T06:24:17.894642Z",
     "start_time": "2021-06-25T06:24:17.881902Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.   0.   0.   0.02 0.  ]\n",
      " [1.   0.   0.07 0.05 0.01]\n",
      " [0.62 0.71 0.   0.06 0.03]\n",
      " [0.44 0.53 0.18 0.   0.52]\n",
      " [0.43 0.62 0.21 0.47 0.  ]]\n"
     ]
    }
   ],
   "source": [
    "prob = result.get_probabilities(min_causal_effect=0.01)\n",
    "print(prob)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T04:27:45.949921Z",
     "start_time": "2021-06-25T04:27:45.907243Z"
    }
   },
   "source": [
    "## Bootstrap Probability of Path\n",
    "Using the `get_paths()` method, we can explore all paths from any variable to any variable and calculate the bootstrap probability for each path. The path will be output as an array of variable indices. For example, the array `[0, 1, 3]` shows the path from variable X0 through variable X1 to variable X3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-06-25T06:24:19.797821Z",
     "start_time": "2021-06-25T06:24:19.764919Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>path</th>\n",
       "      <th>effect</th>\n",
       "      <th>probability</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[0, 1, 3]</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.53</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[0, 1, 4, 3]</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.51</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[0, 3]</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[0, 4, 3]</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.33</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[0, 2, 3]</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>[0, 1, 2, 3]</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>[0, 2, 4, 3]</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.07</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>[0, 1, 2, 4, 3]</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.04</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>[0, 1, 4, 2, 3]</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>[0, 2, 1, 3]</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>[0, 4, 1, 3]</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.01</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               path  effect  probability\n",
       "0         [0, 1, 3]     1.0         0.53\n",
       "1      [0, 1, 4, 3]     1.0         0.51\n",
       "2            [0, 3]     1.0         0.44\n",
       "3         [0, 4, 3]     1.0         0.33\n",
       "4         [0, 2, 3]     1.0         0.12\n",
       "5      [0, 1, 2, 3]     1.0         0.11\n",
       "6      [0, 2, 4, 3]     1.0         0.07\n",
       "7   [0, 1, 2, 4, 3]     1.0         0.04\n",
       "8   [0, 1, 4, 2, 3]     1.0         0.03\n",
       "9      [0, 2, 1, 3]     1.0         0.01\n",
       "10     [0, 4, 1, 3]     1.0         0.01"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from_index = 0 # index of x0\n",
    "to_index = 3 # index of x3\n",
    "\n",
    "pd.DataFrame(result.get_paths(from_index, to_index))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
