{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2b8151af-990c-4b5f-8494-a146c1b5a7e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.7/site-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n",
      "  import pandas.util.testing as tm\n",
      "Fitting causal mechanism of node x2: 100%|██████████| 2/2 [00:00<00:00, 788.33it/s]\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "<lambda>() takes 2 positional arguments but 4 were given",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/p2/qjns7zj118l97lkbqykpkg7c0000gn/T/ipykernel_56958/1503293713.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0mstructural_equations\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise_distributions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mselect_struct_and_noise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mequations_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscm_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0mexper_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mExperimentationModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscm_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstructural_equations\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise_distributions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mfactual\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mexper_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m params = {'num_epochs' : 200,\n",
      "\u001b[0;32m~/Documents/Amazon/new_icml_submission/CDM/experiments/data_generation.py\u001b[0m in \u001b[0;36msample\u001b[0;34m(self, num_samples)\u001b[0m\n\u001b[1;32m    119\u001b[0m     def sample(self,\n\u001b[1;32m    120\u001b[0m                num_samples: int) -> Tuple[pd.DataFrame, pd.DataFrame]:\n\u001b[0;32m--> 121\u001b[0;31m         \u001b[0mdata_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise_samples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_draw_data_and_noise_samples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    122\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mdata_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise_samples\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/Amazon/new_icml_submission/CDM/experiments/data_generation.py\u001b[0m in \u001b[0;36m_draw_data_and_noise_samples\u001b[0;34m(self, num_samples)\u001b[0m\n\u001b[1;32m    163\u001b[0m                 drawn_samples[node] = self.model.causal_mechanism(node).evaluate(\n\u001b[1;32m    164\u001b[0m                     \u001b[0mcolumn_stack_selected_numpy_arrays\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdrawn_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_ordered_predecessors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m                     noise)\n\u001b[0m\u001b[1;32m    166\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    167\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mconvert_to_data_frame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdrawn_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconvert_to_data_frame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdrawn_noise_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/Amazon/new_icml_submission/CDM/experiments/data_generation.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(self, parent_samples, noise_samples)\u001b[0m\n\u001b[1;32m     28\u001b[0m                  noise_samples: np.ndarray) -> np.ndarray:\n\u001b[1;32m     29\u001b[0m         \u001b[0mparent_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise_samples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mshape_into_2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparent_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoise_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_formula\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnoise_samples\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mparent_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     32\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mclone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Documents/Amazon/new_icml_submission/CDM/experiments/data_generation.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(noise, parents)\u001b[0m\n\u001b[1;32m    110\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    111\u001b[0m                 spread = lambda func: (lambda noise, parents: func(noise,\n\u001b[0;32m--> 112\u001b[0;31m                                     *np.hsplit(parents,list(range(1,parents.shape[1])))))\n\u001b[0m\u001b[1;32m    113\u001b[0m                 \u001b[0mspread_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstructural_equations\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    114\u001b[0m                 self.model.set_causal_mechanism(node, \n",
      "\u001b[0;31mTypeError\u001b[0m: <lambda>() takes 2 positional arguments but 4 were given"
     ]
    }
   ],
   "source": [
    "from experiments.structural_equations import *\n",
    "from experiments.data_generation import ExperimentationModel\n",
    "#from data_generation import ExperimentationModel\n",
    "import dowhy.gcm as cy\n",
    "\n",
    "n = 50\n",
    "scm_type = \"bivariate_multivar\"\n",
    "equations_type = \"nonadditive\"\n",
    "g = get_graph(scm_type)\n",
    "structural_equations, noise_distributions = select_struct_and_noise(equations_type, scm_type)\n",
    "exper_model = ExperimentationModel(g, scm_type, structural_equations, noise_distributions)\n",
    "factual, noise = exper_model.sample(n)\n",
    "\n",
    "params = {'num_epochs' : 200,\n",
    "          'lr' : 1e-4,\n",
    "          'batch_size': 64,\n",
    "          'hidden_dim' : 64}\n",
    "\n",
    "diff_model = create_diff_model(scm_type, params)\n",
    "\n",
    "cy.fit(diff_model, factual)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "65921563",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'factual' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/p2/qjns7zj118l97lkbqykpkg7c0000gn/T/ipykernel_56944/3862671678.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfactual\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'factual' is not defined"
     ]
    }
   ],
   "source": [
    "factual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d13c1514-8583-4e75-95ed-e190079911d7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>0.215984</td>\n",
       "      <td>1.175192</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>0.796382</td>\n",
       "      <td>2.667339</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0.183704</td>\n",
       "      <td>2.233022</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>1.662512</td>\n",
       "      <td>5.522245</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2</td>\n",
       "      <td>1.075608</td>\n",
       "      <td>3.786495</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>495</th>\n",
       "      <td>2</td>\n",
       "      <td>0.713797</td>\n",
       "      <td>1.372172</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>496</th>\n",
       "      <td>2</td>\n",
       "      <td>0.665122</td>\n",
       "      <td>3.174364</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>497</th>\n",
       "      <td>2</td>\n",
       "      <td>2.405847</td>\n",
       "      <td>5.381603</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>498</th>\n",
       "      <td>2</td>\n",
       "      <td>0.975035</td>\n",
       "      <td>1.891542</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>499</th>\n",
       "      <td>2</td>\n",
       "      <td>1.333937</td>\n",
       "      <td>3.042212</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>500 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     x1        x2        x3\n",
       "0     2  0.215984  1.175192\n",
       "1     2  0.796382  2.667339\n",
       "2     2  0.183704  2.233022\n",
       "3     2  1.662512  5.522245\n",
       "4     2  1.075608  3.786495\n",
       "..   ..       ...       ...\n",
       "495   2  0.713797  1.372172\n",
       "496   2  0.665122  3.174364\n",
       "497   2  2.405847  5.381603\n",
       "498   2  0.975035  1.891542\n",
       "499   2  1.333937  3.042212\n",
       "\n",
       "[500 rows x 3 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from dowhy.gcm  import counterfactual_samples\n",
    "intervention = {\"x1\": lambda x: 2}\n",
    "counterfactual_samples(diff_model, intervention, observed_data = factual)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f9e98bc-6412-422d-bb9d-4036c67c2762",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.4 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "vscode": {
   "interpreter": {
    "hash": "40d3a090f54c6569ab1632332b64b2c03c39dcf918b08424e98f38b5ae0af88f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
