{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ebe3b3ab",
   "metadata": {},
   "source": [
    "## Example with Custom Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1abba87b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.7/site-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n",
      "  import pandas.util.testing as tm\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.080419</td>\n",
       "      <td>0.209818</td>\n",
       "      <td>1.714534</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-2.072334</td>\n",
       "      <td>-1.348848</td>\n",
       "      <td>-3.314112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.301027</td>\n",
       "      <td>1.155196</td>\n",
       "      <td>1.714924</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.802104</td>\n",
       "      <td>0.857641</td>\n",
       "      <td>2.054792</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.408486</td>\n",
       "      <td>0.814797</td>\n",
       "      <td>2.950110</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         x1        x2        x3\n",
       "0  1.080419  0.209818  1.714534\n",
       "1 -2.072334 -1.348848 -3.314112\n",
       "2  1.301027  1.155196  1.714924\n",
       "3  1.802104  0.857641  2.054792\n",
       "4  0.408486  0.814797  2.950110"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np \n",
    "import pandas as pd\n",
    "import networkx as nx \n",
    "from model.diffusion import create_model_from_graph\n",
    "import dowhy.gcm as cy\n",
    "from dowhy.gcm  import draw_samples, interventional_samples, counterfactual_samples\n",
    "\n",
    "n = 1000\n",
    "x1 = np.random.normal(size=(n))\n",
    "x2 = x1 + np.random.normal(size=(n)) \n",
    "x3 = x1 + x2 + np.random.normal(size=(n)) \n",
    "graph = nx.DiGraph([('x1', 'x2'), ('x1', 'x3'), ('x2','x3')])\n",
    "factual = pd.DataFrame({\"x1\" : x1, \"x2\" : x2, \"x3\" : x3})\n",
    "factual.head()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "da15431e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fitting causal mechanism of node x3: 100%|██████████| 3/3 [00:05<00:00,  1.96s/it]\n"
     ]
    }
   ],
   "source": [
    "params = {'num_epochs' : 200,\n",
    "          'lr' : 1e-4,\n",
    "          'batch_size': 64,\n",
    "          'hidden_dim' : 64}\n",
    "\n",
    "diff_model = create_model_from_graph(graph, params)\n",
    "\n",
    "cy.fit(diff_model, factual)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8af2d0a0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.242151</td>\n",
       "      <td>0.602441</td>\n",
       "      <td>0.697818</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.147019</td>\n",
       "      <td>-0.573428</td>\n",
       "      <td>-0.589873</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.560376</td>\n",
       "      <td>0.019936</td>\n",
       "      <td>-0.248564</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.036918</td>\n",
       "      <td>0.492662</td>\n",
       "      <td>0.826894</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.173450</td>\n",
       "      <td>-0.547257</td>\n",
       "      <td>-0.620032</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         x1        x2        x3\n",
       "0  0.242151  0.602441  0.697818\n",
       "1  1.147019 -0.573428 -0.589873\n",
       "2  0.560376  0.019936 -0.248564\n",
       "3  1.036918  0.492662  0.826894\n",
       "4  0.173450 -0.547257 -0.620032"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Observational Samples\n",
    "obs_samples = draw_samples(diff_model, num_samples = 20)\n",
    "obs_samples.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2a140c90",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>1.211636</td>\n",
       "      <td>3.109947</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>2.056405</td>\n",
       "      <td>5.458492</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>2.074754</td>\n",
       "      <td>5.663898</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>1.536017</td>\n",
       "      <td>2.445876</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2</td>\n",
       "      <td>3.226298</td>\n",
       "      <td>6.365125</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   x1        x2        x3\n",
       "0   2  1.211636  3.109947\n",
       "1   2  2.056405  5.458492\n",
       "2   2  2.074754  5.663898\n",
       "3   2  1.536017  2.445876\n",
       "4   2  3.226298  6.365125"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "intervention = {\"x1\": lambda x: 2}\n",
    "int_samples = interventional_samples(diff_model, intervention, num_samples_to_draw=20)\n",
    "int_samples.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9a8e7015",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>1.041874</td>\n",
       "      <td>3.504380</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>2.426617</td>\n",
       "      <td>4.643862</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>1.795991</td>\n",
       "      <td>3.238968</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>1.050746</td>\n",
       "      <td>2.487216</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2</td>\n",
       "      <td>2.297843</td>\n",
       "      <td>6.336059</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   x1        x2        x3\n",
       "0   2  1.041874  3.504380\n",
       "1   2  2.426617  4.643862\n",
       "2   2  1.795991  3.238968\n",
       "3   2  1.050746  2.487216\n",
       "4   2  2.297843  6.336059"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cf_estimates = counterfactual_samples(diff_model, intervention, observed_data = factual)\n",
    "cf_estimates.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "401157d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAEGCAYAAABvtY4XAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAmnklEQVR4nO3df5RcdZnn8ffTlQpUR4YOa86sNMSgx4GdTCQxPYLG8QeOMq5DJgujEWFmdf9gd9RRsxonzHJM2MM5ZCej4M6y7mHw1yyIQMDeMDgGnaAzskM0oRMzATIqPwKlM8SFRiENqXQ/+0fV7VRX33vrVnfdqnurPq9z+qT71o/7TWie/vbzfb7P19wdERHpPQPdHoCIiKRDAV5EpEcpwIuI9CgFeBGRHqUALyLSoxZ0ewD1Xv7yl/uyZcu6PQwRkdzYu3fvz919SdhjmQrwy5YtY8+ePd0ehohIbpjZE1GPKUUjItKjFOBFRHqUAryISI9SgBcR6VEK8CIiPSpTVTQiIlk0OlZm285D/HR8gtOHSmy88GzWrRru9rCaUoAXEYkxOlbmyrsOMFGZBKA8PsGVdx0AyHyQV4pGRCTGtp2HpoN7YKIyybadh7o0ouQU4EVEYvx0fKKl61miAC8iEuP0oVJL17NEAV5EJMbGC8+mVCzMuFYqFth44dnzfu/RsTJrtu7irE33sGbrLkbHyvN+z3paZBURiREspLa7iqYTi7cK8CIiTaxbNdz2ipm4xdt23UspGhGRLujE4q0CvIhIF3Ri8VYpGhGRmjR2rEa958YLz56Rg4f2Ld4GFOBFJNfignIrATuNRc8k75lmCwRz97a92XyNjIy4TnQSkaQaAyhUZ8HXXrwCIPKxsCC6ZusuyiH57+GhEvdvumD6fq0E5CTvOV9mttfdR8Ie0wxeRHKrWRuBVqpUwgJx/fW5zPC7vQtWi6wikltxAbTV4Fowi70+l5403d4FqwAvIpnVbKdnXABtNbhORqSrg+tzmY2nuQs2CQV4EcmkICVSHp/AOZESqQ/ycQG01eA6HBH4g+tzmY2vWzXMtRevYHiohNXeK2oNIA3KwYtIJkWlRK6+++CMhc5LVg9z3yNHIhc+ky6KNitbTFrWGLYQ264F1VapikZEMumsTfeQJDrFVca0qlmVTJLHW6ncaYe4KhoFeBHJpKgSwzBzKTtMY1NTJ8oiG8UF+FRz8Ga2wcwOmtk/mtmtZnZymvcTkd4RlkOP0mrZYVh+f+P2/ay8+t55te7tdllko9QCvJkNAx8FRtz9N4AC8L607icivSVsgXKoVAx97qkR16OE5fcrk874RCVyQTeJbpdFNkp7kXUBUDKzCjAI/DTl+4lID2ls0zs6VmbjHfupTM1MLb9w7Ph0MN628xDl8QkKZky6MxySfkkyo67fFJU0ndOJ/jKtSC3Au3vZzP4cOAxMAPe6+72NzzOzK4ArAJYuXZrWcEQkpxqDa7FgswJ8ZdLZsuMgLx2fmg6uQf162I7T04dKifL7Px2faGkHayf6y7QitUVWM1sM3AmsB8aBO4Dt7n5z1Gu0yCoi9cKqUuaqsadMkvcNauA7vXDaim4tsv428Ji7H3H3CnAX8MYU7yciPSYsVz5XQVom+I1gojI53YZg8WCR4sDMVgVBaiVrC6etSDPAHwbON7NBMzPg7cDDKd5PRHpM0iBaHDAWD8YvtJ4+VJpRPQPVNE6pWGDzRcvZ9p5zQ3ecZm3htBVp5uB3m9l24EHgODAG3JjW/USk9yTNlb/s5AVsvmg5G27bF7o5yqgugMY1DLt/0wW5WDhtRap18O6+2d3PcfffcPc/cPeX0ryfiPSWpLXwzx6tsOG2fQwunP1cAy47fynrVg3PKd3S7X4y86FeNCKSWY1VKQO10scwDrxwbJJiwVi0cAHPTVRmVbGcWioyPlGZ9dpmdfSN5Zp5oQAvIplWH1yTVL9UJp1FJy1g3+Z3znosouV75PW8U4AXkbZLo88LnJjRX333QZ49OnsmHohKuYxHvCbqet4pwItIIkmDdhqHVzd6sTIV+3hc5UvYom0eKmLmQgd+iEhTSQ7fCGzZcbDlo+1a0aw2Pq7CpdsnLHWaZvAi0lRUeeEnbt/Phtv2Tc/ogdBFTGjfxqC49zHgktXRC6JZayWQNgV4EWkqKqg29ns5uRidFGhXGiSuNt6B+x45Evv6vFbEzIVSNCLSVJLgPFGZjF34bCUNEnfYdrPa+Dy0EOgUBXgRaaqVwzfCLB4szmr7GxXAw/L9H79tH8s23cPKq6sNaS9ZHT0D79UF07lQikZEmmrMXbfSgzbo9RJoVmUTt4g6PlFh4x37ednJ4aEraEkwF2mVdnaTZvAiksi6VcPcv+kCHtv67qbPrd84dNKCE2FmdKzMJ27fH1tl0yzFUpnyyFSQM7dSzFaqhPJEM3gRCRU3ox2K2PIfqO8mMD5R4cq7DrDniWe4c285stVAeXyCNVt3RbYTSGJ4jumZuCohaF/9fqcpwIvILHFpFKgekdeKicokt+5+MjK4B8rjExQGmvcNGCoVZ5zeBPOrZ4+rEmr3Jq1OUoAXkVni2upCtd9Lq5oF9+nnTTkLC8axiHsUB4wta5dPj7MdOfO40sv6s1kDecnXK8CLyCxRM9pmvdkLZpy0wDjapJVAM8cmnevXr5zVc2aoVGTL2uXTwbRdQTWs53u9+n+PTrRiaBcFeBGZJelBG40m3alMVWfZjQdjt6qTG5KC+3zi9v2hv2nUl17G/XajAC8imVSfdji1VKRYsDmlYiqTzuLBIoMLFzQtqRwsDoTO9oea9GdPQxCcm53elKczWhXgRWRW2mF8ojJ9zmnc7tQo40crjH262o99zdZdob8NDNdy1xvv2D9jth/k2EfHymzZcXC6ombxYJHNFy1PdZacpFdNnjpSKsCLSGjaoTLluFfz6mFpi6AksVmwiwri9YGzPte+6KQF3HDfj/jR0y/MeM9nj1bYuD39ssVmqaE8ndGqjU4ifSxoGRCVbx+fqIQG9yCgJWm/u+eJZ2bl4xuTMvX93ccnKrOCe6Ay6W1rOzxXeTqjVTN4kT6V5Pi7MAWzWQEtKqUxOlbm5gcOz3qPySnnv3z9ANt2Hmp5Mbddue75lDrmpSOlArxIn2p2cEaUST8xi64PdEGwrH8sbrb9wrFJXjjWerBuR647T6WO86EAL9LjomaqcTPhqLx7oHFna1iw3PPEM3MqtYxTLFhbct15KnWcDwV4kR4WN1ON6vliJNt1Wr+zNSxY3hKSmpmvRQvbE7LyVOo4H1pkFelhUTPVLTsORvaTcaoz+CTK4xORQXF+25zCBY3L5tvlMe5Q7l6iAC+SU3GHZgSigu/4RCV2E9OkO8VC8yBvEHtMX6uKBWu6yakdB3j3y+HbCvAiOZS0f/m8ZqRe3VzU5CkzShznqzLpmNH09Kj5plLyVOo4H8rBi+RQ0kXCqE05JxcHmu5QTdpLJmkqxhI+d/xohevWr4wtoWxHKiUvpY7zoRm8SM6MjpUjA1/jzDZqprr5ouWJzlidS5uCMAUzLjt/aaK0z+lDpenTo65fv7IvUilpUYAXyZEgNRNlqCGlElUiuW7VcOzB1e025c7IK09rOoVvDN79kkpJi3nCJvydMDIy4nv27On2MES6Liowx7UVgNpZqF6dBb/tnCXcubc8Kz0TBMhV//Xets3QmymYccrJC0LLMgtmTLln+uCMLDOzve4+EvaYcvAiGRNXu95scTGYr5XHJ7jlgcOzJsz1efpOBXeoVuVEnbM65Z7oIG9pnVI0IhkTt4DayuJi1O/mweHWWdFrtedZkmqAN7MhM9tuZo+Y2cNm9oY07yfSC+J2WYbVb89Fu1sIzFWQc09S0y+tS3sG/zngm+5+DnAu8HDK9xPJvagZ7UBtd2mw6AjV0sO8GioVufbiFQCJavqldakFeDM7FXgz8AUAdz/m7uNp3U+kV0TN0ifdp3Px92+6gMe3vjvX6Y1FJy2Y7jgZlZKS+UlzBn8WcAT4kpmNmdlNZrYoxfuJ9IS4EsaJyiSfuH3/9Ow2682x4nbCBmPvl8Zf3ZAowJvZsJm90czeHHwkeNkC4HXA5919FfACsCnkva8wsz1mtufIkSMtDV6kV933SPT/C8FMfnSsnOkZ/OLBImOffud0OqlRMPZ+afzVDU0DvJn9N+B+4CpgY+3jkwne+yngKXffXft6O9WAP4O73+juI+4+smTJksQDF+lVcTtVA0EKo12Lru1WKhbYfNFyoHljr35p/NUNSerg1wFnu/tLrbyxu/+zmT1pZme7+yHg7cBDcxijSN+4avRA4j7qQQrjpAUDczqZKS3DDRuWGk98CjZhbdt5iA237eP0oRKXrB7mvkeOzOn4PInWdCermf0N8B53f77lNzdbCdwELAQeBT7o7s9GPV87WaWfjY6V2XDbvsTNuxYPFnmxMpWp4G7QdNNS2Fmw9TtspTXz3cl6FNhnZn8LTM/i3f2jzV7o7vuA0BuL9LugHUF5fKLpEXlhXqxMMtHGVr3tkCRv3i/H5WVBkgC/o/YhIm0SpGKCkN5qcAcyF9yT5s1VNdM5TQO8u3/FzBYCv1a7dMjdO9fEQiSHopqFBY+F9YnJs4JZ4hTL6UOl0EVkVc20X5IqmrcCPwJuAP4n8E8JyyRF+lKz05a27TyU2+C+eLAYWvHymfeemzi9oqqZzkmSovkM8M5aJQxm9mvArcDqNAcmklfNdmZmpQ9Mq4oFmy59jPrtJImwqhpVzaQjSRXND939tc2utYOqaKQXnLXpnsgZeqlYyFTVSysGgFMHi4wfrSgoZ8h8q2j2mNlNwM21ry8DFIVFIkTlmAtmuQ3uAFOcOMKvvke9gnx2JWlV8EdUNyh9tPbxUO2aiIQIyzEXC62XQWadGoJlX5IqmpeAz9Y+RKSJxhzz0GCR51883uVRpUOljdkWGeDN7HZ3f6+ZHSDkcJg0cvAivSI42BpgzdZdHT0er5NU2phtcTP4j9X+/N1ODESkV+WlamaAap49UCwYixYu4LmJCqeWirxw7DiVyRNzPZU2Zl9kgHf3n9U+/ZC7/0n9Y7UOk38y+1UiEhgdK7Nlx8FuDyOxUweLDC5cEFm62NhaoT4Hr4XWbEpSRfMOZgfzd4VcE5GasIZaWTd+tMLYp98Z+XgQxOv/Xqqmyba4HPwfAR8CXm1mP6x76BTg/6Y9MJE8C9vslHVDg0XWbN0Vu/lIjcLyJW4G/1Xgb4BrmXkS0y/d/ZlURyWSc3mrLikMGM+/eLxpnbsaheVLZB28uz/n7o8DnwOecfcn3P0J4LiZndepAYpk0ehYmTVbd3HWpntYs3XXdJ+Z4HqeKt4NOOWkBVSmZo46rM5dx+vlS5KNTp8H6g/7eL52TaQvjY6V2bh9/4xmYhu37+eq0QPTTcbyxIHxifAyzsaZuRqF5UuSRVbzuoY17j5lZkleJ9KTrr774IxyQYDKpPPV3YeZytPUPYHGmbkaheVLkkD9qJl9lBOz9g9RPX5PpC9FbVrqteAeNTOv38Ql2ZYkRfOfgDcCZeAp4DzgijQHJZJVQa49j6zF5+uM1PxL0ovmaeB9HRiLSKYFte15UzDjJ9f+2+mNV1H59nrDQyUF9x4QVwf/KXf/MzP7C8J70TQ9dFukl+Sxth1OnPcapFaWbbon9vlaNO0dcTP4h2t/qve7CPmt9R5uWCgdjuhXD62drSrZF9eL5u7an1/p3HBEOifqYOyw6wADlr+e7mGz8Y0Xns3G7ftnVQIVB4xt70l+tqpkX+SRfWZ2NyGpmYC7r233YHRkn3RKWK+YUrHAJauHuXNveVYqprHTYhYFM/NC7QfRcEwJ4+hYmavvPjhdETRUKrJl7XIF9xya65F9f17782LgX3PiyL5LgX9p3/BEOi+qp8qtu58MnaVnPbgPlYrcv+mCxM9XqWN/iGtV8F13/y6wxt3Xu/vdtY/3A7/VuSGKtF9UPj3rKZjLz19KcWB2weMLx47nuoRT0pGkDn6Rmb0q+MLMzgIWpTckkfTltXfKNetWUCzMDvCVSdf5qDJLkgC/AfiOmX3HzL4L3Ad8PNVRiaQsrKdKHlw1eoCjlfCEUd564Ej6kmx0+qaZvQY4p3bpkdpB3CK5Vd9TJU+B8dbdT0Y+VrBW96pKr2s6gzezQWAj8BF33w8sNTOd0yq5t27VcEsLk1kQt0aQ9fUD6bwkKZovAceAN9S+LgPXpDYiEYkUN0tv3NAkkiTAv9rd/wyoALj7UVrvWyTSEVEHccQ9N08uPe/M0EXW4oCpvYDMkqRd8DEzK1Hb9GRmrwaUg5fMady8FHbsXLBLtTw+gRGzky+DBgxGXnkaI688TZuUJJEkAX4z8E3gTDO7BVgDfCDpDcysQLWfTdndlbuX1Fx998HYA6EbfwBkLbgPGBQMIopkmPLqovD9my5QMJdEYlM0ZjYALKa6m/UDwK3AiLt/p4V7fIwTjctEUjE6Vo48iCPY1JT1bpBTDgsKBS4/f2nkc/La8Ey6IzbAu/sU8Cl3/3/ufo+7/7W7/zzpm5vZGcC7gZvmOU6RWFfffTDysWBTUx7KIYN2CVHyukFLuiNJiubbZvZJ4DbgheCiuz+T4LXXA58CTpnT6ERCNHZ7fNs5SyJn71AN7Gu27mLA8nGsXly5oxZSpRVJAvz62p8frrvmwKtCnjutViv/tLvvNbO3xjzvCmpHAC5dGv2rqQiEL6Te8sDhpq/Lw+w9UIhoSzxUKir3Li1pWibp7meFfMQG95o1wFozexz4GnCBmd3c+CR3v9HdR9x9ZMmSJS3/BaS/hOXRczApb8ml5505q41CqVhgy9rlXRqR5FVkgDez88xsv5k9b2b/YGb/ppU3dvcr3f0Md19G9UzXXe5++TzHK32u1xcZh0pFrlm3gmsvXsHwUAmjuoFJpyzJXMSlaG4APgn8HbCWaj79wg6MSSTS6THHzeWdwfQsXf3apR3iUjQD7v4td3/J3e8A5pw/cffvqAZe2qGXFxkvO3+pgrq0VVyAHzKzi4OPkK9FOm7dqmGGSsVuD2POrl+/ksWDM8e/eLDI9etXcs26FV0alfSquBTNd4GLIr524K60BiUSJiiPHJ+o5K7NAFQDuVIv0kmRAd7dP9jJgYg0Gh0rs2XHQcYnZte45y24A2y+SFUw0llJ6uBFOm50rMzGO/ZTycPOpARUwy7dkKRdsEjHbdt5qGeCu2rYpVs0g5eua2w9sPHCs3NV714sGOt/80zue+QIPx2fYGiwiDs8N1GZ/vto9i7dEBngm1XKuLsWWWXewloPbLhtH4MLC7xwLLudHwOLB4tsvki92CWb4mbw24F9tQ+YeYqTqmikLaJaD2Q5uA8YfPa9KxXUJfPiAvzFVFsMvBb4P8Ct7v7jjoxK+kaeUjGB4OANBXjJushFVncfdff3AW8BfgJ8xsy+Z2Zv6djopOfltb95Hn8wSf9JUkXzIvAc8AvgZcDJqY5I+srbzslnB9G8/mCS/hK3yHoB1RTN64FvA59z9z2dGpj0hrAKmfoDsG/Z3byXe9YUC9bTPXGkd8Tl4L8N/BD4HnAS8Idm9ofBg+7+0ZTHJjl31egBbnng8PSu0/L4BFfedQCAPU88w80JDurIom2/f67y75ILcQH+P5DPHeGSAaNj5RnBPTBRmeTquw/GHrGXZUE/GZE8iAvwXwNOcfcj9RfNbAnwy1RHJbm3beehyNlBXoN7sWDqJyO5ErfI+t+B3wq5/ibgunSGI73gqtEDuT6Uo2CGUZ2tD5WK06cqKTUjeRM3g1/t7lc0XnT3r5vZNSmOSXLsqtEDucmtr3n1aTx4+LkZG61KxYKOx5OeERfgB2MeU5MyAU5UyeRxxv7Qz37JtReviKzyEcm7uAD/tJm93t2/X3/RzH4TOBLxGukTo2PlXC+WQnUtQAdwSC+LC/AbgdvN7MvA3tq1EeAPqdbHS5+pn63n8UQlkX4Td6LT983sPOBDwAdqlw8C57n70x0Ym2RIY9fHXgjueT7bVSSJ2H7w7v4vwOZaaSSNJZPSP8K6PuZZccB0CIf0vMjFUqvaYmZHgEPAITM7Ymaf7tzwJCt6oblW0O96eKjEtveo5FF6X9wMfgOwBni9uz8GYGavAj5vZhvcXbXwfWJ0rNztITT1q6cs5Onnj+G13NFgcYCLV58xfcqSKmSkH8UF+D8A3uHuPw8uuPujZnY5cC/a7NQXgsOvs5pzN+Cy85dyzboV3R6KSObEBfhifXAPuPsRM9PqVI/Lcn17ccCUYhFJIC7AH5vjY5Jjo2Nltuw4yPhENuvbh0pFtqzVGagiScQF+HPN7Bch1w0d+tGTGkshs2jRSQsU3EUSiquDL3RyINI9WU7HNOqFah6RTomtg5fel4dZez0dlSeSnAJ8n8rTrD2go/JEWqMA32euGj3AV3cfZiqrdY8RFg8W2XyRFldFWqEA30fm26u9MGBMpviTIaqB2fBQifs3XZDafUV6lfq695Fbdz8559ca6X+zONUDN+qVigWlZUTmKLX/Z83sTDO7z8weMrODZvaxtO4lyUz63GffDlRSzusMD5W49uIVDA+Vpo/J0+lKInOXZormOPAJd3/QzE4B9prZt9z9oRTv2feCxdPG/itZ6yfTmI4JZuo6gEOkfVIL8O7+M+Bntc9/aWYPA8OAAnxKrho9wC0PHJ4OnOXxCa686wB7nniGW3Zn55zUghmfee+5OipPJGUdWWQ1s2XAKmB3yGNXAFcALF26tBPD6UmjY+UZwT0wUZkMvd5Nk+6aqYt0QOqLrGb2MuBO4OPuPqv1gbvf6O4j7j6yZMmStIfTs7btPBQZxLMU3KGaWxeR9KU6g691nbwTuMXd70rzXv0kLM+ely38qooR6ZzUAryZGfAF4GF3/2xa9+k3ja0FyuMTbLxjP2YwjyKZjhhWrl2ko9Kcwa+hemjIATPbV7v2p+7+jRTv2fPCzkZNu3xxPrQDVaR70qyi+R4njsGUNsl6KiaokFFAF+k+7WTNmSx3UywVCwruIhmiAJ8zGy88e9Z2/k4qmE3vMr38/KXadSqSYWo2liNB9cxEZZKC2bxaD8zVlDuPbX13x+8rIq1TgM+gsDJIYEb1zHyDe8GMVy0Z5EdPv9DS67KcIhKRmRTgMyasDPLKuw5wcnGgracuTbrz1LMvsubVp3H/T55J9JrigA7cEMkTBfiM2bLj4KxAPlGZTOVIvYnKZOLgPlQqsmWtyh1F8kQBPkNGx8qMT1S6PYxZdOCGSD6piiZDtu08FPlYqdid/1RqLSCSX5rBZ8ToWDn2AOzjHdytGvRqV2sBkXxTgM+AYGE1yoBBZTLdAF8wY8pdvdlFeogCfIdEnbQE4f1lAga0c/JuwMnFwoz7lYoFbVIS6UHmGWpBODIy4nv27On2MNqusfSx3vBQKTY1A7OPt5uv69ev1GlKIj3CzPa6+0jYY5rBtyBuFh4nboZeHp9oGsDbGdyHh0o6TUmkTyjAJxS1AQmIDZbNFk+hGsDbNUuPex9VxIj0F5VJJhQ2C5+oTMaWNjZbPK0XVK3MR8GM69avnG4ANlQqsniwqGZgIn1KM/iEovqwx/Vnj0vNNAo2E101emDOh2TrMGsRqacZfEJRTbbimm8lPZwjSJ2MjpW5c295zqmagul8FRE5QQE+gdGxMkePHZ91vThgHD12nLM23cOarbsYHSvPeDwu+AfBuD51EtaHphWXnnfmnF8rIr1HKZomokocS8UBjk85zx6t9o4JW3TdeOHZs14bVXM+nz40AwbvP28p16xbMafXi0hvUoBvIiqPfuy4z+rJHiy6BsG7fiNTs9LKuMXaRkOlIs9NVFTDLiKxFOCbiMqjRx240fj8YNEzqKHfcNs+tu08NCswt3KY9r7N70z8XBHpX8rBN9HqCUYOLNt0D6++8htcNVpN2QRpnvL4BM6JdE59zj7pfeZbSiki/UMBvom5HnI96c7NDxzmqtEDkTX0V999sKX7aKOSiLRCAb6JdauGufbiFdObh1otRbx195OR6Zdnj1amZ/GN9xkeKnH5+UtnfK2NSiLSCjUba9FZm+5puU49rqGYTksSkfmIazamGXyLWs3JFyz+oOpWFldFRFqhAN+isFx5ccCIOlHv0vPOZN2qYYZKxdDHW/2BISKSlAJ8i4JceX3Arkw5lamZzxswuPz8E5uPtqxdPusHgxZNRSRNqoMnus/76FiZLTsOTu8wXTxYZPNFywF46fhU3FvyilNLM3aWtrLpSUSkHfp+kTWsFUGpWOCS1cPc9v0nqTScl1csGIsWLmjaVsCAx7a+O40hi4hM04lORM/So2rUb939ZOhu1cqkJ+oZo9y6iHRbXwT4uNOYWm1FkIRy6yKSBX2xyBp3GlPUTLvVDU3Bs7UhSUSyItUZvJn9DvA5oADc5O5b232PJAdhx53GdN36lS3l4KNct36lgrqIZEpqM3gzKwA3AO8Cfh241Mx+vZ33SNLEC+JPYwprEXDtxSu4Zt0Ktr3n3BnlkFGT+uHa+4iIZEmaM/jXAz9290cBzOxrwO8BD7XrBnGpl/qAG3XwxtvOWcKarbumZ//XrV85/b4bbtvH6UMltqxdPv1eURU3yreLSBalGeCHgSfrvn4KOK/xSWZ2BXAFwNKlS1u6QdKDsMNq0N92zhLu3FuesfC6cft+cKbTMo2nNKmWXUTypOtVNO5+I3AjVOvgW3nt6RFNvMJSMvUBGmDN1l2zZv+Vydm3DzulSQFdRPIgzSqaMlB/CvQZtWttE9YXJmnKpJUmX2oIJiJ5lGaA/wHwGjM7y8wWAu8DdrTzBlELpElm2K1sRNKmJRHJo9RSNO5+3Mw+AuykWib5RXc/2ORlLZtryiRs4bVYsBk5eNAiqojkV6o5eHf/BvCNNO8xV1ELpmHXlHMXkTzq+WZjSTZCiYjkVd82G4vrQaMgLyK9rqd70cRthBIR6XU9HeCTboQSEelFPR3g43rQiIj0up4O8PPZCCUiknc9vciq3jEi0s96OsCDeseISP/q6RSNiEg/U4AXEelRCvAiIj1KAV5EpEcpwIuI9KhMNRszsyPAEy285OXAz1MaznxkcVxZHBNoXK3SuFrTD+N6pbsvCXsgUwG+VWa2J6qLWjdlcVxZHBNoXK3SuFrT7+NSikZEpEcpwIuI9Ki8B/gbuz2ACFkcVxbHBBpXqzSu1vT1uHKdgxcRkWh5n8GLiEgEBXgRkR6VywBvZr9jZofM7Mdmtqnb4wEwsy+a2dNm9o/dHks9MzvTzO4zs4fM7KCZfazbYwIws5PN7Ptmtr82rqu7PaaAmRXMbMzM/rrbY6lnZo+b2QEz22dm7T2dfh7MbMjMtpvZI2b2sJm9ocvjObv2bxR8/MLMPt7NMQXMbEPt+/0fzexWMzs51fvlLQdvZgXgn4B3AE8BPwAudfeHujyuNwPPA3/l7r/RzbHUM7NXAK9w9wfN7BRgL7AuA/9eBixy9+fNrAh8D/iYuz/QzXEBmNl/BkaAX3H33+32eAJm9jgw4u6Z2rhjZl8B/t7dbzKzhcCgu493eVjAdLwoA+e5eyubKNMYyzDV7/Nfd/cJM7sd+Ia7fzmte+ZxBv964Mfu/qi7HwO+Bvxel8eEu/8d8Ey3x9HI3X/m7g/WPv8l8DDQ9Qb5XvV87cti7aPrsw0zOwN4N3BTt8eSB2Z2KvBm4AsA7n4sK8G95u3AT7od3OssAEpmtgAYBH6a5s3yGOCHgSfrvn6KDASsPDCzZcAqYHeXhwJMp0L2AU8D33L3LIzreuBTwFSXxxHGgXvNbK+ZXdHtwdScBRwBvlRLa91kZou6Pag67wNu7fYgANy9DPw5cBj4GfCcu9+b5j3zGOBlDszsZcCdwMfd/RfdHg+Au0+6+0rgDOD1ZtbV1JaZ/S7wtLvv7eY4YrzJ3V8HvAv4cC0t2G0LgNcBn3f3VcALQFbWxRYCa4E7uj0WADNbTDXbcBZwOrDIzC5P8555DPBl4My6r8+oXZMItRz3ncAt7n5Xt8fTqPYr/X3A73R5KGuAtbVc99eAC8zs5u4O6YTaDBB3fxr4OtV0Zbc9BTxV99vXdqoBPwveBTzo7v/S7YHU/DbwmLsfcfcKcBfwxjRvmMcA/wPgNWZ2Vu0n9PuAHV0eU2bVFjO/ADzs7p/t9ngCZrbEzIZqn5eoLpo/0s0xufuV7n6Guy+j+n21y91TnWElZWaLaovk1FIg7wS6XrHl7v8MPGlmZ9cuvR3o6gJ+nUvJSHqm5jBwvpkN1v6/fDvVNbHU5O7QbXc/bmYfAXYCBeCL7n6wy8PCzG4F3gq83MyeAja7+xe6OyqgOiv9A+BALd8N8Kfu/o3uDQmAVwBfqVU5DAC3u3umyhIz5leBr1fjAguAr7r7N7s7pGl/DNxSm3A9Cnywy+MJfgi+A/iP3R5LwN13m9l24EHgODBGyi0LclcmKSIiyeQxRSMiIgkowIuI9CgFeBGRHqUALyLSoxTgRUR6lAK85IKZ/aqZfdXMHq1t1f8HM/t3HR7DssZuoWa2oq5r4TNm9ljt82+38J7vr/v6A2b2P9o9dulPCvCSebVNIaPA37n7q9x9NdWNSGeEPLejezvc/YC7r6y1XNgBbKx9/dsJx7QMeH/M4yJzpgAveXABcMzd/1dwwd2fcPe/gOlZ7w4z2wX8rZmdZmajZvZDM3vAzF5be94WM/tk8B61ntzLah8Pm9lf1np131vbXYuZrbZqz/r9wIeTDtjMvmNm19f6tn/MzL5sZr9f93jQSXMr8Fu1Wf+G2rXTzeybZvYjM/uzOf2LiaAAL/mwnOruvzivA37f3d8CXA2MuftrgT8F/irBPV4D3ODuy4Fx4JLa9S8Bf+zu585h3AvdfcTdPxPznE1Ue6mvdPfratdWAuuBFcB6Mzsz6sUicRTgJXfM7IbarPoHdZe/5e5BP/43Af8bwN13Af/KzH6lyds+5u77ap/vBZbVeuUM1Xr9E7xnC25r8fmBv3X359z9Rap9XV45x/eRPqcAL3lwkLoOhe7+YaqNmpbUPeeFBO9znJnf8/XHpb1U9/kk7enTVD+m6Xub2QCwMOZ1aYxF+pACvOTBLuBkM/ujumuDMc//e+AyADN7K/DzWg/8x6n9oDCz11Htyx2p1sZ43MzeVLt0WetDn/Y4sLr2+VqqJ1gB/BI4ZR7vKxJJAV4yz6sd8dYBb6mVIX4f+ArwJxEv2QKsNrMfUl3E/Pe163cCp5nZQeAjVM/2beaDwA21Tpw2178D8JdUx78feAMnZvc/BCZrKacNka8WmQN1kxQR6VGawYuI9CgFeBGRHqUALyLSoxTgRUR6lAK8iEiPUoAXEelRCvAiIj3q/wNz+GB74PnTHAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt \n",
    "ground_truth_x2 = (2 - factual[\"x1\"]) + factual[\"x2\"]\n",
    "ground_truth_x3 = 2 * (2 - factual[\"x1\"]) + factual[\"x3\"]\n",
    "fig, ax = plt.subplots()\n",
    "plt.scatter(ground_truth_x3,cf_estimates[\"x3\"])\n",
    "\n",
    "plt.xlabel(\"Ground Truth\")\n",
    "plt.ylabel(\"DCM Prediction\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9888e5a",
   "metadata": {},
   "source": [
    "## Example with Prespecified Graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8332832c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fitting causal mechanism of node x4: 100%|██████████| 4/4 [00:00<00:00, 701.80it/s]\n",
      "Fitting causal mechanism of node x4: 100%|██████████| 4/4 [00:04<00:00,  1.12s/it]\n"
     ]
    }
   ],
   "source": [
    "from experiments.structural_equations import *\n",
    "from experiments.data_generation import ExperimentationModel\n",
    "n = 500\n",
    "scm_type = \"diamond\"\n",
    "equations_type = \"nonadditive\"\n",
    "g = get_graph(scm_type)\n",
    "structural_equations, noise_distributions = select_struct_and_noise(equations_type, scm_type)\n",
    "exper_model = ExperimentationModel(g, scm_type, structural_equations, noise_distributions)\n",
    "factual, noise = exper_model.sample(n)\n",
    "\n",
    "\n",
    "\n",
    "diff_model = create_diff_model(scm_type, params)\n",
    "\n",
    "cy.fit(diff_model, factual)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "65921563",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>x4</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>-0.087256</td>\n",
       "      <td>0.056533</td>\n",
       "      <td>3.155144</td>\n",
       "      <td>-0.764645</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.174429</td>\n",
       "      <td>2.519266</td>\n",
       "      <td>5.068688</td>\n",
       "      <td>-2.671952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.538098</td>\n",
       "      <td>0.839421</td>\n",
       "      <td>4.260010</td>\n",
       "      <td>-2.360014</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.633835</td>\n",
       "      <td>2.057782</td>\n",
       "      <td>4.502608</td>\n",
       "      <td>-2.698460</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-0.688768</td>\n",
       "      <td>1.502286</td>\n",
       "      <td>5.121723</td>\n",
       "      <td>-2.739814</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         x1        x2        x3        x4\n",
       "0 -0.087256  0.056533  3.155144 -0.764645\n",
       "1  1.174429  2.519266  5.068688 -2.671952\n",
       "2  0.538098  0.839421  4.260010 -2.360014\n",
       "3  1.633835  2.057782  4.502608 -2.698460\n",
       "4 -0.688768  1.502286  5.121723 -2.739814"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "factual.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8c3fc832",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>x4</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>3.670004</td>\n",
       "      <td>4.108356</td>\n",
       "      <td>-2.690757</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>2.987519</td>\n",
       "      <td>5.194756</td>\n",
       "      <td>-2.401216</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>2.985081</td>\n",
       "      <td>3.569173</td>\n",
       "      <td>-2.583557</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>2.511662</td>\n",
       "      <td>4.272810</td>\n",
       "      <td>-2.771418</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2</td>\n",
       "      <td>2.811127</td>\n",
       "      <td>4.155429</td>\n",
       "      <td>-2.899285</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   x1        x2        x3        x4\n",
       "0   2  3.670004  4.108356 -2.690757\n",
       "1   2  2.987519  5.194756 -2.401216\n",
       "2   2  2.985081  3.569173 -2.583557\n",
       "3   2  2.511662  4.272810 -2.771418\n",
       "4   2  2.811127  4.155429 -2.899285"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from dowhy.gcm  import interventional_samples\n",
    "intervention = {\"x1\": lambda x: 2}\n",
    "int_samples = interventional_samples(diff_model, intervention, num_samples_to_draw=20)\n",
    "int_samples.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d13c1514-8583-4e75-95ed-e190079911d7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>x4</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>2.279446</td>\n",
       "      <td>3.728182</td>\n",
       "      <td>-2.605973</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>4.364803</td>\n",
       "      <td>5.626756</td>\n",
       "      <td>-1.729665</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>3.285627</td>\n",
       "      <td>5.249624</td>\n",
       "      <td>-2.450392</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>2.914311</td>\n",
       "      <td>4.972789</td>\n",
       "      <td>-2.597203</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2</td>\n",
       "      <td>4.018981</td>\n",
       "      <td>4.904093</td>\n",
       "      <td>-2.349076</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   x1        x2        x3        x4\n",
       "0   2  2.279446  3.728182 -2.605973\n",
       "1   2  4.364803  5.626756 -1.729665\n",
       "2   2  3.285627  5.249624 -2.450392\n",
       "3   2  2.914311  4.972789 -2.597203\n",
       "4   2  4.018981  4.904093 -2.349076"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from dowhy.gcm  import counterfactual_samples\n",
    "cf_estimates = counterfactual_samples(diff_model, intervention, observed_data = factual)\n",
    "cf_estimates.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6aec0eb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.4 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "vscode": {
   "interpreter": {
    "hash": "40d3a090f54c6569ab1632332b64b2c03c39dcf918b08424e98f38b5ae0af88f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
