{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.3"
    },
    "toc": {
      "base_numbering": 1,
      "nav_menu": {},
      "number_sections": true,
      "sideBar": true,
      "skip_h1_title": false,
      "title_cell": "Table of Contents",
      "title_sidebar": "Contents",
      "toc_cell": false,
      "toc_position": {},
      "toc_section_display": true,
      "toc_window_display": false
    },
    "colab": {
      "name": "Example to handle groups of variables with known causal orders btw the groups.ipynb",
      "provenance": [],
      "collapsed_sections": []
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XPXL0jCBHEaI"
      },
      "source": [
        "# Example to handle groups of varibles with known causal orders btw the groups\n",
        "The third group does not cause the first and second. The second group \n",
        "does not cause the first. This example estimates the causal structure of variables of the second group. Further, it estimates the causal structure of variables of the third group if all the variables of the third are continuous or the third group has only a binary variable."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nbFjt3O3HEaJ"
      },
      "source": [
        "## Import necessary packages"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2019-09-09T02:01:39.097825Z",
          "start_time": "2019-09-09T02:01:33.841227Z"
        },
        "id": "VR-ty4wEHEaK",
        "outputId": "e9a1f611-2e67-4a45-aa96-a90f090f24bd"
      },
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "import graphviz\n",
        "from sklearn.linear_model import LinearRegression, LogisticRegression\n",
        "import lingam\n",
        "from lingam.utils import make_dot, remove_effect, predict_adaptive_lasso\n",
        "\n",
        "from IPython.display import display_svg, SVG\n",
        "\n",
        "print([np.__version__, pd.__version__, graphviz.__version__, lingam.__version__])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['1.16.2', '0.24.2', '0.11.1', '1.4.0']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UMx4duoRHEaR"
      },
      "source": [
        "# Output setting\n",
        "np.set_printoptions(precision=3, suppress=True)\n",
        "np.random.seed(0)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aYR8xMgBHEaV"
      },
      "source": [
        "## Indicate which variable belong to which group"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TCjCStbCHEaW"
      },
      "source": [
        "# Dataset to be analyzed\n",
        "input_file = 'data1.csv'\n",
        "\n",
        "# Indicate which variable belong to which group\n",
        "set1_labels = ['x1_1', 'x1_2']\n",
        "set2_labels = ['x2_1', 'x2_2', 'x2_3', 'x2_4', 'x2_5']\n",
        "set3_labels = ['x3_1', 'x3_2', 'x3_3']\n",
        "labels = set1_labels + set2_labels + set3_labels"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I9SQSWGpHEaZ"
      },
      "source": [
        "## Load data"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4h2bZZK1HEaZ",
        "outputId": "e039692d-8b37-492c-f1cf-98e91f0037c7"
      },
      "source": [
        "X = pd.read_csv(input_file)\n",
        "\n",
        "# Check if the third group includes a binary data\n",
        "contain_bin_var_in_set3 = False\n",
        "for set3_label in set3_labels:\n",
        "    if len(np.unique(X[set3_label])) == 2:\n",
        "        X[set3_label] = X[set3_label].astype(int)\n",
        "        contain_bin_var_in_set3 = True\n",
        "\n",
        "if contain_bin_var_in_set3:\n",
        "    print('\\x1b[31mThe third group includes a binary variable.\\x1b[0m')\n",
        "\n",
        "X.head()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>x1_1</th>\n",
              "      <th>x1_2</th>\n",
              "      <th>x2_1</th>\n",
              "      <th>x2_2</th>\n",
              "      <th>x2_3</th>\n",
              "      <th>x2_4</th>\n",
              "      <th>x2_5</th>\n",
              "      <th>x3_1</th>\n",
              "      <th>x3_2</th>\n",
              "      <th>x3_3</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>0.102727</td>\n",
              "      <td>0.205501</td>\n",
              "      <td>1.026972</td>\n",
              "      <td>-0.271038</td>\n",
              "      <td>-0.124869</td>\n",
              "      <td>-0.157580</td>\n",
              "      <td>0.273743</td>\n",
              "      <td>0.179777</td>\n",
              "      <td>0.048608</td>\n",
              "      <td>0.376362</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0.562784</td>\n",
              "      <td>-3.905674</td>\n",
              "      <td>0.232378</td>\n",
              "      <td>2.365468</td>\n",
              "      <td>0.235710</td>\n",
              "      <td>-2.103891</td>\n",
              "      <td>0.458637</td>\n",
              "      <td>-0.603726</td>\n",
              "      <td>-0.053263</td>\n",
              "      <td>2.230408</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>0.230076</td>\n",
              "      <td>-0.049555</td>\n",
              "      <td>0.162457</td>\n",
              "      <td>0.885359</td>\n",
              "      <td>-1.226624</td>\n",
              "      <td>-2.974060</td>\n",
              "      <td>0.730431</td>\n",
              "      <td>-0.061475</td>\n",
              "      <td>0.883855</td>\n",
              "      <td>1.862853</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>0.094054</td>\n",
              "      <td>0.540496</td>\n",
              "      <td>-0.644040</td>\n",
              "      <td>0.962975</td>\n",
              "      <td>-0.615114</td>\n",
              "      <td>-1.208096</td>\n",
              "      <td>0.031687</td>\n",
              "      <td>1.125484</td>\n",
              "      <td>-0.864481</td>\n",
              "      <td>-1.166307</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>-0.165689</td>\n",
              "      <td>-2.430977</td>\n",
              "      <td>0.152987</td>\n",
              "      <td>2.184269</td>\n",
              "      <td>-2.179047</td>\n",
              "      <td>-2.949728</td>\n",
              "      <td>-1.430591</td>\n",
              "      <td>-6.400437</td>\n",
              "      <td>-2.622068</td>\n",
              "      <td>-1.654538</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "       x1_1      x1_2      x2_1      x2_2      x2_3      x2_4      x2_5  \\\n",
              "0  0.102727  0.205501  1.026972 -0.271038 -0.124869 -0.157580  0.273743   \n",
              "1  0.562784 -3.905674  0.232378  2.365468  0.235710 -2.103891  0.458637   \n",
              "2  0.230076 -0.049555  0.162457  0.885359 -1.226624 -2.974060  0.730431   \n",
              "3  0.094054  0.540496 -0.644040  0.962975 -0.615114 -1.208096  0.031687   \n",
              "4 -0.165689 -2.430977  0.152987  2.184269 -2.179047 -2.949728 -1.430591   \n",
              "\n",
              "       x3_1      x3_2      x3_3  \n",
              "0  0.179777  0.048608  0.376362  \n",
              "1 -0.603726 -0.053263  2.230408  \n",
              "2 -0.061475  0.883855  1.862853  \n",
              "3  1.125484 -0.864481 -1.166307  \n",
              "4 -6.400437 -2.622068 -1.654538  "
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RJ4cuRJhHEab"
      },
      "source": [
        "## Compute the residual when each variable of the second group on all the variables of the first group"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cCI19wdAHEac"
      },
      "source": [
        "# Obtain the column numbers of the variables of the first group\n",
        "set1_indices = [X.columns.get_loc(label) for label in set1_labels]\n",
        "\n",
        "# Compute the residual when each variable of the second group on all the variables of the first group\n",
        "X_removed_set1 = remove_effect(X, set1_indices)\n",
        "\n",
        "# Create the residual dataset for the second group by the residuals computed just above\n",
        "set2_indices = [X.columns.get_loc(label) for label in set2_labels]\n",
        "X2_resid = X_removed_set1[:, set2_indices]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qnlCWuqKHEae"
      },
      "source": [
        "## Perform LiNGAM on the residual dataset for the second group and draw the estimated causal graph"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ylePh0iNHEae",
        "outputId": "05a8a426-fa06-49ad-b02a-d18624615353"
      },
      "source": [
        "set2_model = lingam.DirectLiNGAM()\n",
        "set2_model.fit(X2_resid)\n",
        "make_dot(set2_model.adjacency_matrix_, labels=set2_labels)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n -->\r\n<!-- Title: %3 Pages: 1 -->\r\n<svg width=\"192pt\" height=\"218pt\"\r\n viewBox=\"0.00 0.00 191.60 218.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 214)\">\r\n<title>%3</title>\r\n<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-214 187.597,-214 187.597,4 -4,4\"/>\r\n<!-- x2_1 -->\r\n<g id=\"node1\" class=\"node\"><title>x2_1</title>\r\n<ellipse fill=\"none\" stroke=\"black\" cx=\"58.5975\" cy=\"-192\" rx=\"28.6953\" ry=\"18\"/>\r\n<text text-anchor=\"middle\" x=\"58.5975\" y=\"-188.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x2_1</text>\r\n</g>\r\n<!-- x2_3 -->\r\n<g id=\"node3\" class=\"node\"><title>x2_3</title>\r\n<ellipse fill=\"none\" stroke=\"black\" cx=\"28.5975\" cy=\"-105\" rx=\"28.6953\" ry=\"18\"/>\r\n<text text-anchor=\"middle\" x=\"28.5975\" y=\"-101.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x2_3</text>\r\n</g>\r\n<!-- x2_1&#45;&gt;x2_3 -->\r\n<g id=\"edge1\" class=\"edge\"><title>x2_1&#45;&gt;x2_3</title>\r\n<path fill=\"none\" stroke=\"black\" d=\"M43.4047,-176.314C38.3829,-170.502 33.3739,-163.436 30.5975,-156 27.9536,-148.919 26.8758,-140.877 26.6049,-133.344\"/>\r\n<polygon fill=\"black\" stroke=\"black\" points=\"30.1059,-133.213 26.6801,-123.187 23.106,-133.161 30.1059,-133.213\"/>\r\n<text text-anchor=\"middle\" x=\"43.0975\" y=\"-144.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.36</text>\r\n</g>\r\n<!-- x2_5 -->\r\n<g id=\"node5\" class=\"node\"><title>x2_5</title>\r\n<ellipse fill=\"none\" stroke=\"black\" cx=\"103.597\" cy=\"-105\" rx=\"28.6953\" ry=\"18\"/>\r\n<text text-anchor=\"middle\" x=\"103.597\" y=\"-101.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x2_5</text>\r\n</g>\r\n<!-- x2_1&#45;&gt;x2_5 -->\r\n<g id=\"edge5\" class=\"edge\"><title>x2_1&#45;&gt;x2_5</title>\r\n<path fill=\"none\" stroke=\"black\" d=\"M67.2739,-174.611C73.7483,-162.382 82.7159,-145.443 90.1207,-131.456\"/>\r\n<polygon fill=\"black\" stroke=\"black\" points=\"93.362,-132.814 94.9476,-122.339 87.1755,-129.539 93.362,-132.814\"/>\r\n<text text-anchor=\"middle\" x=\"97.0975\" y=\"-144.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.58</text>\r\n</g>\r\n<!-- x2_2 -->\r\n<g id=\"node2\" class=\"node\"><title>x2_2</title>\r\n<ellipse fill=\"none\" stroke=\"black\" cx=\"145.597\" cy=\"-192\" rx=\"28.6953\" ry=\"18\"/>\r\n<text text-anchor=\"middle\" x=\"145.597\" y=\"-188.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x2_2</text>\r\n</g>\r\n<!-- x2_4 -->\r\n<g id=\"node4\" class=\"node\"><title>x2_4</title>\r\n<ellipse fill=\"none\" stroke=\"black\" cx=\"103.597\" cy=\"-18\" rx=\"28.6953\" ry=\"18\"/>\r\n<text text-anchor=\"middle\" x=\"103.597\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x2_4</text>\r\n</g>\r\n<!-- x2_2&#45;&gt;x2_4 -->\r\n<g id=\"edge2\" class=\"edge\"><title>x2_2&#45;&gt;x2_4</title>\r\n<path fill=\"none\" stroke=\"black\" d=\"M150.864,-174.077C152.387,-168.391 153.841,-161.975 154.597,-156 159.383,-118.173 139.591,-59.7507 136.597,-54 134.106,-49.2151 130.771,-44.6166 127.177,-40.4151\"/>\r\n<polygon fill=\"black\" stroke=\"black\" points=\"129.497,-37.7723 120.099,-32.8799 124.395,-42.5649 129.497,-37.7723\"/>\r\n<text text-anchor=\"middle\" x=\"169.097\" y=\"-101.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;0.72</text>\r\n</g>\r\n<!-- x2_2&#45;&gt;x2_5 -->\r\n<g id=\"edge6\" class=\"edge\"><title>x2_2&#45;&gt;x2_5</title>\r\n<path fill=\"none\" stroke=\"black\" d=\"M137.499,-174.611C131.457,-162.382 123.087,-145.443 116.176,-131.456\"/>\r\n<polygon fill=\"black\" stroke=\"black\" points=\"119.238,-129.753 111.671,-122.339 112.963,-132.854 119.238,-129.753\"/>\r\n<text text-anchor=\"middle\" x=\"140.097\" y=\"-144.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.14</text>\r\n</g>\r\n<!-- x2_3&#45;&gt;x2_4 -->\r\n<g id=\"edge3\" class=\"edge\"><title>x2_3&#45;&gt;x2_4</title>\r\n<path fill=\"none\" stroke=\"black\" d=\"M42.006,-88.8037C53.5244,-75.7495 70.3015,-56.7354 83.4381,-41.8473\"/>\r\n<polygon fill=\"black\" stroke=\"black\" points=\"86.3132,-43.8788 90.305,-34.0647 81.0644,-39.2475 86.3132,-43.8788\"/>\r\n<text text-anchor=\"middle\" x=\"83.0975\" y=\"-57.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.30</text>\r\n</g>\r\n<!-- x2_5&#45;&gt;x2_4 -->\r\n<g id=\"edge4\" class=\"edge\"><title>x2_5&#45;&gt;x2_4</title>\r\n<path fill=\"none\" stroke=\"black\" d=\"M103.597,-86.799C103.597,-75.1626 103.597,-59.5479 103.597,-46.2368\"/>\r\n<polygon fill=\"black\" stroke=\"black\" points=\"107.098,-46.1754 103.597,-36.1754 100.098,-46.1755 107.098,-46.1754\"/>\r\n<text text-anchor=\"middle\" x=\"118.097\" y=\"-57.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;0.05</text>\r\n</g>\r\n</g>\r\n</svg>\r\n",
            "text/plain": [
              "<graphviz.dot.Digraph at 0x1cf4c7ef438>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RTB-l9n7HEag"
      },
      "source": [
        "## Regress each variable of the third group on all the variables of the first and second groups"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BbfP02NWHEag"
      },
      "source": [
        "if contain_bin_var_in_set3:\n",
        "    print('\\x1b[31mSkip because the third group includes a binary variable\\x1b[0m')\n",
        "else:\n",
        "    # Obtain column numbers of the variables of the first and second groups\n",
        "    set12_indices = [X.columns.get_loc(label) for label in set1_labels+set2_labels]\n",
        "\n",
        "    # Regress each variable of the third group on all the variables of the first and second groups\n",
        "    X_removed_set12 = remove_effect(X, set12_indices)\n",
        "\n",
        "    # Create the residual dataset for the third group by the residuals computed just above\n",
        "    set3_indices = [X.columns.get_loc(label) for label in set3_labels]\n",
        "    X3_resid = X_removed_set12[:, set3_indices]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nPjGsYLhHEai"
      },
      "source": [
        "## Perform LiNGAM on the residual dataset for the third group and draw the estimated causal graph"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yT8V-I1zHEai",
        "outputId": "ffd1a3bf-29cc-48f2-964d-a26d2c3195b9"
      },
      "source": [
        "if contain_bin_var_in_set3:\n",
        "    print('\\x1b[31mSkip because the third group includes a binary variableスキップします\\x1b[0m')\n",
        "else:\n",
        "    set3_model = lingam.DirectLiNGAM()\n",
        "    set3_model.fit(X3_resid)\n",
        "    g = make_dot(set3_model.adjacency_matrix_, labels=set3_labels)\n",
        "    display_svg(SVG(g._repr_svg_()))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/svg+xml": "<svg height=\"131pt\" viewBox=\"0.00 0.00 140.19 131.00\" width=\"140pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 127)\">\n<title>%3</title>\n<polygon fill=\"white\" points=\"-4,4 -4,-127 136.195,-127 136.195,4 -4,4\" stroke=\"none\"/>\n<!-- x3_1 -->\n<g class=\"node\" id=\"node1\"><title>x3_1</title>\n<ellipse cx=\"28.5975\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"28.5975\" y=\"-101.3\">x3_1</text>\n</g>\n<!-- x3_2 -->\n<g class=\"node\" id=\"node2\"><title>x3_2</title>\n<ellipse cx=\"103.597\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103.597\" y=\"-101.3\">x3_2</text>\n</g>\n<!-- x3_3 -->\n<g class=\"node\" id=\"node3\"><title>x3_3</title>\n<ellipse cx=\"103.597\" cy=\"-18\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103.597\" y=\"-14.3\">x3_3</text>\n</g>\n<!-- x3_2&#45;&gt;x3_3 -->\n<g class=\"edge\" id=\"edge1\"><title>x3_2-&gt;x3_3</title>\n<path d=\"M103.597,-86.799C103.597,-75.1626 103.597,-59.5479 103.597,-46.2368\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"107.098,-46.1754 103.597,-36.1754 100.098,-46.1755 107.098,-46.1754\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"116.097\" y=\"-57.8\">0.17</text>\n</g>\n</g>\n</svg>"
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZzKYfh_OHEaj"
      },
      "source": [
        "## Combine the causal grpahs for the second and third groups and draw it"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rQ133HTqHEak",
        "outputId": "aec61e8b-e221-46c8-b65d-be7ae9dc6c71"
      },
      "source": [
        "if contain_bin_var_in_set3:\n",
        "    print('\\x1b[31mSkip because the third group includes a binary variable\\x1b[0m')\n",
        "else:\n",
        "    # Adjucency matrix for the all the variables\n",
        "    adj_matrix = np.zeros([X.shape[1], X.shape[1]])\n",
        "\n",
        "    # Update the adjuncy matrix using the causal graph estimated for the second group\n",
        "    set2_start_pos = len(set1_labels)\n",
        "    set2_end_pos = set2_start_pos + len(set2_labels)\n",
        "    adj_matrix[set2_start_pos:set2_end_pos, set2_start_pos:set2_end_pos] = set2_model.adjacency_matrix_\n",
        "\n",
        "    # Update the adjuncy matrix using the causal graph estimated for the third group\n",
        "    set3_start_pos = len(set1_labels) + len(set2_labels)\n",
        "    set3_end_pos = set3_start_pos + len(set3_labels)\n",
        "    adj_matrix[set3_start_pos:set3_end_pos, set3_start_pos:set3_end_pos] = set3_model.adjacency_matrix_\n",
        "\n",
        "    # Compute the connection strengths from each variable of the second group to that of the third\n",
        "    for i, idx in enumerate(set3_indices):\n",
        "\n",
        "        # Obtain parents of each variable of the third group\n",
        "        set3_parents = np.where(np.abs(set3_model.adjacency_matrix_[i]) > 0)[0]\n",
        "        set3_parents = [X.columns.get_loc(set3_labels[idx]) for idx in set3_parents]\n",
        "\n",
        "        # Create the set of explanatory variables\n",
        "        predictors = []\n",
        "        predictors.extend(set2_indices) # All the variables of the second group\n",
        "        predictors.extend(set3_parents) # Parents in the third group\n",
        "\n",
        "        # Pruning\n",
        "        coefs = predict_adaptive_lasso(X_removed_set1, predictors, idx)\n",
        "        adj_matrix[idx, set2_start_pos:set2_end_pos] = coefs[:len(set2_indices)]\n",
        "\n",
        "    # Remove a part of the adjacency matrix corresponding the variables of the first group\n",
        "    adj_matrix_set23 = adj_matrix[set2_start_pos:, set2_start_pos:]\n",
        "    g = make_dot(adj_matrix_set23, labels=set2_labels+set3_labels)\n",
        "    display_svg(SVG(g._repr_svg_()))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/svg+xml": "<svg height=\"305pt\" viewBox=\"0.00 0.00 230.19 305.00\" width=\"230pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 301)\">\n<title>%3</title>\n<polygon fill=\"white\" points=\"-4,4 -4,-301 226.195,-301 226.195,4 -4,4\" stroke=\"none\"/>\n<!-- x2_1 -->\n<g class=\"node\" id=\"node1\"><title>x2_1</title>\n<ellipse cx=\"160.597\" cy=\"-279\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"160.597\" y=\"-275.3\">x2_1</text>\n</g>\n<!-- x2_3 -->\n<g class=\"node\" id=\"node3\"><title>x2_3</title>\n<ellipse cx=\"118.597\" cy=\"-192\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"118.597\" y=\"-188.3\">x2_3</text>\n</g>\n<!-- x2_1&#45;&gt;x2_3 -->\n<g class=\"edge\" id=\"edge1\"><title>x2_1-&gt;x2_3</title>\n<path d=\"M152.499,-261.611C146.457,-249.382 138.087,-232.443 131.176,-218.456\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"134.238,-216.753 126.671,-209.339 127.963,-219.854 134.238,-216.753\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"155.097\" y=\"-231.8\">0.36</text>\n</g>\n<!-- x2_5 -->\n<g class=\"node\" id=\"node5\"><title>x2_5</title>\n<ellipse cx=\"193.597\" cy=\"-192\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"193.597\" y=\"-188.3\">x2_5</text>\n</g>\n<!-- x2_1&#45;&gt;x2_5 -->\n<g class=\"edge\" id=\"edge5\"><title>x2_1-&gt;x2_5</title>\n<path d=\"M167.117,-261.207C171.773,-249.216 178.126,-232.851 183.442,-219.157\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"186.735,-220.347 187.091,-209.758 180.209,-217.814 186.735,-220.347\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"192.097\" y=\"-231.8\">0.58</text>\n</g>\n<!-- x2_2 -->\n<g class=\"node\" id=\"node2\"><title>x2_2</title>\n<ellipse cx=\"42.5975\" cy=\"-279\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"42.5975\" y=\"-275.3\">x2_2</text>\n</g>\n<!-- x2_4 -->\n<g class=\"node\" id=\"node4\"><title>x2_4</title>\n<ellipse cx=\"93.5975\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"93.5975\" y=\"-101.3\">x2_4</text>\n</g>\n<!-- x2_2&#45;&gt;x2_4 -->\n<g class=\"edge\" id=\"edge2\"><title>x2_2-&gt;x2_4</title>\n<path d=\"M41.6929,-260.947C41.1119,-239.939 41.8545,-203.419 51.5975,-174 56.8279,-158.207 66.3718,-142.291 75.0415,-129.832\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"78.1,-131.575 81.1256,-121.422 72.4285,-127.472 78.1,-131.575\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"66.0975\" y=\"-188.3\">-0.72</text>\n</g>\n<!-- x2_2&#45;&gt;x2_5 -->\n<g class=\"edge\" id=\"edge6\"><title>x2_2-&gt;x2_5</title>\n<path d=\"M58.5004,-263.848C71.333,-252.966 90.2148,-238.127 108.597,-228 128.553,-217.006 135.566,-218.762 156.597,-210 158.408,-209.246 160.261,-208.455 162.126,-207.645\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"163.697,-210.777 171.407,-203.509 160.848,-204.383 163.697,-210.777\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"121.097\" y=\"-231.8\">0.14</text>\n</g>\n<!-- x3_1 -->\n<g class=\"node\" id=\"node6\"><title>x3_1</title>\n<ellipse cx=\"28.5975\" cy=\"-18\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"28.5975\" y=\"-14.3\">x3_1</text>\n</g>\n<!-- x2_2&#45;&gt;x3_1 -->\n<g class=\"edge\" id=\"edge7\"><title>x2_2-&gt;x3_1</title>\n<path d=\"M39.7911,-260.687C36.1223,-237.094 29.8934,-193.488 27.5975,-156 25.2741,-118.065 26.2407,-74.1114 27.2954,-46.4321\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"30.7973,-46.4576 27.7126,-36.3218 23.8032,-46.169 30.7973,-46.4576\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"42.0975\" y=\"-144.8\">-0.07</text>\n</g>\n<!-- x2_3&#45;&gt;x2_4 -->\n<g class=\"edge\" id=\"edge3\"><title>x2_3-&gt;x2_4</title>\n<path d=\"M105.231,-175.665C101.003,-169.891 96.8488,-163.015 94.5975,-156 92.2893,-148.808 91.4345,-140.729 91.3118,-133.198\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"94.8139,-133.143 91.5505,-123.064 87.8159,-132.978 94.8139,-133.143\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"107.097\" y=\"-144.8\">0.30</text>\n</g>\n<!-- x2_4&#45;&gt;x3_1 -->\n<g class=\"edge\" id=\"edge8\"><title>x2_4-&gt;x3_1</title>\n<path d=\"M81.6755,-88.4097C71.8715,-75.5891 57.8235,-57.2186 46.6422,-42.597\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"49.2987,-40.309 40.4439,-34.4915 43.7382,-44.5612 49.2987,-40.309\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"78.0975\" y=\"-57.8\">0.24</text>\n</g>\n<!-- x3_3 -->\n<g class=\"node\" id=\"node8\"><title>x3_3</title>\n<ellipse cx=\"119.597\" cy=\"-18\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"119.597\" y=\"-14.3\">x3_3</text>\n</g>\n<!-- x2_4&#45;&gt;x3_3 -->\n<g class=\"edge\" id=\"edge10\"><title>x2_4-&gt;x3_3</title>\n<path d=\"M98.7342,-87.2067C102.367,-75.332 107.311,-59.1684 111.474,-45.5567\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"114.893,-46.3448 114.471,-35.7584 108.199,-44.2972 114.893,-46.3448\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"123.097\" y=\"-57.8\">-0.31</text>\n</g>\n<!-- x2_5&#45;&gt;x2_4 -->\n<g class=\"edge\" id=\"edge4\"><title>x2_5-&gt;x2_4</title>\n<path d=\"M177.084,-176.964C161.017,-163.307 136.411,-142.391 118.068,-126.8\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"120.021,-123.867 110.135,-120.057 115.488,-129.2 120.021,-123.867\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"164.097\" y=\"-144.8\">-0.05</text>\n</g>\n<!-- x3_2 -->\n<g class=\"node\" id=\"node7\"><title>x3_2</title>\n<ellipse cx=\"180.597\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"180.597\" y=\"-101.3\">x3_2</text>\n</g>\n<!-- x2_5&#45;&gt;x3_2 -->\n<g class=\"edge\" id=\"edge9\"><title>x2_5-&gt;x3_2</title>\n<path d=\"M190.967,-173.799C189.187,-162.163 186.799,-146.548 184.763,-133.237\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"188.196,-132.531 183.224,-123.175 181.276,-133.59 188.196,-132.531\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"201.097\" y=\"-144.8\">0.47</text>\n</g>\n<!-- x3_2&#45;&gt;x3_3 -->\n<g class=\"edge\" id=\"edge11\"><title>x3_2-&gt;x3_3</title>\n<path d=\"M169.409,-88.4097C160.351,-75.7883 147.434,-57.7881 137.023,-43.2812\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"139.595,-40.8631 130.921,-34.7793 133.908,-44.9444 139.595,-40.8631\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"167.097\" y=\"-57.8\">0.17</text>\n</g>\n</g>\n</svg>"
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aLqMhEFoHEal"
      },
      "source": [
        "## Compute causal effects from each variable of the second group to that of the third"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "B8kuVE3ZHEal",
        "outputId": "53aa7cfd-d294-425c-bdf2-8e616168676c"
      },
      "source": [
        "# Compute causal effects from each variable of the second group to that of the third\n",
        "for set3_label in set3_labels:\n",
        "    for set2_label in set2_labels:\n",
        "\n",
        "        # Create the variable index\n",
        "        var2_index = X.columns.get_loc(set2_label)\n",
        "        var3_index = X.columns.get_loc(set3_label)\n",
        "        \n",
        "        # Obtain parents of each variable of the second group\n",
        "        parents = np.where(np.abs(set2_model.adjacency_matrix_[set2_labels.index(set2_label)]) > 0)[0]\n",
        "        parents = [X.columns.get_loc(set2_labels[idx]) for idx in parents]\n",
        "\n",
        "        # Create the set of explanatory variables\n",
        "        predictors = [var2_index]\n",
        "        predictors.extend(parents)\n",
        "        predictors.extend(set1_indices)\n",
        "\n",
        "        # If all the variables of the third group are continuous, peform linear regression\n",
        "        # If they are binary, perform logistic regression\n",
        "        if len(np.unique(X[set3_label])) != 2:\n",
        "            lr = LinearRegression()\n",
        "            lr.fit(X.iloc[:, predictors], X.iloc[:, var3_index])\n",
        "            effect = lr.coef_[0]\n",
        "        else:\n",
        "            lr = LogisticRegression(solver='liblinear')\n",
        "            lr.fit(X.iloc[:, predictors], X.iloc[:, var3_index])\n",
        "            X_intervened = X.copy()\n",
        "            X_intervened.iloc[:, var2_index] = X.iloc[:, var2_index].mean() # do(x=E(x))\n",
        "            p1 = lr.predict_proba(X_intervened.iloc[:, predictors])\n",
        "            X_intervened.iloc[:, var2_index] = X.iloc[:, var2_index].mean() + 1 # do(x=E(x)+1)\n",
        "            p2 = lr.predict_proba(X_intervened.iloc[:, predictors])\n",
        "            effect = p2[:, 1].mean() - p1[:, 1].mean() # The difference btw the two averages\n",
        "\n",
        "        print(f'{set2_label} ---> {set3_label} : {effect:.3f}')"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "x2_1 ---> x3_1 : -0.004\n",
            "x2_2 ---> x3_1 : -0.258\n",
            "x2_3 ---> x3_1 : 0.085\n",
            "x2_4 ---> x3_1 : 0.243\n",
            "x2_5 ---> x3_1 : 0.000\n",
            "x2_1 ---> x3_2 : 0.322\n",
            "x2_2 ---> x3_2 : 0.112\n",
            "x2_3 ---> x3_2 : -0.014\n",
            "x2_4 ---> x3_2 : -0.045\n",
            "x2_5 ---> x3_2 : 0.496\n",
            "x2_1 ---> x3_3 : 0.054\n",
            "x2_2 ---> x3_3 : 0.270\n",
            "x2_3 ---> x3_3 : -0.085\n",
            "x2_4 ---> x3_3 : -0.339\n",
            "x2_5 ---> x3_3 : 0.121\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KnB-PyGgrPmC"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}