{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.3"
    },
    "toc": {
      "base_numbering": 1,
      "nav_menu": {},
      "number_sections": true,
      "sideBar": true,
      "skip_h1_title": false,
      "title_cell": "Table of Contents",
      "title_sidebar": "Contents",
      "toc_cell": false,
      "toc_position": {},
      "toc_section_display": true,
      "toc_window_display": false
    },
    "colab": {
      "name": "Example to handle groups of varibles with known causal orders btw the groups for multiple datasets.ipynb",
      "provenance": [],
      "collapsed_sections": []
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6V7jVfw6nWln"
      },
      "source": [
        "# Example to handle groups of varibles with known causal orders btw the groups for multiple datasets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Kj0bkRKwnWlo"
      },
      "source": [
        "## Import necessary packages"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2019-09-09T02:01:39.097825Z",
          "start_time": "2019-09-09T02:01:33.841227Z"
        },
        "id": "9xT7kywznWlp",
        "outputId": "aa86da98-efb5-4332-d231-3270457979fd"
      },
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "import graphviz\n",
        "from sklearn.linear_model import LinearRegression, LogisticRegression\n",
        "import lingam\n",
        "from lingam.utils import make_dot, remove_effect, predict_adaptive_lasso\n",
        "\n",
        "from IPython.display import display_svg, SVG\n",
        "\n",
        "print([np.__version__, pd.__version__, graphviz.__version__, lingam.__version__])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['1.16.2', '0.24.2', '0.11.1', '1.4.0']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qhr83J1wnWlu"
      },
      "source": [
        "# Output setting\n",
        "np.set_printoptions(precision=3, suppress=True)\n",
        "np.random.seed(0)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FwEQGk2XnWlx"
      },
      "source": [
        "## Indicate which variable belong to which group"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HD_B12nFnWly"
      },
      "source": [
        "# More than one Datasets to be analyzed\n",
        "input_files = ['data1.csv', 'data3.csv']\n",
        "\n",
        "# Indicate which variable belong to which group\n",
        "set1_labels = ['x1_1', 'x1_2']\n",
        "set2_labels = ['x2_1', 'x2_2', 'x2_3', 'x2_4', 'x2_5']\n",
        "set3_labels = ['x3_1', 'x3_2', 'x3_3']\n",
        "labels = set1_labels + set2_labels + set3_labels"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6Fl2FU4rnWl1"
      },
      "source": [
        "## Load data"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "djASmKX2nWl1",
        "outputId": "bfd6d87d-27f7-4883-e4d2-f30b012a3ba7"
      },
      "source": [
        "X_list = [pd.read_csv(input_file) for input_file in input_files]\n",
        "\n",
        "# Check if the third group includes a binary data\n",
        "contain_bin_var_in_set3 = False\n",
        "for i, X in enumerate(X_list):\n",
        "    for set3_label in set3_labels:\n",
        "        if len(np.unique(X[set3_label])) == 2:\n",
        "            X[set3_label] = X[set3_label].astype(int)\n",
        "            print(f'\\x1b[31m変数セット3に2値の離散変数を含んでいます\\x1b[0m：{[i]} {input_files[i]}')\n",
        "            contain_bin_var_in_set3 = True\n",
        "            break\n",
        "        \n",
        "X_list[0].head()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>x1_1</th>\n",
              "      <th>x1_2</th>\n",
              "      <th>x2_1</th>\n",
              "      <th>x2_2</th>\n",
              "      <th>x2_3</th>\n",
              "      <th>x2_4</th>\n",
              "      <th>x2_5</th>\n",
              "      <th>x3_1</th>\n",
              "      <th>x3_2</th>\n",
              "      <th>x3_3</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>0.102727</td>\n",
              "      <td>0.205501</td>\n",
              "      <td>1.026972</td>\n",
              "      <td>-0.271038</td>\n",
              "      <td>-0.124869</td>\n",
              "      <td>-0.157580</td>\n",
              "      <td>0.273743</td>\n",
              "      <td>0.179777</td>\n",
              "      <td>0.048608</td>\n",
              "      <td>0.376362</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0.562784</td>\n",
              "      <td>-3.905674</td>\n",
              "      <td>0.232378</td>\n",
              "      <td>2.365468</td>\n",
              "      <td>0.235710</td>\n",
              "      <td>-2.103891</td>\n",
              "      <td>0.458637</td>\n",
              "      <td>-0.603726</td>\n",
              "      <td>-0.053263</td>\n",
              "      <td>2.230408</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>0.230076</td>\n",
              "      <td>-0.049555</td>\n",
              "      <td>0.162457</td>\n",
              "      <td>0.885359</td>\n",
              "      <td>-1.226624</td>\n",
              "      <td>-2.974060</td>\n",
              "      <td>0.730431</td>\n",
              "      <td>-0.061475</td>\n",
              "      <td>0.883855</td>\n",
              "      <td>1.862853</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>0.094054</td>\n",
              "      <td>0.540496</td>\n",
              "      <td>-0.644040</td>\n",
              "      <td>0.962975</td>\n",
              "      <td>-0.615114</td>\n",
              "      <td>-1.208096</td>\n",
              "      <td>0.031687</td>\n",
              "      <td>1.125484</td>\n",
              "      <td>-0.864481</td>\n",
              "      <td>-1.166307</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>-0.165689</td>\n",
              "      <td>-2.430977</td>\n",
              "      <td>0.152987</td>\n",
              "      <td>2.184269</td>\n",
              "      <td>-2.179047</td>\n",
              "      <td>-2.949728</td>\n",
              "      <td>-1.430591</td>\n",
              "      <td>-6.400437</td>\n",
              "      <td>-2.622068</td>\n",
              "      <td>-1.654538</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "       x1_1      x1_2      x2_1      x2_2      x2_3      x2_4      x2_5  \\\n",
              "0  0.102727  0.205501  1.026972 -0.271038 -0.124869 -0.157580  0.273743   \n",
              "1  0.562784 -3.905674  0.232378  2.365468  0.235710 -2.103891  0.458637   \n",
              "2  0.230076 -0.049555  0.162457  0.885359 -1.226624 -2.974060  0.730431   \n",
              "3  0.094054  0.540496 -0.644040  0.962975 -0.615114 -1.208096  0.031687   \n",
              "4 -0.165689 -2.430977  0.152987  2.184269 -2.179047 -2.949728 -1.430591   \n",
              "\n",
              "       x3_1      x3_2      x3_3  \n",
              "0  0.179777  0.048608  0.376362  \n",
              "1 -0.603726 -0.053263  2.230408  \n",
              "2 -0.061475  0.883855  1.862853  \n",
              "3  1.125484 -0.864481 -1.166307  \n",
              "4 -6.400437 -2.622068 -1.654538  "
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Eyb1qo-WnWl3"
      },
      "source": [
        "## Compute the residual when each variable of the second group on all the variables of the first group"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BS04gQNwnWl4"
      },
      "source": [
        "X2_resid_list = []\n",
        "\n",
        "for X in X_list:\n",
        "    # Obtain the column numbers of the variables of the first group\n",
        "    set1_indices = [X.columns.get_loc(label) for label in set1_labels]\n",
        "\n",
        "    # Compute the residual when each variable of the second group on all the variables of the first group\n",
        "    X_removed_set1 = remove_effect(X, set1_indices)\n",
        "\n",
        "    # Creat the residual dataset for the second group by the residuals computed just above\n",
        "    set2_indices = [X.columns.get_loc(label) for label in set2_labels]\n",
        "    X2_resid_list.append(X_removed_set1[:, set2_indices])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tj1CP0cJnWl6"
      },
      "source": [
        "## Perform LiNGAM on the residual dataset for the second group and draw the estimated causal graph"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RIf09ChQnWl6",
        "outputId": "253828ac-d3b2-4063-f87b-15d934d513cb"
      },
      "source": [
        "# Perform LiNGAM\n",
        "set2_model = lingam.MultiGroupDirectLiNGAM()\n",
        "set2_model.fit(X2_resid_list)\n",
        "\n",
        "# Draw the estimated causal graph\n",
        "for i, _ in enumerate(X_list):\n",
        "    print(f'{[i]} {input_files[i]}')\n",
        "    g = make_dot(set2_model.adjacency_matrices_[i], labels=set2_labels)\n",
        "    display_svg(SVG(g._repr_svg_()))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[0] data1.csv\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/svg+xml": "<svg height=\"218pt\" viewBox=\"0.00 0.00 191.60 218.00\" width=\"192pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 214)\">\n<title>%3</title>\n<polygon fill=\"white\" points=\"-4,4 -4,-214 187.597,-214 187.597,4 -4,4\" stroke=\"none\"/>\n<!-- x2_1 -->\n<g class=\"node\" id=\"node1\"><title>x2_1</title>\n<ellipse cx=\"58.5975\" cy=\"-192\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"58.5975\" y=\"-188.3\">x2_1</text>\n</g>\n<!-- x2_3 -->\n<g class=\"node\" id=\"node3\"><title>x2_3</title>\n<ellipse cx=\"28.5975\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"28.5975\" y=\"-101.3\">x2_3</text>\n</g>\n<!-- x2_1&#45;&gt;x2_3 -->\n<g class=\"edge\" id=\"edge1\"><title>x2_1-&gt;x2_3</title>\n<path d=\"M43.4047,-176.314C38.3829,-170.502 33.3739,-163.436 30.5975,-156 27.9536,-148.919 26.8758,-140.877 26.6049,-133.344\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"30.1059,-133.213 26.6801,-123.187 23.106,-133.161 30.1059,-133.213\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"43.0975\" y=\"-144.8\">0.35</text>\n</g>\n<!-- x2_5 -->\n<g class=\"node\" id=\"node5\"><title>x2_5</title>\n<ellipse cx=\"103.597\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103.597\" y=\"-101.3\">x2_5</text>\n</g>\n<!-- x2_1&#45;&gt;x2_5 -->\n<g class=\"edge\" id=\"edge5\"><title>x2_1-&gt;x2_5</title>\n<path d=\"M67.2739,-174.611C73.7483,-162.382 82.7159,-145.443 90.1207,-131.456\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"93.362,-132.814 94.9476,-122.339 87.1755,-129.539 93.362,-132.814\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"97.0975\" y=\"-144.8\">0.60</text>\n</g>\n<!-- x2_2 -->\n<g class=\"node\" id=\"node2\"><title>x2_2</title>\n<ellipse cx=\"145.597\" cy=\"-192\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.597\" y=\"-188.3\">x2_2</text>\n</g>\n<!-- x2_4 -->\n<g class=\"node\" id=\"node4\"><title>x2_4</title>\n<ellipse cx=\"103.597\" cy=\"-18\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103.597\" y=\"-14.3\">x2_4</text>\n</g>\n<!-- x2_2&#45;&gt;x2_4 -->\n<g class=\"edge\" id=\"edge2\"><title>x2_2-&gt;x2_4</title>\n<path d=\"M150.864,-174.077C152.387,-168.391 153.841,-161.975 154.597,-156 159.383,-118.173 139.591,-59.7507 136.597,-54 134.106,-49.2151 130.771,-44.6166 127.177,-40.4151\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"129.497,-37.7723 120.099,-32.8799 124.395,-42.5649 129.497,-37.7723\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"169.097\" y=\"-101.3\">-0.72</text>\n</g>\n<!-- x2_2&#45;&gt;x2_5 -->\n<g class=\"edge\" id=\"edge6\"><title>x2_2-&gt;x2_5</title>\n<path d=\"M137.499,-174.611C131.457,-162.382 123.087,-145.443 116.176,-131.456\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"119.238,-129.753 111.671,-122.339 112.963,-132.854 119.238,-129.753\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"140.097\" y=\"-144.8\">0.16</text>\n</g>\n<!-- x2_3&#45;&gt;x2_4 -->\n<g class=\"edge\" id=\"edge3\"><title>x2_3-&gt;x2_4</title>\n<path d=\"M42.006,-88.8037C53.5244,-75.7495 70.3015,-56.7354 83.4381,-41.8473\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"86.3132,-43.8788 90.305,-34.0647 81.0644,-39.2475 86.3132,-43.8788\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"83.0975\" y=\"-57.8\">0.30</text>\n</g>\n<!-- x2_5&#45;&gt;x2_4 -->\n<g class=\"edge\" id=\"edge4\"><title>x2_5-&gt;x2_4</title>\n<path d=\"M103.597,-86.799C103.597,-75.1626 103.597,-59.5479 103.597,-46.2368\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"107.098,-46.1754 103.597,-36.1754 100.098,-46.1755 107.098,-46.1754\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"118.097\" y=\"-57.8\">-0.05</text>\n</g>\n</g>\n</svg>"
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "[1] data3.csv\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/svg+xml": "<svg height=\"218pt\" viewBox=\"0.00 0.00 189.60 218.00\" width=\"190pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 214)\">\n<title>%3</title>\n<polygon fill=\"white\" points=\"-4,4 -4,-214 185.597,-214 185.597,4 -4,4\" stroke=\"none\"/>\n<!-- x2_1 -->\n<g class=\"node\" id=\"node1\"><title>x2_1</title>\n<ellipse cx=\"58.5975\" cy=\"-192\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"58.5975\" y=\"-188.3\">x2_1</text>\n</g>\n<!-- x2_3 -->\n<g class=\"node\" id=\"node3\"><title>x2_3</title>\n<ellipse cx=\"28.5975\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"28.5975\" y=\"-101.3\">x2_3</text>\n</g>\n<!-- x2_1&#45;&gt;x2_3 -->\n<g class=\"edge\" id=\"edge1\"><title>x2_1-&gt;x2_3</title>\n<path d=\"M43.4047,-176.314C38.3829,-170.502 33.3739,-163.436 30.5975,-156 27.9536,-148.919 26.8758,-140.877 26.6049,-133.344\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"30.1059,-133.213 26.6801,-123.187 23.106,-133.161 30.1059,-133.213\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"43.0975\" y=\"-144.8\">0.27</text>\n</g>\n<!-- x2_5 -->\n<g class=\"node\" id=\"node5\"><title>x2_5</title>\n<ellipse cx=\"103.597\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103.597\" y=\"-101.3\">x2_5</text>\n</g>\n<!-- x2_1&#45;&gt;x2_5 -->\n<g class=\"edge\" id=\"edge4\"><title>x2_1-&gt;x2_5</title>\n<path d=\"M67.2739,-174.611C73.7483,-162.382 82.7159,-145.443 90.1207,-131.456\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"93.362,-132.814 94.9476,-122.339 87.1755,-129.539 93.362,-132.814\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"97.0975\" y=\"-144.8\">0.45</text>\n</g>\n<!-- x2_2 -->\n<g class=\"node\" id=\"node2\"><title>x2_2</title>\n<ellipse cx=\"145.597\" cy=\"-192\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.597\" y=\"-188.3\">x2_2</text>\n</g>\n<!-- x2_4 -->\n<g class=\"node\" id=\"node4\"><title>x2_4</title>\n<ellipse cx=\"123.597\" cy=\"-18\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"123.597\" y=\"-14.3\">x2_4</text>\n</g>\n<!-- x2_2&#45;&gt;x2_4 -->\n<g class=\"edge\" id=\"edge2\"><title>x2_2-&gt;x2_4</title>\n<path d=\"M150.864,-174.077C152.387,-168.391 153.841,-161.975 154.597,-156 155.434,-149.386 155.383,-147.62 154.597,-141 150.609,-107.39 140.1,-69.7946 132.42,-45.3373\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"135.719,-44.1613 129.33,-35.7099 129.054,-46.3012 135.719,-44.1613\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"167.097\" y=\"-101.3\">-0.60</text>\n</g>\n<!-- x2_2&#45;&gt;x2_5 -->\n<g class=\"edge\" id=\"edge5\"><title>x2_2-&gt;x2_5</title>\n<path d=\"M137.499,-174.611C131.457,-162.382 123.087,-145.443 116.176,-131.456\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"119.238,-129.753 111.671,-122.339 112.963,-132.854 119.238,-129.753\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"140.097\" y=\"-144.8\">0.12</text>\n</g>\n<!-- x2_3&#45;&gt;x2_4 -->\n<g class=\"edge\" id=\"edge3\"><title>x2_3-&gt;x2_4</title>\n<path d=\"M44.7129,-89.5809C59.9539,-75.9442 82.9977,-55.3261 100.25,-39.89\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"102.601,-42.4825 107.72,-33.2061 97.9338,-37.2658 102.601,-42.4825\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"95.0975\" y=\"-57.8\">0.36</text>\n</g>\n</g>\n</svg>"
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CMWKk36fnWl7"
      },
      "source": [
        "Regress each variable of the third group on all the variables of the first and second groups and compute the residuals"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-l3UnpkFnWl8"
      },
      "source": [
        "X3_resid_list = []\n",
        "\n",
        "if contain_bin_var_in_set3:\n",
        "    print('\\x1b[31mSkip because the third group includes a binary variable\\x1b[0m')\n",
        "else:\n",
        "    for X in X_list:\n",
        "        # Obtain column numbers of the variables of the first and second groups\n",
        "        set12_indices = [X.columns.get_loc(label) for label in set1_labels+set2_labels]\n",
        "\n",
        "        # Regress each variable of the third group on all the variables of the first and second groups\n",
        "        X_removed_set12 = remove_effect(X, set12_indices)\n",
        "\n",
        "        # Create the residual dataset for the third group by the residuals computed just above\n",
        "        set3_indices = [X.columns.get_loc(label) for label in set3_labels]\n",
        "        X3_resid_list.append(X_removed_set12[:, set3_indices])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4nhOnU6nnWl9"
      },
      "source": [
        "## Combine the causal grpahs for the second and third groups and draw it"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Jmt-K7dSnWl9",
        "outputId": "5108813f-5f88-4f03-d3f0-cec726d8c79a"
      },
      "source": [
        "if contain_bin_var_in_set3:\n",
        "    print('\\x1b[31mSkip because the third group includes a binary variable\\x1b[0m')\n",
        "else:\n",
        "    # Perform LiNGAM\n",
        "    set3_model = lingam.MultiGroupDirectLiNGAM()\n",
        "    set3_model.fit(X3_resid_list)\n",
        "    \n",
        "    # Draw the estimated graph\n",
        "    for i, _ in enumerate(X_list):\n",
        "        print(f'{[i]} {input_files[i]}')\n",
        "        g = make_dot(set3_model.adjacency_matrices_[i], labels=set3_labels)\n",
        "        display_svg(SVG(g._repr_svg_()))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[0] data1.csv\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/svg+xml": "<svg height=\"131pt\" viewBox=\"0.00 0.00 140.19 131.00\" width=\"140pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 127)\">\n<title>%3</title>\n<polygon fill=\"white\" points=\"-4,4 -4,-127 136.195,-127 136.195,4 -4,4\" stroke=\"none\"/>\n<!-- x3_1 -->\n<g class=\"node\" id=\"node1\"><title>x3_1</title>\n<ellipse cx=\"28.5975\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"28.5975\" y=\"-101.3\">x3_1</text>\n</g>\n<!-- x3_2 -->\n<g class=\"node\" id=\"node2\"><title>x3_2</title>\n<ellipse cx=\"103.597\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103.597\" y=\"-101.3\">x3_2</text>\n</g>\n<!-- x3_3 -->\n<g class=\"node\" id=\"node3\"><title>x3_3</title>\n<ellipse cx=\"103.597\" cy=\"-18\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103.597\" y=\"-14.3\">x3_3</text>\n</g>\n<!-- x3_2&#45;&gt;x3_3 -->\n<g class=\"edge\" id=\"edge1\"><title>x3_2-&gt;x3_3</title>\n<path d=\"M103.597,-86.799C103.597,-75.1626 103.597,-59.5479 103.597,-46.2368\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"107.098,-46.1754 103.597,-36.1754 100.098,-46.1755 107.098,-46.1754\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"116.097\" y=\"-57.8\">0.17</text>\n</g>\n</g>\n</svg>"
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "[1] data3.csv\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/svg+xml": "<svg height=\"131pt\" viewBox=\"0.00 0.00 140.19 131.00\" width=\"140pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 127)\">\n<title>%3</title>\n<polygon fill=\"white\" points=\"-4,4 -4,-127 136.195,-127 136.195,4 -4,4\" stroke=\"none\"/>\n<!-- x3_1 -->\n<g class=\"node\" id=\"node1\"><title>x3_1</title>\n<ellipse cx=\"28.5975\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"28.5975\" y=\"-101.3\">x3_1</text>\n</g>\n<!-- x3_2 -->\n<g class=\"node\" id=\"node2\"><title>x3_2</title>\n<ellipse cx=\"103.597\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103.597\" y=\"-101.3\">x3_2</text>\n</g>\n<!-- x3_3 -->\n<g class=\"node\" id=\"node3\"><title>x3_3</title>\n<ellipse cx=\"103.597\" cy=\"-18\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103.597\" y=\"-14.3\">x3_3</text>\n</g>\n<!-- x3_2&#45;&gt;x3_3 -->\n<g class=\"edge\" id=\"edge1\"><title>x3_2-&gt;x3_3</title>\n<path d=\"M103.597,-86.799C103.597,-75.1626 103.597,-59.5479 103.597,-46.2368\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"107.098,-46.1754 103.597,-36.1754 100.098,-46.1755 107.098,-46.1754\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"116.097\" y=\"-57.8\">0.19</text>\n</g>\n</g>\n</svg>"
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iAf76a_8nWl_"
      },
      "source": [
        "## Compute causal effects from each variable of the second group to that of the third"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h3Fcly25nWmA",
        "outputId": "3965fdd1-0cbc-4257-db1d-2efbee0a7526"
      },
      "source": [
        "if contain_bin_var_in_set3:\n",
        "    print('\\x1b[31mSkip because the third group includes a binary variable\\x1b[0m')\n",
        "else:\n",
        "    for i, X in enumerate(X_list):\n",
        "        # Adjucency matrix for the all the variables\n",
        "        adj_matrix = np.zeros([X.shape[1], X.shape[1]])\n",
        "\n",
        "        # Update the adjuncy matrix using the causal graph estimated for the second group\n",
        "        set2_start_pos = len(set1_labels)\n",
        "        set2_end_pos = set2_start_pos + len(set2_labels)\n",
        "        adj_matrix[set2_start_pos:set2_end_pos, set2_start_pos:set2_end_pos] = set2_model.adjacency_matrices_[i]\n",
        "\n",
        "        # Update the adjuncy matrix using the causal graph estimated for the third group\n",
        "        set3_start_pos = len(set1_labels) + len(set2_labels)\n",
        "        set3_end_pos = set3_start_pos + len(set3_labels)\n",
        "        adj_matrix[set3_start_pos:set3_end_pos, set3_start_pos:set3_end_pos] = set3_model.adjacency_matrices_[i]\n",
        "\n",
        "        # Compute the connection strengths from each variable of the second group to that of the third\n",
        "        for j, idx in enumerate(set3_indices):\n",
        "\n",
        "            # Obtain parents of each variable of the third group\n",
        "            set3_parents = np.where(np.abs(set3_model.adjacency_matrices_[i][j]) > 0)[0]\n",
        "            set3_parents = [X.columns.get_loc(set3_labels[idx]) for idx in set3_parents]\n",
        "\n",
        "            # Create the set of explanatory variables\n",
        "            predictors = []\n",
        "            predictors.extend(set2_indices) # All the variables of the second group\n",
        "            predictors.extend(set3_parents) # Parents in the third group\n",
        "\n",
        "            # Pruning\n",
        "            coefs = predict_adaptive_lasso(X_removed_set1, predictors, idx)\n",
        "            adj_matrix[idx, set2_start_pos:set2_end_pos] = coefs[:len(set2_indices)]\n",
        "\n",
        "        # Remove a part of the adjacency matrix corresponding the variables of the first group\n",
        "        adj_matrix_set23 = adj_matrix[set2_start_pos:, set2_start_pos:]\n",
        "        g = make_dot(adj_matrix_set23, labels=set2_labels+set3_labels)\n",
        "        print(f'{[i]} {input_files[i]}')\n",
        "        display_svg(SVG(g._repr_svg_()))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[0] data1.csv\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/svg+xml": "<svg height=\"305pt\" viewBox=\"0.00 0.00 245.60 305.00\" width=\"246pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 301)\">\n<title>%3</title>\n<polygon fill=\"white\" points=\"-4,4 -4,-301 241.597,-301 241.597,4 -4,4\" stroke=\"none\"/>\n<!-- x2_1 -->\n<g class=\"node\" id=\"node1\"><title>x2_1</title>\n<ellipse cx=\"53.5975\" cy=\"-279\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"53.5975\" y=\"-275.3\">x2_1</text>\n</g>\n<!-- x2_3 -->\n<g class=\"node\" id=\"node3\"><title>x2_3</title>\n<ellipse cx=\"28.5975\" cy=\"-192\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"28.5975\" y=\"-188.3\">x2_3</text>\n</g>\n<!-- x2_1&#45;&gt;x2_3 -->\n<g class=\"edge\" id=\"edge1\"><title>x2_1-&gt;x2_3</title>\n<path d=\"M40.8604,-262.562C36.816,-256.778 32.8202,-249.923 30.5975,-243 28.2869,-235.804 27.3158,-227.724 27.0406,-220.192\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"30.5406,-220.062 27.0479,-210.059 23.5406,-220.057 30.5406,-220.062\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"43.0975\" y=\"-231.8\">0.35</text>\n</g>\n<!-- x2_5 -->\n<g class=\"node\" id=\"node5\"><title>x2_5</title>\n<ellipse cx=\"103.597\" cy=\"-192\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"103.597\" y=\"-188.3\">x2_5</text>\n</g>\n<!-- x2_1&#45;&gt;x2_5 -->\n<g class=\"edge\" id=\"edge5\"><title>x2_1-&gt;x2_5</title>\n<path d=\"M62.003,-261.588C67.163,-251.733 74.0229,-239.012 80.5975,-228 82.6766,-224.518 84.9417,-220.891 87.213,-217.349\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"90.229,-219.131 92.7693,-208.846 84.3691,-215.302 90.229,-219.131\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"93.0975\" y=\"-231.8\">0.60</text>\n</g>\n<!-- x2_2 -->\n<g class=\"node\" id=\"node2\"><title>x2_2</title>\n<ellipse cx=\"168.597\" cy=\"-279\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"168.597\" y=\"-275.3\">x2_2</text>\n</g>\n<!-- x2_4 -->\n<g class=\"node\" id=\"node4\"><title>x2_4</title>\n<ellipse cx=\"76.5975\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"76.5975\" y=\"-101.3\">x2_4</text>\n</g>\n<!-- x2_2&#45;&gt;x2_4 -->\n<g class=\"edge\" id=\"edge2\"><title>x2_2-&gt;x2_4</title>\n<path d=\"M171.869,-260.959C176.226,-233.038 180.632,-177.053 154.597,-141 151.629,-136.889 129.608,-127.196 109.719,-119.05\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"110.802,-115.713 100.219,-115.207 108.176,-122.202 110.802,-115.713\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"190.097\" y=\"-188.3\">-0.72</text>\n</g>\n<!-- x2_2&#45;&gt;x2_5 -->\n<g class=\"edge\" id=\"edge6\"><title>x2_2-&gt;x2_5</title>\n<path d=\"M150.994,-264.55C143.792,-258.483 135.729,-250.909 129.597,-243 123.91,-235.664 118.867,-226.898 114.772,-218.769\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"117.884,-217.163 110.419,-209.644 111.566,-220.177 117.884,-217.163\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"142.097\" y=\"-231.8\">0.16</text>\n</g>\n<!-- x3_3 -->\n<g class=\"node\" id=\"node8\"><title>x3_3</title>\n<ellipse cx=\"168.597\" cy=\"-18\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"168.597\" y=\"-14.3\">x3_3</text>\n</g>\n<!-- x2_2&#45;&gt;x3_3 -->\n<g class=\"edge\" id=\"edge9\"><title>x2_2-&gt;x3_3</title>\n<path d=\"M180.91,-262.612C190.565,-249.496 203.235,-229.691 208.597,-210 219.043,-171.639 209.133,-90.6542 194.597,-54 192.997,-49.9628 190.807,-45.9772 188.374,-42.2229\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"191.105,-40.0266 182.414,-33.9669 185.429,-44.1237 191.105,-40.0266\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"225.097\" y=\"-144.8\">0.04</text>\n</g>\n<!-- x2_3&#45;&gt;x2_4 -->\n<g class=\"edge\" id=\"edge3\"><title>x2_3-&gt;x2_4</title>\n<path d=\"M32.5469,-174.076C35.3197,-164.039 39.61,-151.306 45.5975,-141 48.1154,-136.666 51.2222,-132.386 54.4874,-128.387\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"57.3147,-130.471 61.2742,-120.644 52.0506,-125.857 57.3147,-130.471\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"58.0975\" y=\"-144.8\">0.30</text>\n</g>\n<!-- x3_1 -->\n<g class=\"node\" id=\"node6\"><title>x3_1</title>\n<ellipse cx=\"68.5975\" cy=\"-18\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"68.5975\" y=\"-14.3\">x3_1</text>\n</g>\n<!-- x2_4&#45;&gt;x3_1 -->\n<g class=\"edge\" id=\"edge7\"><title>x2_4-&gt;x3_1</title>\n<path d=\"M72.5686,-87.0044C71.3885,-81.3141 70.2428,-74.9137 69.5975,-69 68.8001,-61.6933 68.4207,-53.7513 68.2701,-46.3933\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"71.7684,-46.1453 68.1911,-36.1727 64.7686,-46.1995 71.7684,-46.1453\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"82.0975\" y=\"-57.8\">0.30</text>\n</g>\n<!-- x2_4&#45;&gt;x3_3 -->\n<g class=\"edge\" id=\"edge10\"><title>x2_4-&gt;x3_3</title>\n<path d=\"M92.1836,-89.5483C102.964,-79.557 117.644,-65.9621 130.597,-54 135.441,-49.5275 140.635,-44.7381 145.58,-40.1824\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"148.074,-42.6439 153.059,-33.295 143.332,-37.4947 148.074,-42.6439\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"145.097\" y=\"-57.8\">-0.27</text>\n</g>\n<!-- x2_5&#45;&gt;x2_4 -->\n<g class=\"edge\" id=\"edge4\"><title>x2_5-&gt;x2_4</title>\n<path d=\"M89.4028,-176.173C84.7304,-170.346 80.0954,-163.306 77.5975,-156 75.1537,-148.853 74.2554,-140.789 74.1344,-133.256\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"77.637,-133.202 74.4009,-123.113 70.6394,-133.018 77.637,-133.202\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"92.0975\" y=\"-144.8\">-0.05</text>\n</g>\n<!-- x3_2 -->\n<g class=\"node\" id=\"node7\"><title>x3_2</title>\n<ellipse cx=\"159.597\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"159.597\" y=\"-101.3\">x3_2</text>\n</g>\n<!-- x2_5&#45;&gt;x3_2 -->\n<g class=\"edge\" id=\"edge8\"><title>x2_5-&gt;x3_2</title>\n<path d=\"M114.131,-175.012C122.422,-162.427 134.115,-144.679 143.567,-130.333\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"146.533,-132.192 149.112,-121.916 140.687,-128.341 146.533,-132.192\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"148.097\" y=\"-144.8\">0.53</text>\n</g>\n<!-- x3_2&#45;&gt;x3_3 -->\n<g class=\"edge\" id=\"edge11\"><title>x3_2-&gt;x3_3</title>\n<path d=\"M161.419,-86.799C162.651,-75.1626 164.304,-59.5479 165.714,-46.2368\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"169.206,-46.4884 166.779,-36.1754 162.245,-45.7513 169.206,-46.4884\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"178.097\" y=\"-57.8\">0.17</text>\n</g>\n</g>\n</svg>"
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "[1] data3.csv\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/svg+xml": "<svg height=\"305pt\" viewBox=\"0.00 0.00 227.73 305.00\" width=\"228pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 301)\">\n<title>%3</title>\n<polygon fill=\"white\" points=\"-4,4 -4,-301 223.728,-301 223.728,4 -4,4\" stroke=\"none\"/>\n<!-- x2_1 -->\n<g class=\"node\" id=\"node1\"><title>x2_1</title>\n<ellipse cx=\"169.13\" cy=\"-279\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"169.13\" y=\"-275.3\">x2_1</text>\n</g>\n<!-- x2_3 -->\n<g class=\"node\" id=\"node3\"><title>x2_3</title>\n<ellipse cx=\"191.13\" cy=\"-192\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191.13\" y=\"-188.3\">x2_3</text>\n</g>\n<!-- x2_1&#45;&gt;x2_3 -->\n<g class=\"edge\" id=\"edge1\"><title>x2_1-&gt;x2_3</title>\n<path d=\"M173.477,-261.207C176.53,-249.409 180.679,-233.378 184.188,-219.822\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"187.6,-220.607 186.718,-210.049 180.823,-218.853 187.6,-220.607\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"193.63\" y=\"-231.8\">0.27</text>\n</g>\n<!-- x2_5 -->\n<g class=\"node\" id=\"node5\"><title>x2_5</title>\n<ellipse cx=\"47.1303\" cy=\"-192\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"47.1303\" y=\"-188.3\">x2_5</text>\n</g>\n<!-- x2_1&#45;&gt;x2_5 -->\n<g class=\"edge\" id=\"edge4\"><title>x2_1-&gt;x2_5</title>\n<path d=\"M150.601,-265.09C130.268,-250.924 97.4371,-228.05 74.2166,-211.872\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"76.0396,-208.876 65.8338,-206.031 72.038,-214.619 76.0396,-208.876\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"128.63\" y=\"-231.8\">0.45</text>\n</g>\n<!-- x2_2 -->\n<g class=\"node\" id=\"node2\"><title>x2_2</title>\n<ellipse cx=\"47.1303\" cy=\"-279\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"47.1303\" y=\"-275.3\">x2_2</text>\n</g>\n<!-- x2_4 -->\n<g class=\"node\" id=\"node4\"><title>x2_4</title>\n<ellipse cx=\"122.13\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"122.13\" y=\"-101.3\">x2_4</text>\n</g>\n<!-- x2_2&#45;&gt;x2_4 -->\n<g class=\"edge\" id=\"edge2\"><title>x2_2-&gt;x2_4</title>\n<path d=\"M61.4391,-263.22C66.6042,-257.292 72.1387,-250.171 76.1303,-243 95.9318,-207.427 109.112,-161.906 116.11,-133.342\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"119.594,-133.818 118.49,-123.28 112.782,-132.206 119.594,-133.818\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"119.63\" y=\"-188.3\">-0.60</text>\n</g>\n<!-- x2_2&#45;&gt;x2_5 -->\n<g class=\"edge\" id=\"edge5\"><title>x2_2-&gt;x2_5</title>\n<path d=\"M47.1303,-260.799C47.1303,-249.163 47.1303,-233.548 47.1303,-220.237\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"50.6304,-220.175 47.1303,-210.175 43.6304,-220.175 50.6304,-220.175\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"59.6303\" y=\"-231.8\">0.12</text>\n</g>\n<!-- x3_3 -->\n<g class=\"node\" id=\"node8\"><title>x3_3</title>\n<ellipse cx=\"35.1303\" cy=\"-18\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"35.1303\" y=\"-14.3\">x3_3</text>\n</g>\n<!-- x2_2&#45;&gt;x3_3 -->\n<g class=\"edge\" id=\"edge8\"><title>x2_2-&gt;x3_3</title>\n<path d=\"M35.4259,-262.512C26.2488,-249.336 14.2092,-229.5 9.13027,-210 -4.64836,-157.098 -0.796519,-140.758 9.13027,-87 11.8105,-72.4853 17.456,-57.0625 22.8035,-44.5806\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"26.0779,-45.832 26.966,-35.2744 19.688,-42.9739 26.0779,-45.832\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"13.6303\" y=\"-144.8\">0.04</text>\n</g>\n<!-- x2_3&#45;&gt;x2_4 -->\n<g class=\"edge\" id=\"edge3\"><title>x2_3-&gt;x2_4</title>\n<path d=\"M178.475,-175.41C168.067,-162.589 153.155,-144.219 141.285,-129.597\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"143.726,-127.05 134.706,-121.491 138.291,-131.461 143.726,-127.05\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"172.63\" y=\"-144.8\">0.36</text>\n</g>\n<!-- x3_1 -->\n<g class=\"node\" id=\"node6\"><title>x3_1</title>\n<ellipse cx=\"135.13\" cy=\"-18\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"135.13\" y=\"-14.3\">x3_1</text>\n</g>\n<!-- x2_4&#45;&gt;x3_1 -->\n<g class=\"edge\" id=\"edge6\"><title>x2_4-&gt;x3_1</title>\n<path d=\"M124.761,-86.799C126.541,-75.1626 128.929,-59.5479 130.965,-46.2368\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"134.451,-46.5897 132.503,-36.1754 127.532,-45.5313 134.451,-46.5897\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"142.63\" y=\"-57.8\">0.30</text>\n</g>\n<!-- x2_4&#45;&gt;x3_3 -->\n<g class=\"edge\" id=\"edge9\"><title>x2_4-&gt;x3_3</title>\n<path d=\"M107.372,-89.5809C93.6234,-76.1485 72.9417,-55.9422 57.2253,-40.5871\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"59.5598,-37.9747 49.961,-33.4898 54.6679,-42.9817 59.5598,-37.9747\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"98.6303\" y=\"-57.8\">-0.27</text>\n</g>\n<!-- x3_2 -->\n<g class=\"node\" id=\"node7\"><title>x3_2</title>\n<ellipse cx=\"47.1303\" cy=\"-105\" fill=\"none\" rx=\"28.6953\" ry=\"18\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"47.1303\" y=\"-101.3\">x3_2</text>\n</g>\n<!-- x2_5&#45;&gt;x3_2 -->\n<g class=\"edge\" id=\"edge7\"><title>x2_5-&gt;x3_2</title>\n<path d=\"M47.1303,-173.799C47.1303,-162.163 47.1303,-146.548 47.1303,-133.237\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"50.6304,-133.175 47.1303,-123.175 43.6304,-133.175 50.6304,-133.175\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"59.6303\" y=\"-144.8\">0.53</text>\n</g>\n<!-- x3_2&#45;&gt;x3_3 -->\n<g class=\"edge\" id=\"edge10\"><title>x3_2-&gt;x3_3</title>\n<path d=\"M40.7683,-87.1457C38.9119,-81.4648 37.1181,-75.0343 36.1303,-69 34.9429,-61.7464 34.4315,-53.8225 34.2788,-46.4647\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"37.7788,-46.2325 34.2748,-36.2339 30.7788,-46.2353 37.7788,-46.2325\" stroke=\"black\"/>\n<text font-family=\"Times New Roman,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"48.6303\" y=\"-57.8\">0.19</text>\n</g>\n</g>\n</svg>"
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4jz_IDrOnWmB"
      },
      "source": [
        "## Compute causal effects from each variable of the second group to that of the third"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xKkhua-DnWmC",
        "outputId": "c0b93995-966b-4e4f-fba1-0cbeb44ad187"
      },
      "source": [
        "for i, X in enumerate(X_list):\n",
        "    print(f'{input_files[i]}')\n",
        "\n",
        "    # Compute causal effects from each variable of the second group to that of the third\n",
        "    for set3_label in set3_labels:\n",
        "        for set2_label in set2_labels:\n",
        "\n",
        "            # Create the variable index\n",
        "            var2_index = X.columns.get_loc(set2_label)\n",
        "            var3_index = X.columns.get_loc(set3_label)\n",
        "\n",
        "            # Obtain parents of each variable of the second group\n",
        "            parents = np.where(np.abs(set2_model.adjacency_matrices_[i][set2_labels.index(set2_label)]) > 0)[0]\n",
        "            parents = [X.columns.get_loc(set2_labels[idx]) for idx in parents]\n",
        "\n",
        "            # Create the set of explanatory variables\n",
        "            predictors = [var2_index]\n",
        "            predictors.extend(parents)\n",
        "            predictors.extend(set1_indices)\n",
        "\n",
        "            # If all the variables of the third group are continuous, peform linear regression\n",
        "            # If they are binary, perform logistic regression\n",
        "            if len(np.unique(X[set3_label])) != 2:\n",
        "                lr = LinearRegression()\n",
        "                lr.fit(X.iloc[:, predictors], X.iloc[:, var3_index])\n",
        "                effect = lr.coef_[0]\n",
        "            else:\n",
        "                lr = LogisticRegression(solver='liblinear')\n",
        "                lr.fit(X.iloc[:, predictors], X.iloc[:, var3_index])\n",
        "                X_intervened = X.copy()\n",
        "                X_intervened.iloc[:, var2_index] = X.iloc[:, var2_index].mean() # do(x=E(x))\n",
        "                p1 = lr.predict_proba(X_intervened.iloc[:, predictors])\n",
        "                X_intervened.iloc[:, var2_index] = X.iloc[:, var2_index].mean() + 1 # do(x=E(x)+1)\n",
        "                p2 = lr.predict_proba(X_intervened.iloc[:, predictors])\n",
        "                effect = p2[:, 1].mean() - p1[:, 1].mean() # The difference btw the two averages\n",
        "\n",
        "            print(f'{set2_label} ---> {set3_label} : {effect:.3f}')"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "data1.csv\n",
            "x2_1 ---> x3_1 : -0.004\n",
            "x2_2 ---> x3_1 : -0.258\n",
            "x2_3 ---> x3_1 : 0.085\n",
            "x2_4 ---> x3_1 : 0.243\n",
            "x2_5 ---> x3_1 : 0.000\n",
            "x2_1 ---> x3_2 : 0.322\n",
            "x2_2 ---> x3_2 : 0.112\n",
            "x2_3 ---> x3_2 : -0.014\n",
            "x2_4 ---> x3_2 : -0.045\n",
            "x2_5 ---> x3_2 : 0.496\n",
            "x2_1 ---> x3_3 : 0.054\n",
            "x2_2 ---> x3_3 : 0.270\n",
            "x2_3 ---> x3_3 : -0.085\n",
            "x2_4 ---> x3_3 : -0.339\n",
            "x2_5 ---> x3_3 : 0.121\n",
            "data3.csv\n",
            "x2_1 ---> x3_1 : 0.070\n",
            "x2_2 ---> x3_1 : -0.160\n",
            "x2_3 ---> x3_1 : 0.101\n",
            "x2_4 ---> x3_1 : 0.396\n",
            "x2_5 ---> x3_1 : 0.031\n",
            "x2_1 ---> x3_2 : 0.255\n",
            "x2_2 ---> x3_2 : 0.033\n",
            "x2_3 ---> x3_2 : 0.033\n",
            "x2_4 ---> x3_2 : 0.016\n",
            "x2_5 ---> x3_2 : 0.588\n",
            "x2_1 ---> x3_3 : 0.048\n",
            "x2_2 ---> x3_3 : 0.258\n",
            "x2_3 ---> x3_3 : -0.101\n",
            "x2_4 ---> x3_3 : -0.295\n",
            "x2_5 ---> x3_3 : 0.130\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3TFhfRNknWmD"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}