{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Causal Effect for Logistic Regression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import and settings\n",
    "In this example, we need to import `numpy`, `pandas`, and `graphviz` in addition to `lingam`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['1.16.2', '0.24.2', '0.11.1', '1.4.1']\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import graphviz\n",
    "import lingam\n",
    "from lingam.utils import make_prior_knowledge\n",
    "\n",
    "print([np.__version__, pd.__version__, graphviz.__version__, lingam.__version__])\n",
    "\n",
    "np.set_printoptions(precision=3, suppress=True)\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utility function\n",
    "We define a utility function to draw the directed acyclic graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_graph(adjacency_matrix, labels=None):\n",
    "    idx = np.abs(adjacency_matrix) > 0.01\n",
    "    dirs = np.where(idx)\n",
    "    d = graphviz.Digraph(engine='dot')\n",
    "    names = labels if labels else [f'x{i}' for i in range(len(adjacency_matrix))]\n",
    "    for to, from_, coef in zip(dirs[0], dirs[1], adjacency_matrix[idx]):\n",
    "        d.edge(names[from_], names[to], label=f'{coef:.2f}')\n",
    "    return d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test data\n",
    "We use 'Wine Quality Data Set' (https://archive.ics.uci.edu/ml/datasets/Wine+Quality)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1599, 12)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fixed acidity</th>\n",
       "      <th>volatile acidity</th>\n",
       "      <th>citric acid</th>\n",
       "      <th>residual sugar</th>\n",
       "      <th>chlorides</th>\n",
       "      <th>free sulfur dioxide</th>\n",
       "      <th>total sulfur dioxide</th>\n",
       "      <th>density</th>\n",
       "      <th>pH</th>\n",
       "      <th>sulphates</th>\n",
       "      <th>alcohol</th>\n",
       "      <th>quality</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>7.4</td>\n",
       "      <td>0.70</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.9</td>\n",
       "      <td>0.076</td>\n",
       "      <td>11.0</td>\n",
       "      <td>34.0</td>\n",
       "      <td>0.9978</td>\n",
       "      <td>3.51</td>\n",
       "      <td>0.56</td>\n",
       "      <td>9.4</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>7.8</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.00</td>\n",
       "      <td>2.6</td>\n",
       "      <td>0.098</td>\n",
       "      <td>25.0</td>\n",
       "      <td>67.0</td>\n",
       "      <td>0.9968</td>\n",
       "      <td>3.20</td>\n",
       "      <td>0.68</td>\n",
       "      <td>9.8</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7.8</td>\n",
       "      <td>0.76</td>\n",
       "      <td>0.04</td>\n",
       "      <td>2.3</td>\n",
       "      <td>0.092</td>\n",
       "      <td>15.0</td>\n",
       "      <td>54.0</td>\n",
       "      <td>0.9970</td>\n",
       "      <td>3.26</td>\n",
       "      <td>0.65</td>\n",
       "      <td>9.8</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>11.2</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.56</td>\n",
       "      <td>1.9</td>\n",
       "      <td>0.075</td>\n",
       "      <td>17.0</td>\n",
       "      <td>60.0</td>\n",
       "      <td>0.9980</td>\n",
       "      <td>3.16</td>\n",
       "      <td>0.58</td>\n",
       "      <td>9.8</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>7.4</td>\n",
       "      <td>0.70</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.9</td>\n",
       "      <td>0.076</td>\n",
       "      <td>11.0</td>\n",
       "      <td>34.0</td>\n",
       "      <td>0.9978</td>\n",
       "      <td>3.51</td>\n",
       "      <td>0.56</td>\n",
       "      <td>9.4</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   fixed acidity  volatile acidity  citric acid  residual sugar  chlorides  \\\n",
       "0            7.4              0.70         0.00             1.9      0.076   \n",
       "1            7.8              0.88         0.00             2.6      0.098   \n",
       "2            7.8              0.76         0.04             2.3      0.092   \n",
       "3           11.2              0.28         0.56             1.9      0.075   \n",
       "4            7.4              0.70         0.00             1.9      0.076   \n",
       "\n",
       "   free sulfur dioxide  total sulfur dioxide  density    pH  sulphates  \\\n",
       "0                 11.0                  34.0   0.9978  3.51       0.56   \n",
       "1                 25.0                  67.0   0.9968  3.20       0.68   \n",
       "2                 15.0                  54.0   0.9970  3.26       0.65   \n",
       "3                 17.0                  60.0   0.9980  3.16       0.58   \n",
       "4                 11.0                  34.0   0.9978  3.51       0.56   \n",
       "\n",
       "   alcohol  quality  \n",
       "0      9.4        0  \n",
       "1      9.8        0  \n",
       "2      9.8        0  \n",
       "3      9.8        1  \n",
       "4      9.4        0  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv', sep=';')\n",
    "X['quality'] = np.where(X['quality']>5, 1, 0)\n",
    "print(X.shape)\n",
    "X.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Causal Discovery\n",
    "To run causal discovery, we create a `DirectLiNGAM` object and call the `fit` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: %3 Pages: 1 -->\r\n",
       "<svg width=\"904pt\" height=\"566pt\"\r\n",
       " viewBox=\"0.00 0.00 903.79 566.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 562)\">\r\n",
       "<title>%3</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-562 899.792,-562 899.792,4 -4,4\"/>\r\n",
       "<!-- 2. citric acid -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>2. citric acid</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"347\" cy=\"-366\" rx=\"55.7903\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"347\" y=\"-362.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">2. citric acid</text>\r\n",
       "</g>\r\n",
       "<!-- 1. volatile acidity -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1. volatile acidity</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"202\" cy=\"-279\" rx=\"71.4873\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"202\" y=\"-275.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1. volatile acidity</text>\r\n",
       "</g>\r\n",
       "<!-- 2. citric acid&#45;&gt;1. volatile acidity -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>2. citric acid&#45;&gt;1. volatile acidity</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M321.077,-349.804C297.905,-336.22 263.725,-316.183 237.921,-301.057\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"239.572,-297.968 229.175,-295.93 236.032,-304.007 239.572,-297.968\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"298.5\" y=\"-318.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;0.62</text>\r\n",
       "</g>\r\n",
       "<!-- 9. sulphates -->\r\n",
       "<g id=\"node10\" class=\"node\"><title>9. sulphates</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"186\" cy=\"-192\" rx=\"53.8905\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"186\" y=\"-188.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">9. sulphates</text>\r\n",
       "</g>\r\n",
       "<!-- 2. citric acid&#45;&gt;9. sulphates -->\r\n",
       "<g id=\"edge18\" class=\"edge\"><title>2. citric acid&#45;&gt;9. sulphates</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M348.192,-347.67C349.173,-318.817 346.719,-260.865 315,-228 313.541,-226.488 272.417,-215.478 236.991,-206.204\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"237.748,-202.784 227.188,-203.644 235.979,-209.557 237.748,-202.784\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"358.5\" y=\"-275.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.09</text>\r\n",
       "</g>\r\n",
       "<!-- 10. alcohol -->\r\n",
       "<g id=\"node11\" class=\"node\"><title>10. alcohol</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"400\" cy=\"-105\" rx=\"50.8918\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"400\" y=\"-101.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">10. alcohol</text>\r\n",
       "</g>\r\n",
       "<!-- 2. citric acid&#45;&gt;10. alcohol -->\r\n",
       "<g id=\"edge23\" class=\"edge\"><title>2. citric acid&#45;&gt;10. alcohol</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M355.485,-348.121C361.888,-334.627 370.339,-315.048 375,-297 389.764,-239.826 396.055,-170.658 398.537,-133.24\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"402.039,-133.327 399.166,-123.129 395.052,-132.893 402.039,-133.327\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"401.5\" y=\"-231.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.83</text>\r\n",
       "</g>\r\n",
       "<!-- 1. volatile acidity&#45;&gt;9. sulphates -->\r\n",
       "<g id=\"edge17\" class=\"edge\"><title>1. volatile acidity&#45;&gt;9. sulphates</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M198.762,-260.799C196.572,-249.163 193.633,-233.548 191.127,-220.237\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"194.523,-219.355 189.233,-210.175 187.643,-220.65 194.523,-219.355\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"209.5\" y=\"-231.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;0.20</text>\r\n",
       "</g>\r\n",
       "<!-- 1. volatile acidity&#45;&gt;10. alcohol -->\r\n",
       "<g id=\"edge22\" class=\"edge\"><title>1. volatile acidity&#45;&gt;10. alcohol</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M220.985,-261.508C256.647,-230.529 333.683,-163.608 374.081,-128.516\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"376.395,-131.142 381.649,-121.941 371.804,-125.857 376.395,-131.142\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"333.5\" y=\"-188.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.36</text>\r\n",
       "</g>\r\n",
       "<!-- 11. quality -->\r\n",
       "<g id=\"node12\" class=\"node\"><title>11. quality</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"88\" cy=\"-18\" rx=\"48.1917\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"88\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">11. quality</text>\r\n",
       "</g>\r\n",
       "<!-- 1. volatile acidity&#45;&gt;11. quality -->\r\n",
       "<g id=\"edge29\" class=\"edge\"><title>1. volatile acidity&#45;&gt;11. quality</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M168.135,-263.128C130.139,-244.352 70.4903,-208.089 45,-156 26.8072,-118.823 50.2997,-72.2429 69.0591,-44.0942\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"72.1011,-45.8496 74.9247,-35.638 66.3493,-41.8599 72.1011,-45.8496\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"59.5\" y=\"-144.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;0.50</text>\r\n",
       "</g>\r\n",
       "<!-- 4. chlorides -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>4. chlorides</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"121\" cy=\"-453\" rx=\"53.0913\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"121\" y=\"-449.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">4. chlorides</text>\r\n",
       "</g>\r\n",
       "<!-- 4. chlorides&#45;&gt;2. citric acid -->\r\n",
       "<g id=\"edge6\" class=\"edge\"><title>4. chlorides&#45;&gt;2. citric acid</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M161.436,-441.234C183.93,-434.889 212.291,-426.296 237,-417 261.348,-407.84 287.918,-395.822 308.845,-385.864\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"310.535,-388.935 318.035,-381.451 307.505,-382.625 310.535,-388.935\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"283.5\" y=\"-405.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.70</text>\r\n",
       "</g>\r\n",
       "<!-- 4. chlorides&#45;&gt;1. volatile acidity -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>4. chlorides&#45;&gt;1. volatile acidity</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M127.137,-434.756C134.759,-413.867 148.534,-377.775 163,-348 169.987,-333.619 178.775,-318.116 186.239,-305.549\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"189.241,-307.35 191.396,-296.976 183.242,-303.742 189.241,-307.35\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"175.5\" y=\"-362.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.51</text>\r\n",
       "</g>\r\n",
       "<!-- 8. pH -->\r\n",
       "<g id=\"node9\" class=\"node\"><title>8. pH</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"452\" cy=\"-366\" rx=\"31.3957\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"452\" y=\"-362.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">8. pH</text>\r\n",
       "</g>\r\n",
       "<!-- 4. chlorides&#45;&gt;8. pH -->\r\n",
       "<g id=\"edge15\" class=\"edge\"><title>4. chlorides&#45;&gt;8. pH</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M167.62,-444.117C203.897,-437.686 255.4,-427.906 300,-417 325.204,-410.837 387.575,-392.756 412,-384 414.217,-383.205 416.489,-382.351 418.768,-381.465\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"420.108,-384.698 428.065,-377.703 417.482,-378.21 420.108,-384.698\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"368.5\" y=\"-405.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;0.76</text>\r\n",
       "</g>\r\n",
       "<!-- 4. chlorides&#45;&gt;9. sulphates -->\r\n",
       "<g id=\"edge19\" class=\"edge\"><title>4. chlorides&#45;&gt;9. sulphates</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M117.056,-434.862C109.985,-400.437 98.0052,-321.046 122,-261 129.141,-243.13 143.42,-227.212 156.682,-215.278\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"159.004,-217.897 164.311,-208.727 154.444,-212.586 159.004,-217.897\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"122.5\" y=\"-318.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.24</text>\r\n",
       "</g>\r\n",
       "<!-- 4. chlorides&#45;&gt;10. alcohol -->\r\n",
       "<g id=\"edge25\" class=\"edge\"><title>4. chlorides&#45;&gt;10. alcohol</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M114.216,-435.057C112.071,-429.37 109.794,-422.959 108,-417 92.2119,-364.559 86.2633,-351.407 80,-297 73.377,-239.469 79.9081,-212.688 123,-174 155.09,-145.189 271.749,-124.104 343.039,-113.569\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"343.819,-116.992 353.212,-112.091 342.813,-110.065 343.819,-116.992\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"94.5\" y=\"-275.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;1.46</text>\r\n",
       "</g>\r\n",
       "<!-- 4. chlorides&#45;&gt;11. quality -->\r\n",
       "<g id=\"edge30\" class=\"edge\"><title>4. chlorides&#45;&gt;11. quality</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M79.6971,-441.443C44.6443,-429.705 0,-406.92 0,-367 0,-367 0,-367 0,-104 0,-74.4294 26.024,-51.8835 49.7679,-37.4669\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"51.7665,-40.3553 58.6973,-32.3419 48.282,-34.2842 51.7665,-40.3553\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"14.5\" y=\"-231.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;0.63</text>\r\n",
       "</g>\r\n",
       "<!-- 7. density -->\r\n",
       "<g id=\"node4\" class=\"node\"><title>7. density</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"271\" cy=\"-540\" rx=\"46.2923\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"271\" y=\"-536.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">7. density</text>\r\n",
       "</g>\r\n",
       "<!-- 7. density&#45;&gt;2. citric acid -->\r\n",
       "<g id=\"edge7\" class=\"edge\"><title>7. density&#45;&gt;2. citric acid</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M275.998,-522.063C282.394,-501.172 294.301,-464.77 308,-435 314.684,-420.475 323.447,-404.96 330.973,-392.423\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"333.966,-394.237 336.187,-383.877 327.991,-390.591 333.966,-394.237\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"326\" y=\"-449.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;29.64</text>\r\n",
       "</g>\r\n",
       "<!-- 7. density&#45;&gt;1. volatile acidity -->\r\n",
       "<g id=\"edge3\" class=\"edge\"><title>7. density&#45;&gt;1. volatile acidity</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M251.14,-523.407C244.998,-517.81 238.636,-511.111 234,-504 211.583,-469.614 210.139,-457.423 203,-417 196.354,-379.366 197.656,-335.006 199.548,-307.187\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"203.051,-307.273 200.315,-297.038 196.071,-306.746 203.051,-307.273\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"218.5\" y=\"-405.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">17.23</text>\r\n",
       "</g>\r\n",
       "<!-- 7. density&#45;&gt;4. chlorides -->\r\n",
       "<g id=\"edge9\" class=\"edge\"><title>7. density&#45;&gt;4. chlorides</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M230.119,-531.442C208.825,-526.153 183.107,-517.544 163,-504 153.309,-497.473 144.614,-488.097 137.684,-479.225\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"140.357,-476.952 131.604,-470.982 134.723,-481.107 140.357,-476.952\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"175.5\" y=\"-492.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">4.11</text>\r\n",
       "</g>\r\n",
       "<!-- 7. density&#45;&gt;8. pH -->\r\n",
       "<g id=\"edge16\" class=\"edge\"><title>7. density&#45;&gt;8. pH</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M288.271,-523.232C303.165,-509.557 325.182,-489.163 344,-471 360.281,-455.286 363.773,-450.77 380,-435 396.125,-419.329 414.566,-401.922 428.747,-388.644\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"431.23,-391.114 436.146,-381.729 426.45,-386 431.23,-391.114\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"395.5\" y=\"-449.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">27.46</text>\r\n",
       "</g>\r\n",
       "<!-- 7. density&#45;&gt;9. sulphates -->\r\n",
       "<g id=\"edge20\" class=\"edge\"><title>7. density&#45;&gt;9. sulphates</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M267.116,-522.021C260.035,-488.527 246.668,-411.67 257,-348 259.476,-332.74 260.705,-328.631 268,-315 272.782,-306.064 278.711,-306.586 282,-297 287.193,-281.866 289.349,-275.212 282,-261 270.529,-238.817 248.115,-222.312 227.829,-211.121\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"229.188,-207.882 218.704,-206.352 225.946,-214.086 229.188,-207.882\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"269.5\" y=\"-362.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">4.69</text>\r\n",
       "</g>\r\n",
       "<!-- 7. density&#45;&gt;10. alcohol -->\r\n",
       "<g id=\"edge26\" class=\"edge\"><title>7. density&#45;&gt;10. alcohol</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M314.792,-534.246C383.087,-525.828 509.958,-505.824 537,-471 568.991,-429.802 483.635,-228.707 430,-141 427.745,-137.313 425.107,-133.624 422.347,-130.098\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"424.898,-127.692 415.811,-122.245 419.518,-132.17 424.898,-127.692\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"539.5\" y=\"-318.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;617.38</text>\r\n",
       "</g>\r\n",
       "<!-- 0. fixed acidity -->\r\n",
       "<g id=\"node5\" class=\"node\"><title>0. fixed acidity</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"589\" cy=\"-540\" rx=\"64.189\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"589\" y=\"-536.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0. fixed acidity</text>\r\n",
       "</g>\r\n",
       "<!-- 0. fixed acidity&#45;&gt;2. citric acid -->\r\n",
       "<g id=\"edge4\" class=\"edge\"><title>0. fixed acidity&#45;&gt;2. citric acid</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M561.497,-523.407C538.669,-510.169 505.594,-490.299 478,-471 440.638,-444.869 399.829,-411.671 373.856,-389.904\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"375.871,-387.025 365.967,-383.262 371.362,-392.38 375.871,-387.025\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"490.5\" y=\"-449.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.09</text>\r\n",
       "</g>\r\n",
       "<!-- 3. residual sugar -->\r\n",
       "<g id=\"node6\" class=\"node\"><title>3. residual sugar</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"687\" cy=\"-453\" rx=\"69.5877\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"687\" y=\"-449.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">3. residual sugar</text>\r\n",
       "</g>\r\n",
       "<!-- 0. fixed acidity&#45;&gt;3. residual sugar -->\r\n",
       "<g id=\"edge8\" class=\"edge\"><title>0. fixed acidity&#45;&gt;3. residual sugar</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M607.895,-522.611C622.809,-509.676 643.796,-491.473 660.406,-477.066\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"662.902,-479.535 668.163,-470.339 658.315,-474.247 662.902,-479.535\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"656.5\" y=\"-492.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.09</text>\r\n",
       "</g>\r\n",
       "<!-- 5. free sulfur dioxide -->\r\n",
       "<g id=\"node7\" class=\"node\"><title>5. free sulfur dioxide</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"810\" cy=\"-366\" rx=\"85.5853\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"810\" y=\"-362.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">5. free sulfur dioxide</text>\r\n",
       "</g>\r\n",
       "<!-- 0. fixed acidity&#45;&gt;5. free sulfur dioxide -->\r\n",
       "<g id=\"edge10\" class=\"edge\"><title>0. fixed acidity&#45;&gt;5. free sulfur dioxide</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M642.641,-529.983C681.257,-520.986 732.28,-503.611 766,-471 787.509,-450.199 799.011,-417.234 804.792,-394.004\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"808.241,-394.623 807.061,-384.094 801.418,-393.06 808.241,-394.623\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"803.5\" y=\"-449.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;1.07</text>\r\n",
       "</g>\r\n",
       "<!-- 0. fixed acidity&#45;&gt;8. pH -->\r\n",
       "<g id=\"edge14\" class=\"edge\"><title>0. fixed acidity&#45;&gt;8. pH</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M575.614,-522.195C551.055,-491.362 498.75,-425.693 470.697,-390.473\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"473.412,-388.265 464.444,-382.623 467.937,-392.626 473.412,-388.265\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"549.5\" y=\"-449.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;0.08</text>\r\n",
       "</g>\r\n",
       "<!-- 0. fixed acidity&#45;&gt;10. alcohol -->\r\n",
       "<g id=\"edge21\" class=\"edge\"><title>0. fixed acidity&#45;&gt;10. alcohol</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M589,-521.744C589,-504.57 589,-477.464 589,-454 589,-454 589,-454 589,-191 589,-133.59 515.57,-114.725 460.249,-108.658\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"460.446,-105.161 450.152,-107.668 459.763,-112.127 460.446,-105.161\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"601.5\" y=\"-318.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.53</text>\r\n",
       "</g>\r\n",
       "<!-- 3. residual sugar&#45;&gt;2. citric acid -->\r\n",
       "<g id=\"edge5\" class=\"edge\"><title>3. residual sugar&#45;&gt;2. citric acid</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M638.02,-440.098C582.984,-426.676 491.001,-404.097 412,-384 407.687,-382.903 403.215,-381.755 398.728,-380.596\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"399.496,-377.179 388.937,-378.056 397.738,-383.955 399.496,-377.179\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"552.5\" y=\"-405.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.02</text>\r\n",
       "</g>\r\n",
       "<!-- 3. residual sugar&#45;&gt;5. free sulfur dioxide -->\r\n",
       "<g id=\"edge11\" class=\"edge\"><title>3. residual sugar&#45;&gt;5. free sulfur dioxide</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M710.135,-436.012C729.218,-422.825 756.506,-403.967 777.737,-389.296\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"779.892,-392.06 786.129,-383.496 775.913,-386.302 779.892,-392.06\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"767.5\" y=\"-405.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.54</text>\r\n",
       "</g>\r\n",
       "<!-- 6. total sulfur dioxide -->\r\n",
       "<g id=\"node8\" class=\"node\"><title>6. total sulfur dioxide</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"777\" cy=\"-279\" rx=\"87.9851\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"777\" y=\"-275.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">6. total sulfur dioxide</text>\r\n",
       "</g>\r\n",
       "<!-- 3. residual sugar&#45;&gt;6. total sulfur dioxide -->\r\n",
       "<g id=\"edge12\" class=\"edge\"><title>3. residual sugar&#45;&gt;6. total sulfur dioxide</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M683.229,-434.712C679.446,-413.143 675.982,-375.726 690,-348 699.87,-328.478 718.227,-312.849 735.543,-301.521\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"737.802,-304.238 744.453,-295.992 734.111,-298.291 737.802,-304.238\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"702.5\" y=\"-362.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.44</text>\r\n",
       "</g>\r\n",
       "<!-- 3. residual sugar&#45;&gt;10. alcohol -->\r\n",
       "<g id=\"edge24\" class=\"edge\"><title>3. residual sugar&#45;&gt;10. alcohol</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M669.395,-435.577C654.31,-419.647 635,-393.987 635,-367 635,-367 635,-367 635,-191 635,-164.616 624.98,-155.593 603,-141 579.886,-125.654 510.222,-116.053 458.924,-110.871\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"459.221,-107.383 448.928,-109.892 458.54,-114.35 459.221,-107.383\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"647.5\" y=\"-275.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.28</text>\r\n",
       "</g>\r\n",
       "<!-- 5. free sulfur dioxide&#45;&gt;6. total sulfur dioxide -->\r\n",
       "<g id=\"edge13\" class=\"edge\"><title>5. free sulfur dioxide&#45;&gt;6. total sulfur dioxide</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M803.322,-347.799C798.76,-336.047 792.622,-320.238 787.421,-306.842\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"790.55,-305.231 783.668,-297.175 784.025,-307.764 790.55,-305.231\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"807.5\" y=\"-318.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.99</text>\r\n",
       "</g>\r\n",
       "<!-- 8. pH&#45;&gt;10. alcohol -->\r\n",
       "<g id=\"edge27\" class=\"edge\"><title>8. pH&#45;&gt;10. alcohol</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M450.349,-347.868C446.429,-310.003 435.26,-216.561 414,-141 413.202,-138.165 412.253,-135.245 411.229,-132.358\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"414.426,-130.921 407.576,-122.838 407.891,-133.429 414.426,-130.921\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"448.5\" y=\"-231.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">3.76</text>\r\n",
       "</g>\r\n",
       "<!-- 9. sulphates&#45;&gt;10. alcohol -->\r\n",
       "<g id=\"edge28\" class=\"edge\"><title>9. sulphates&#45;&gt;10. alcohol</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M198.913,-174.09C208.288,-163.002 221.93,-149.106 237,-141 268.741,-123.927 308.17,-115.182 340.297,-110.703\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"340.982,-114.143 350.452,-109.393 340.087,-107.201 340.982,-114.143\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"249.5\" y=\"-144.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.25</text>\r\n",
       "</g>\r\n",
       "<!-- 9. sulphates&#45;&gt;11. quality -->\r\n",
       "<g id=\"edge31\" class=\"edge\"><title>9. sulphates&#45;&gt;11. quality</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M176.425,-174.195C159.226,-144.01 123.006,-80.4388 102.662,-44.7341\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"105.606,-42.8294 97.614,-35.8735 99.5235,-46.2948 105.606,-42.8294\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"159.5\" y=\"-101.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.44</text>\r\n",
       "</g>\r\n",
       "<!-- 10. alcohol&#45;&gt;11. quality -->\r\n",
       "<g id=\"edge32\" class=\"edge\"><title>10. alcohol&#45;&gt;11. quality</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M361.116,-93.4066C304.017,-77.8508 197.588,-48.8557 135.796,-32.0214\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"136.476,-28.5791 125.908,-29.3274 134.636,-35.333 136.476,-28.5791\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"275.5\" y=\"-57.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.15</text>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.dot.Digraph at 0x22d72bbb4e0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pk = make_prior_knowledge(\n",
    "    n_variables=len(X.columns),\n",
    "    sink_variables=[11])\n",
    "\n",
    "model = lingam.DirectLiNGAM(prior_knowledge=pk)\n",
    "model.fit(X)\n",
    "labels = [f'{i}. {col}' for i, col in enumerate(X.columns)]\n",
    "make_graph(model.adjacency_matrix_, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prediction Model\n",
    "We create the logistic regression model because the target is a discrete variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n",
       "          intercept_scaling=1, max_iter=100, multi_class='warn',\n",
       "          n_jobs=None, penalty='l2', random_state=None, solver='liblinear',\n",
       "          tol=0.0001, verbose=0, warm_start=False)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "target = 11 # quality\n",
    "features = [i for i in range(X.shape[1]) if i != target]\n",
    "reg = LogisticRegression(solver='liblinear')\n",
    "reg.fit(X.iloc[:, features], X.iloc[:, target])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Identification of Feature with Greatest Causal Influence on Prediction\n",
    "To identify of the feature having the greatest intervention effect on the prediction, we create a `CausalEffect` object and call the `estimate_effects_on_prediction` method. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>feature</th>\n",
       "      <th>effect_plus</th>\n",
       "      <th>effect_minus</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>fixed acidity</td>\n",
       "      <td>0.108700</td>\n",
       "      <td>0.111510</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>volatile acidity</td>\n",
       "      <td>0.283712</td>\n",
       "      <td>0.266202</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>citric acid</td>\n",
       "      <td>0.204237</td>\n",
       "      <td>0.214388</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>residual sugar</td>\n",
       "      <td>0.013498</td>\n",
       "      <td>0.013540</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>chlorides</td>\n",
       "      <td>0.013722</td>\n",
       "      <td>0.013679</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>free sulfur dioxide</td>\n",
       "      <td>0.088231</td>\n",
       "      <td>0.086463</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>total sulfur dioxide</td>\n",
       "      <td>0.403675</td>\n",
       "      <td>0.369129</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>density</td>\n",
       "      <td>0.542298</td>\n",
       "      <td>0.481731</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>pH</td>\n",
       "      <td>0.131429</td>\n",
       "      <td>0.135560</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>sulphates</td>\n",
       "      <td>0.242358</td>\n",
       "      <td>0.256786</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>alcohol</td>\n",
       "      <td>0.411940</td>\n",
       "      <td>0.455436</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>quality</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 feature  effect_plus  effect_minus\n",
       "0          fixed acidity     0.108700      0.111510\n",
       "1       volatile acidity     0.283712      0.266202\n",
       "2            citric acid     0.204237      0.214388\n",
       "3         residual sugar     0.013498      0.013540\n",
       "4              chlorides     0.013722      0.013679\n",
       "5    free sulfur dioxide     0.088231      0.086463\n",
       "6   total sulfur dioxide     0.403675      0.369129\n",
       "7                density     0.542298      0.481731\n",
       "8                     pH     0.131429      0.135560\n",
       "9              sulphates     0.242358      0.256786\n",
       "10               alcohol     0.411940      0.455436\n",
       "11               quality     0.000000      0.000000"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ce = lingam.CausalEffect(model)\n",
    "effects = ce.estimate_effects_on_prediction(X, target, reg)\n",
    "\n",
    "df_effects = pd.DataFrame()\n",
    "df_effects['feature'] = X.columns\n",
    "df_effects['effect_plus'] = effects[:, 0]\n",
    "df_effects['effect_minus'] = effects[:, 1]\n",
    "df_effects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "density\n"
     ]
    }
   ],
   "source": [
    "max_index = np.unravel_index(np.argmax(effects), effects.shape)\n",
    "print(X.columns[max_index[0]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Estimation of Optimal Intervention\n",
    "`estimate_optimal_intervention` method of `CausalEffect` is available only for linear regression models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
