{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Usecase: apply SCM on regression tasks generated by sklearn\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Prepare environment\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Set up the runtime\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Project directory: /home/xj265/phd/codebase/Euphratica/Euphratica-dev\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "project_dir = os.getcwd()\n",
    "while not os.path.exists(os.path.join(project_dir, \".git\")):\n",
    "    project_dir = os.path.dirname(project_dir)\n",
    "print(f\"Project directory: {project_dir}\")\n",
    "sys.path.insert(0, project_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Import customised libraries\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import lingam\n",
    "import numpy as np\n",
    "from lingam.utils import make_dot\n",
    "from sklearn.datasets import make_regression\n",
    "from tabeval.plugins import Plugins"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Specify the results of interest"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Generate the data with known causal structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((100000, 10), (100000,))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X, y, coef = make_regression(\n",
    "    n_samples=100000,\n",
    "    n_features=10,\n",
    "    n_informative=5,\n",
    "    noise=0.1,\n",
    "    coef=True,\n",
    "    random_state=42,\n",
    ")\n",
    "X.shape, y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Export results\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Fit the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<lingam.direct_lingam.DirectLiNGAM at 0x721657384790>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = lingam.DirectLiNGAM()\n",
    "model.fit(np.concatenate([X, y.reshape(-1, 1)], axis=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Compare the causal roots and informative feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.        , 71.5561938 , 90.17260756,  0.        ,  0.        ,\n",
       "       21.14662516,  0.        , 56.03943168, 84.91182137,  0.        ])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coef"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 8.1.0 (20230707.2238)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"757pt\" height=\"223pt\"\n",
       " viewBox=\"0.00 0.00 757.00 222.50\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 218.5)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-218.5 753,-218.5 753,4 -4,4\"/>\n",
       "<!-- x0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>x0</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-196.5\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"27\" y=\"-191.82\" font-family=\"Times,serif\" font-size=\"14.00\">x0</text>\n",
       "</g>\n",
       "<!-- x1 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>x1</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"239\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"239\" y=\"-13.32\" font-family=\"Times,serif\" font-size=\"14.00\">x1</text>\n",
       "</g>\n",
       "<!-- x2 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>x2</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"99\" cy=\"-196.5\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"99\" y=\"-191.82\" font-family=\"Times,serif\" font-size=\"14.00\">x2</text>\n",
       "</g>\n",
       "<!-- x2&#45;&gt;x1 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>x2&#45;&gt;x1</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M98.37,-178.21C98.36,-155.88 101.21,-116.64 119,-89.25 138.71,-58.91 176.46,-40.02 204.38,-29.57\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"205.28,-32.61 213.54,-25.97 202.95,-26.01 205.28,-32.61\"/>\n",
       "<text text-anchor=\"middle\" x=\"137\" y=\"-102.58\" font-family=\"Times,serif\" font-size=\"14.00\">&#45;1.26</text>\n",
       "</g>\n",
       "<!-- x8 -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>x8</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"314\" cy=\"-107.25\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"314\" y=\"-102.58\" font-family=\"Times,serif\" font-size=\"14.00\">x8</text>\n",
       "</g>\n",
       "<!-- x2&#45;&gt;x8 -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>x2&#45;&gt;x8</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M122.88,-187.37C131.95,-184.41 142.4,-181.14 152,-178.5 186.3,-169.06 198.48,-176.99 230,-160.5 240.11,-155.21 239.69,-149.84 249,-143.25 259.42,-135.88 271.56,-128.96 282.51,-123.24\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"283.81,-125.99 291.15,-118.35 280.65,-119.75 283.81,-125.99\"/>\n",
       "<text text-anchor=\"middle\" x=\"267\" y=\"-147.2\" font-family=\"Times,serif\" font-size=\"14.00\">&#45;0.62</text>\n",
       "</g>\n",
       "<!-- x10 -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>x10</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"239\" cy=\"-107.25\" rx=\"29.64\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"239\" y=\"-102.58\" font-family=\"Times,serif\" font-size=\"14.00\">x10</text>\n",
       "</g>\n",
       "<!-- x2&#45;&gt;x10 -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>x2&#45;&gt;x10</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M108.01,-179.22C115.16,-167.67 126.17,-152.62 139.5,-143.25 160.72,-128.34 170.56,-133.93 195,-125.25 198.14,-124.14 201.39,-122.95 204.64,-121.73\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"205.78,-124.67 213.88,-117.84 203.29,-118.12 205.78,-124.67\"/>\n",
       "<text text-anchor=\"middle\" x=\"159.25\" y=\"-147.2\" font-family=\"Times,serif\" font-size=\"14.00\">90.52</text>\n",
       "</g>\n",
       "<!-- x3 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>x3</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"506\" cy=\"-196.5\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"506\" y=\"-191.82\" font-family=\"Times,serif\" font-size=\"14.00\">x3</text>\n",
       "</g>\n",
       "<!-- x4 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>x4</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"578\" cy=\"-196.5\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"578\" y=\"-191.82\" font-family=\"Times,serif\" font-size=\"14.00\">x4</text>\n",
       "</g>\n",
       "<!-- x5 -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>x5</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"188\" cy=\"-196.5\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"188\" y=\"-191.82\" font-family=\"Times,serif\" font-size=\"14.00\">x5</text>\n",
       "</g>\n",
       "<!-- x5&#45;&gt;x1 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>x5&#45;&gt;x1</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M178.36,-179.44C166.89,-158.12 150.84,-119.61 164,-89.25 173.46,-67.43 193.41,-49.28 210.35,-36.9\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"212.16,-39.21 218.37,-30.63 208.16,-33.47 212.16,-39.21\"/>\n",
       "<text text-anchor=\"middle\" x=\"182\" y=\"-102.58\" font-family=\"Times,serif\" font-size=\"14.00\">&#45;0.30</text>\n",
       "</g>\n",
       "<!-- x5&#45;&gt;x8 -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>x5&#45;&gt;x8</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M214.49,-192.09C236.72,-187.99 268.1,-179.17 289,-160.5 296.44,-153.86 301.87,-144.52 305.72,-135.57\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"309.3,-137.02 309.56,-126.43 302.75,-134.54 309.3,-137.02\"/>\n",
       "<text text-anchor=\"middle\" x=\"319\" y=\"-147.2\" font-family=\"Times,serif\" font-size=\"14.00\">&#45;0.15</text>\n",
       "</g>\n",
       "<!-- x5&#45;&gt;x10 -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>x5&#45;&gt;x10</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M184.78,-178.51C183.61,-167.78 183.76,-153.97 189.5,-143.25 193.52,-135.75 199.94,-129.49 206.84,-124.45\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"208.22,-127.08 214.7,-118.7 204.4,-121.21 208.22,-127.08\"/>\n",
       "<text text-anchor=\"middle\" x=\"209.25\" y=\"-147.2\" font-family=\"Times,serif\" font-size=\"14.00\">21.27</text>\n",
       "</g>\n",
       "<!-- x6 -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>x6</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"650\" cy=\"-196.5\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"650\" y=\"-191.82\" font-family=\"Times,serif\" font-size=\"14.00\">x6</text>\n",
       "</g>\n",
       "<!-- x7 -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>x7</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"434\" cy=\"-196.5\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"434\" y=\"-191.82\" font-family=\"Times,serif\" font-size=\"14.00\">x7</text>\n",
       "</g>\n",
       "<!-- x7&#45;&gt;x1 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>x7&#45;&gt;x1</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M449.86,-181.8C460.01,-171.3 470.07,-156.41 463,-143.25 424.86,-72.31 328.56,-39.56 275.49,-26.46\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"276.4,-22.85 265.86,-23.96 274.8,-29.66 276.4,-22.85\"/>\n",
       "<text text-anchor=\"middle\" x=\"468\" y=\"-102.58\" font-family=\"Times,serif\" font-size=\"14.00\">&#45;0.78</text>\n",
       "</g>\n",
       "<!-- x7&#45;&gt;x8 -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>x7&#45;&gt;x8</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M430.58,-178.38C427.49,-167.04 421.87,-152.58 412,-143.25 395.39,-127.54 371.25,-118.81 351.09,-113.99\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"351.88,-110.37 341.37,-111.7 350.41,-117.22 351.88,-110.37\"/>\n",
       "<text text-anchor=\"middle\" x=\"441\" y=\"-147.2\" font-family=\"Times,serif\" font-size=\"14.00\">&#45;0.39</text>\n",
       "</g>\n",
       "<!-- x7&#45;&gt;x10 -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>x7&#45;&gt;x10</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M412.76,-184.96C399.57,-178.25 382.35,-169.22 367.5,-160.5 355.38,-153.39 353.79,-149.08 341,-143.25 314.5,-131.17 305.41,-135.09 278,-125.25 276.02,-124.54 274,-123.78 271.96,-122.98\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"273.47,-119.4 262.89,-118.83 270.81,-125.88 273.47,-119.4\"/>\n",
       "<text text-anchor=\"middle\" x=\"387.25\" y=\"-147.2\" font-family=\"Times,serif\" font-size=\"14.00\">56.06</text>\n",
       "</g>\n",
       "<!-- x8&#45;&gt;x1 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>x8&#45;&gt;x1</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M303.35,-90.66C295.83,-80.01 285.3,-65.74 275,-54 271.05,-49.5 266.62,-44.89 262.26,-40.55\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"265.22,-38.58 255.6,-34.14 260.35,-43.61 265.22,-38.58\"/>\n",
       "<text text-anchor=\"middle\" x=\"306\" y=\"-57.95\" font-family=\"Times,serif\" font-size=\"14.00\">&#45;1.19</text>\n",
       "</g>\n",
       "<!-- x9 -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>x9</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"722\" cy=\"-196.5\" rx=\"27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"722\" y=\"-191.82\" font-family=\"Times,serif\" font-size=\"14.00\">x9</text>\n",
       "</g>\n",
       "<!-- x10&#45;&gt;x1 -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>x10&#45;&gt;x1</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M239,-89.01C239,-77.06 239,-60.88 239,-47.08\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"242.5,-47.2 239,-37.2 235.5,-47.2 242.5,-47.2\"/>\n",
       "<text text-anchor=\"middle\" x=\"254.75\" y=\"-57.95\" font-family=\"Times,serif\" font-size=\"14.00\">0.01</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x72165740ffd0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_dot(model.adjacency_matrix_)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
