{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Causal Effect"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import and settings\n",
    "In this example, we need to import `numpy`, `pandas`, and `graphviz` in addition to `lingam`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['1.16.2', '0.24.2', '0.11.1', '1.2.0']\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import graphviz\n",
    "import lingam\n",
    "\n",
    "print([np.__version__, pd.__version__, graphviz.__version__, lingam.__version__])\n",
    "\n",
    "np.set_printoptions(precision=3, suppress=True)\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utility function\n",
    "We define a utility function to draw the directed acyclic graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_graph(adjacency_matrix, labels=None):\n",
    "    idx = np.abs(adjacency_matrix) > 0.01\n",
    "    dirs = np.where(idx)\n",
    "    d = graphviz.Digraph(engine='dot')\n",
    "    names = labels if labels else [f'x{i}' for i in range(len(adjacency_matrix))]\n",
    "    for to, from_, coef in zip(dirs[0], dirs[1], adjacency_matrix[idx]):\n",
    "        d.edge(names[from_], names[to], label=f'{coef:.2f}')\n",
    "    return d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test data\n",
    "We use 'Auto MPG Data Set' (http://archive.ics.uci.edu/ml/datasets/Auto+MPG) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(392, 6)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mpg</th>\n",
       "      <th>cylinders</th>\n",
       "      <th>displacement</th>\n",
       "      <th>horsepower</th>\n",
       "      <th>weight</th>\n",
       "      <th>acceleration</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>18.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>307.0</td>\n",
       "      <td>130.0</td>\n",
       "      <td>3504.0</td>\n",
       "      <td>12.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>15.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>350.0</td>\n",
       "      <td>165.0</td>\n",
       "      <td>3693.0</td>\n",
       "      <td>11.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>18.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>318.0</td>\n",
       "      <td>150.0</td>\n",
       "      <td>3436.0</td>\n",
       "      <td>11.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>16.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>304.0</td>\n",
       "      <td>150.0</td>\n",
       "      <td>3433.0</td>\n",
       "      <td>12.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>17.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>302.0</td>\n",
       "      <td>140.0</td>\n",
       "      <td>3449.0</td>\n",
       "      <td>10.5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    mpg  cylinders  displacement  horsepower  weight  acceleration\n",
       "0  18.0        8.0         307.0       130.0  3504.0          12.0\n",
       "1  15.0        8.0         350.0       165.0  3693.0          11.5\n",
       "2  18.0        8.0         318.0       150.0  3436.0          11.0\n",
       "3  16.0        8.0         304.0       150.0  3433.0          12.0\n",
       "4  17.0        8.0         302.0       140.0  3449.0          10.5"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data-original',\n",
    "                   delim_whitespace=True, header=None,\n",
    "                   names = ['mpg', 'cylinders', 'displacement',\n",
    "                            'horsepower', 'weight', 'acceleration',\n",
    "                            'model year', 'origin', 'car name'])\n",
    "X.dropna(inplace=True)\n",
    "X.drop(['model year', 'origin', 'car name'], axis=1, inplace=True)\n",
    "print(X.shape)\n",
    "X.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Causal Discovery\n",
    "To run causal discovery, we create a `DirectLiNGAM` object and call the `fit` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: %3 Pages: 1 -->\r\n",
       "<svg width=\"298pt\" height=\"479pt\"\r\n",
       " viewBox=\"0.00 0.00 298.00 479.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 475)\">\r\n",
       "<title>%3</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-475 294,-475 294,4 -4,4\"/>\r\n",
       "<!-- 1. cylinders -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>1. cylinders</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"141\" cy=\"-453\" rx=\"52.7911\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"141\" y=\"-449.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1. cylinders</text>\r\n",
       "</g>\r\n",
       "<!-- 0. mpg -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>0. mpg</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"63\" cy=\"-366\" rx=\"36.2938\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"63\" y=\"-362.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0. mpg</text>\r\n",
       "</g>\r\n",
       "<!-- 1. cylinders&#45;&gt;0. mpg -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>1. cylinders&#45;&gt;0. mpg</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M119.421,-436.354C112.231,-430.656 104.424,-423.907 98,-417 91.0048,-409.479 84.2888,-400.484 78.6435,-392.211\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"81.4041,-390.039 72.9755,-383.62 75.5612,-393.894 81.4041,-390.039\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"112.5\" y=\"-405.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;3.56</text>\r\n",
       "</g>\r\n",
       "<!-- 2. displacement -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2. displacement</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"200\" cy=\"-192\" rx=\"67.6881\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"200\" y=\"-188.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">2. displacement</text>\r\n",
       "</g>\r\n",
       "<!-- 1. cylinders&#45;&gt;2. displacement -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>1. cylinders&#45;&gt;2. displacement</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M153.216,-435.324C162.227,-422.099 173.933,-402.778 180,-384 198.147,-327.837 200.76,-258.016 200.645,-220.282\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"204.143,-220.054 200.544,-210.089 197.144,-220.124 204.143,-220.054\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"210.5\" y=\"-318.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">33.49</text>\r\n",
       "</g>\r\n",
       "<!-- 4. weight -->\r\n",
       "<g id=\"node4\" class=\"node\"><title>4. weight</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"132\" cy=\"-279\" rx=\"43.5923\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"132\" y=\"-275.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">4. weight</text>\r\n",
       "</g>\r\n",
       "<!-- 1. cylinders&#45;&gt;4. weight -->\r\n",
       "<g id=\"edge9\" class=\"edge\"><title>1. cylinders&#45;&gt;4. weight</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M140.104,-434.879C138.541,-405.001 135.302,-343.113 133.427,-307.274\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"136.91,-306.855 132.892,-297.052 129.92,-307.221 136.91,-306.855\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"157\" y=\"-362.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">315.37</text>\r\n",
       "</g>\r\n",
       "<!-- 0. mpg&#45;&gt;4. weight -->\r\n",
       "<g id=\"edge8\" class=\"edge\"><title>0. mpg&#45;&gt;4. weight</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M70.4791,-348.316C75.4836,-338.127 82.6676,-325.151 91,-315 94.8041,-310.366 99.2717,-305.859 103.833,-301.707\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"106.213,-304.276 111.496,-295.092 101.639,-298.977 106.213,-304.276\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"109\" y=\"-318.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;36.98</text>\r\n",
       "</g>\r\n",
       "<!-- 3. horsepower -->\r\n",
       "<g id=\"node5\" class=\"node\"><title>3. horsepower</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"128\" cy=\"-18\" rx=\"63.0888\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"128\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">3. horsepower</text>\r\n",
       "</g>\r\n",
       "<!-- 0. mpg&#45;&gt;3. horsepower -->\r\n",
       "<g id=\"edge4\" class=\"edge\"><title>0. mpg&#45;&gt;3. horsepower</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M43.4117,-350.789C24.918,-335.568 0,-309.573 0,-280 0,-280 0,-280 0,-104 0,-67.3032 36.9619,-45.5007 71.1614,-33.1819\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"72.3503,-36.4747 80.7012,-29.9544 70.1069,-29.8439 72.3503,-36.4747\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"14.5\" y=\"-188.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;0.39</text>\r\n",
       "</g>\r\n",
       "<!-- 5. acceleration -->\r\n",
       "<g id=\"node6\" class=\"node\"><title>5. acceleration</title>\r\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"188\" cy=\"-105\" rx=\"63.0888\" ry=\"18\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"188\" y=\"-101.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">5. acceleration</text>\r\n",
       "</g>\r\n",
       "<!-- 0. mpg&#45;&gt;5. acceleration -->\r\n",
       "<g id=\"edge10\" class=\"edge\"><title>0. mpg&#45;&gt;5. acceleration</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M61.5957,-347.794C59.2189,-306.853 58.9694,-202.475 111,-141 117.475,-133.35 125.97,-127.292 134.948,-122.519\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"136.526,-125.644 144.03,-118.165 133.5,-119.332 136.526,-125.644\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"83.5\" y=\"-231.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.04</text>\r\n",
       "</g>\r\n",
       "<!-- 2. displacement&#45;&gt;3. horsepower -->\r\n",
       "<g id=\"edge5\" class=\"edge\"><title>2. displacement&#45;&gt;3. horsepower</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M221.169,-174.644C245.164,-154.036 278.665,-117.612 260,-87 243.613,-60.1251 212.829,-43.4172 185.109,-33.2917\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"186.055,-29.9161 175.46,-29.9788 183.782,-36.5367 186.055,-29.9161\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"277.5\" y=\"-101.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.10</text>\r\n",
       "</g>\r\n",
       "<!-- 2. displacement&#45;&gt;5. acceleration -->\r\n",
       "<g id=\"edge11\" class=\"edge\"><title>2. displacement&#45;&gt;5. acceleration</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M197.572,-173.799C195.929,-162.163 193.724,-146.548 191.845,-133.237\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"195.288,-132.588 190.425,-123.175 188.357,-133.567 195.288,-132.588\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"209.5\" y=\"-144.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;0.03</text>\r\n",
       "</g>\r\n",
       "<!-- 4. weight&#45;&gt;2. displacement -->\r\n",
       "<g id=\"edge3\" class=\"edge\"><title>4. weight&#45;&gt;2. displacement</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M143.205,-261.429C150.085,-251.517 159.231,-238.792 168,-228 170.882,-224.454 174.018,-220.8 177.163,-217.252\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"179.849,-219.5 183.966,-209.737 174.66,-214.802 179.849,-219.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"180.5\" y=\"-231.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.05</text>\r\n",
       "</g>\r\n",
       "<!-- 4. weight&#45;&gt;3. horsepower -->\r\n",
       "<g id=\"edge6\" class=\"edge\"><title>4. weight&#45;&gt;3. horsepower</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M128.6,-260.875C122.399,-227.443 110.384,-151.143 116,-87 117.181,-73.5146 119.655,-58.6952 122.051,-46.3426\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"125.547,-46.7138 124.103,-36.2179 118.687,-45.3235 125.547,-46.7138\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"128.5\" y=\"-144.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.02</text>\r\n",
       "</g>\r\n",
       "<!-- 5. acceleration&#45;&gt;3. horsepower -->\r\n",
       "<g id=\"edge7\" class=\"edge\"><title>5. acceleration&#45;&gt;3. horsepower</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M176.146,-87.2067C167.435,-74.8663 155.455,-57.8941 145.624,-43.9675\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"148.456,-41.9097 139.829,-35.7584 142.737,-45.9465 148.456,-41.9097\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"176.5\" y=\"-57.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;4.66</text>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.dot.Digraph at 0x1aad0e60518>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = lingam.DirectLiNGAM()\n",
    "model.fit(X)\n",
    "labels = [f'{i}. {col}' for i, col in enumerate(X.columns)]\n",
    "make_graph(model.adjacency_matrix_, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prediction Model\n",
    "We create the linear regression model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LassoCV(alphas=None, copy_X=True, cv=5, eps=0.001, fit_intercept=True,\n",
       "    max_iter=1000, n_alphas=100, n_jobs=None, normalize=False,\n",
       "    positive=False, precompute='auto', random_state=0, selection='cyclic',\n",
       "    tol=0.0001, verbose=False)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.linear_model import LassoCV\n",
    "\n",
    "target = 0 # mpg\n",
    "features = [i for i in range(X.shape[1]) if i != target]\n",
    "reg = LassoCV(cv=5, random_state=0)\n",
    "reg.fit(X.iloc[:, features], X.iloc[:, target])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Identification of Feature with Greatest Causal Influence on Prediction\n",
    "To identify of the feature having the greatest intervention effect on the prediction, we create a `CausalEffect` object and call the `estimate_effects_on_prediction` method. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>feature</th>\n",
       "      <th>effect_plus</th>\n",
       "      <th>effect_minus</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>mpg</td>\n",
       "      <td>2.040563</td>\n",
       "      <td>2.040563</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>cylinders</td>\n",
       "      <td>5.872194</td>\n",
       "      <td>5.872194</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>displacement</td>\n",
       "      <td>1.315913</td>\n",
       "      <td>1.315913</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>horsepower</td>\n",
       "      <td>1.047548</td>\n",
       "      <td>1.047548</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>weight</td>\n",
       "      <td>5.632157</td>\n",
       "      <td>5.632157</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>acceleration</td>\n",
       "      <td>0.350074</td>\n",
       "      <td>0.350074</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        feature  effect_plus  effect_minus\n",
       "0           mpg     2.040563      2.040563\n",
       "1     cylinders     5.872194      5.872194\n",
       "2  displacement     1.315913      1.315913\n",
       "3    horsepower     1.047548      1.047548\n",
       "4        weight     5.632157      5.632157\n",
       "5  acceleration     0.350074      0.350074"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ce = lingam.CausalEffect(model)\n",
    "effects = ce.estimate_effects_on_prediction(X, target, reg)\n",
    "\n",
    "df_effects = pd.DataFrame()\n",
    "df_effects['feature'] = X.columns\n",
    "df_effects['effect_plus'] = effects[:, 0]\n",
    "df_effects['effect_minus'] = effects[:, 1]\n",
    "df_effects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cylinders\n"
     ]
    }
   ],
   "source": [
    "max_index = np.unravel_index(np.argmax(effects), effects.shape)\n",
    "print(X.columns[max_index[0]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Estimation of Optimal Intervention\n",
    "To estimate of the intervention such that the expectation of the prediction of the post-intervention observations is equal or close to a specified value, we use `estimate_optimal_intervention` method of `CausalEffect`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimal intervention: 7.922\n"
     ]
    }
   ],
   "source": [
    "# mpg = 15\n",
    "c = ce.estimate_optimal_intervention(X, target, reg, 1, 15)\n",
    "print(f'Optimal intervention: {c:.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimal intervention: 6.182\n"
     ]
    }
   ],
   "source": [
    "# mpg = 21\n",
    "c = ce.estimate_optimal_intervention(X, target, reg, 1, 21)\n",
    "print(f'Optimal intervention: {c:.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimal intervention: 3.571\n"
     ]
    }
   ],
   "source": [
    "# mpg = 30\n",
    "c = ce.estimate_optimal_intervention(X, target, reg, 1, 30)\n",
    "print(f'Optimal intervention: {c:.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
