{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1e6f5686",
   "metadata": {},
   "source": [
    "# Deep End-to-end Causal Inference: Demo Notebook\n",
    "\n",
    "This notebook provides a showcase of the features provided by our open source code for Deep End-to-end Causal Inference (DECI).\n",
    "\n",
    " - We begin with a simple two node example, showing how DECI can orient an edge correctly when non-Gaussian noise is present, and how DECI can then be used for treatment effect estimation\n",
    " - We show how different graph constraints can be incorporated into DECI\n",
    " - We showcase DECI on a larger graph example\n",
    " \n",
    "### Dataset availability\n",
    "To use the notebook, the CSuite datasets need to be available and be downloaded in `../../data/`. More information on the datasets and how they can be downloaded can be found here: https://github.com/microsoft/csuite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aee90314",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# Use this to set the notebook's working directory to the top-level directory, where ./data is located\n",
    "os.chdir(\"../..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0f4bcfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from causica.experiment.steps.step_func import load_data\n",
    "from causica.models.deci.deci import DECI\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import networkx as nx\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "def plot_true_graph(dataset):\n",
    "    true_graph = nx.convert_matrix.from_numpy_matrix(dataset.get_adjacency_data_matrix(), create_using=nx.DiGraph)\n",
    "    nx.draw_networkx(true_graph, arrows=True, with_labels=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ab2482a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cc10b3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_config = {'dataset_format': 'causal_csv', 'use_predefined_dataset': True, 'test_fraction': 0.1, \n",
    "                  'val_fraction': 0.1, 'random_seed': 0, 'negative_sample': False}\n",
    "model_config = {'tau_gumbel': 0.25, 'lambda_dag': 100.0, 'lambda_sparse': 5.0, 'spline_bins': 8, \n",
    "                'var_dist_A_mode': 'enco', 'mode_adjacency': 'learn', \n",
    "                'norm_layers': True, 'res_connection': True, 'base_distribution_type': 'spline'}\n",
    "# To speed up training you can try:\n",
    "#  increasing learning_rate\n",
    "#  increasing batch_size (reduces noise when using higher learning rate)\n",
    "#  decreasing max_steps_auglag (go as low as you can and still get a DAG)\n",
    "#  decreasing max_auglag_inner_epochs\n",
    "training_params = {'learning_rate': 0.05, 'batch_size': 256, 'stardardize_data_mean': False, \n",
    "                   'stardardize_data_std': False, 'rho': 1.0, 'safety_rho': 10000000000000.0, \n",
    "                   'alpha': 0.0, 'safety_alpha': 10000000000000.0, 'tol_dag': 1e-04, 'progress_rate': 0.65, \n",
    "                   'max_steps_auglag': 5, 'max_auglag_inner_epochs': 2000, 'max_p_train_dropout': 0.6, \n",
    "                   'reconstruction_loss_factor': 1.0, 'anneal_entropy': 'noanneal'}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2c6c891",
   "metadata": {},
   "source": [
    "## Simplest example of end-to-end causal inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c09cd6f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from causica.experiment.run_context import RunContext\n",
    "run_context = RunContext()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c65dd891",
   "metadata": {},
   "source": [
    "To load the dataset, ensure that you have run the CSuite data generation script in `causica/data_generation/csuite/simulate.py`, ensure that the CSuite datasets have been created under `./data`, and ensure that the notebook's working directory has been set correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34415175",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_data(\"csuite_linexp\", \"./data\", 0, dataset_config, model_config, False, run_context.download_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b42ca20",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = pd.DataFrame(dataset._train_data, columns=[\"A\", \"B\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c798b58",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f212f27",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "sns.scatterplot(x=train_data[\"A\"], y=train_data[\"B\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87be9ad5",
   "metadata": {},
   "source": [
    "Initially, it is unclear what the causal relationship between A and B is."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "663a399e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({'from': ['A'], 'to': ['B']})\n",
    "G = nx.from_pandas_edgelist(df, 'from', 'to')\n",
    "nx.draw_networkx(G, arrows=False, with_labels=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "901b6dd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = DECI(\"mymodel\", dataset.variables, \"mysavedir\", device, **model_config) #change cuda to cpu if GPU is not available"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7038cd91",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.run_train(dataset, training_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2902299a",
   "metadata": {},
   "source": [
    "## Causal discovery results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e168228b",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = model.networkx_graph()\n",
    "print(graph.edges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b17f3ddc",
   "metadata": {},
   "outputs": [],
   "source": [
    "nx.draw_networkx(graph, arrows=True, with_labels=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b75aacf6",
   "metadata": {},
   "source": [
    "We can compare this with the true graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91b4381a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_true_graph(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ae451fa",
   "metadata": {},
   "source": [
    "## Causal inference results\n",
    "\n",
    "DECI has also fitted an SCM that captures the functional relationship and error distribution of this dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af7ce6f6",
   "metadata": {},
   "source": [
    "We can estimate ATE and compare it to the ATE estimate from ground truth interventional data. Here we will compute E[B|do(A=1)] - E[B|do(A=-1)]."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9758aafc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "intervention_idxs = np.array([0])\n",
    "outcome_idx = 1\n",
    "\n",
    "### Model-based ATE estimate\n",
    "do_1 = model.sample(5000, intervention_idxs=intervention_idxs, intervention_values=np.array([1.])).cpu().numpy()\n",
    "do_minus_1 = model.sample(5000, intervention_idxs=intervention_idxs, intervention_values=np.array([-1.])).cpu().numpy()\n",
    "ate_estimate = do_1[:, outcome_idx].mean() - do_minus_1[:, outcome_idx].mean()\n",
    "print(\"Estimated ATE:\", ate_estimate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "368b9625",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Interventional test data ATE\n",
    "ate_true = dataset._intervention_data[0].test_data[:, outcome_idx].mean() - dataset._intervention_data[0].reference_data[:, outcome_idx].mean()\n",
    "print(\"Interventional ATE:\", ate_true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1304809d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Theoretical ATE is 1.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6dfc3ec",
   "metadata": {},
   "source": [
    "In short, we can start from data, do causal discovery and causal inference, yielding treatment effect estimates that actions can be based upon."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73e348ec",
   "metadata": {},
   "source": [
    "## Graph constraints\n",
    "First, train on a new dataset with no constraints. *Note*: this is a very difficult dataset in which all variables are strongly correlated with one another."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "812794d6",
   "metadata": {},
   "source": [
    "To load the dataset, first ensure that it has been generated under `./data`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee65609f",
   "metadata": {},
   "outputs": [],
   "source": [
    "simpson_data = load_data(\"csuite_nonlin_simpson\", \"./data\", 0, dataset_config, model_config, False, run_context.download_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7206785b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"New dataset with {simpson_data.variables.num_groups} nodes.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26c74ca6",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"The true graph is:\")\n",
    "plot_true_graph(simpson_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8369cc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "simpson_df = pd.DataFrame(simpson_data._train_data, columns=simpson_data.variables.group_names)\n",
    "simpson_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41ee303c",
   "metadata": {},
   "outputs": [],
   "source": [
    "simpson_model = DECI(\"mymodel\", simpson_data.variables, \"mysavedir\", device, **model_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a00312cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# You may need more auglag steps / higher rho to make sure you do not get a non-DAG\n",
    "training_params['max_auglag_inner_epochs'] = 2000\n",
    "training_params['max_steps_auglag'] = 10\n",
    "\n",
    "simpson_model.run_train(simpson_data, training_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d3ac03a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(simpson_model.networkx_graph().edges)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a31711e5",
   "metadata": {},
   "source": [
    "If we are not happy with this DAG, we could add some constraints."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e820dd95",
   "metadata": {},
   "source": [
    "Constraints are encoded using an adjacency matrix where:\n",
    " - 0 indicates that there is no directed edge i → j,\n",
    " - 1 indicates that there has to be a directed edge i → j,\n",
    " - nan indicates that the directed edge i → j is learnable.\n",
    " \n",
    "The following function converts from `tabu_` format into this matrix format, for use with DECI."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7789a160",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_constraint_matrix(variables, tabu_child_nodes=None,  tabu_parent_nodes=None, tabu_edges=None):\n",
    "    \"\"\"\n",
    "    Makes a DECI constraint matrix from GCastle constraint format.\n",
    "\n",
    "    Arguments:\n",
    "        tabu_child_nodes: Optional[List[str]]\n",
    "            nodes that cannot be children of any other nodes (root nodes)\n",
    "        tabu_parent_nodes: Optional[List[str]]\n",
    "            edges that cannot be the parent of any other node (leaf nodes)\n",
    "        tabu_edge: Optional[List[Tuple[str, str]]]\n",
    "            edges that cannot exist\n",
    "    \"\"\"\n",
    "\n",
    "    constraint = np.full((variables.num_groups, variables.num_groups), np.nan)\n",
    "    name_to_idx = {name: i for (i, name) in enumerate(variables.group_names)}\n",
    "    if tabu_child_nodes is not None:\n",
    "        for node in tabu_child_nodes:\n",
    "            idx = name_to_idx[node]\n",
    "            constraint[:, idx] = 0.0\n",
    "    if tabu_parent_nodes is not None:\n",
    "        for node in tabu_parent_nodes:\n",
    "            idx = name_to_idx[node]\n",
    "            constraint[idx, :] = 0.0\n",
    "    if tabu_edges is not None:\n",
    "        for source, sink in tabu_edges:\n",
    "            source_idx, sink_idx = name_to_idx[source], name_to_idx[sink]\n",
    "            constraint[source_idx, sink_idx] = 0.0\n",
    "    return constraint.astype(np.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c306612",
   "metadata": {},
   "source": [
    "### Adding constraint that a node is not a child\n",
    "Let's suppose that 'Column 0' is not a child of anything (it's a root node)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee09e828",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_params['max_auglag_inner_epochs'] = 1000\n",
    "training_params['max_steps_auglag'] = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e3833e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint = make_constraint_matrix(simpson_data.variables, tabu_child_nodes=['Column 0'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d60af0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "simpson_model = DECI(\"mymodel\", simpson_data.variables, \"mysavedir\", device, **model_config)\n",
    "simpson_model.set_graph_constraint(constraint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "683740d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "simpson_model.run_train(simpson_data, training_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34e44fda",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(simpson_model.networkx_graph().edges)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a618ffb7",
   "metadata": {},
   "source": [
    "### Adding constraint that a node is not a parent\n",
    "Suppose we also want to specify that 'Column 3' is not a parent of anything (it's a leaf node)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "106d1c37",
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint = make_constraint_matrix(\n",
    "    simpson_data.variables, tabu_child_nodes=['Column 0'], tabu_parent_nodes=['Column 3']\n",
    ")\n",
    "simpson_model = DECI(\"mymodel\", simpson_data.variables, \"mysavedir\", device, **model_config)\n",
    "simpson_model.set_graph_constraint(constraint)\n",
    "simpson_model.run_train(simpson_data, training_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86849696",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(simpson_model.networkx_graph().edges)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "332f8642",
   "metadata": {},
   "source": [
    "### Adding constraint that an edge doesn't exist\n",
    "Suppose we also want to specify that there is no edge Column 1 to Column 3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03dcb944",
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint = make_constraint_matrix(\n",
    "    simpson_data.variables, tabu_child_nodes=['Column 0'], tabu_parent_nodes=['Column 3'], \n",
    "    tabu_edges=[('Column 1', 'Column 3')]\n",
    ")\n",
    "simpson_model = DECI(\"mymodel\", simpson_data.variables, \"mysavedir\", device, **model_config)\n",
    "simpson_model.set_graph_constraint(constraint)\n",
    "simpson_model.run_train(simpson_data, training_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1067ef2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(simpson_model.networkx_graph().edges)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b444342",
   "metadata": {},
   "source": [
    "### Adding a positive constraint\n",
    "It's also possible with DECI to force an edge to exist. For example, suppose we decide to enforce the existence of the egde from Column 1 to Column 3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ae97124",
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint[0, 2] = 1.0\n",
    "simpson_model = DECI(\"mymodel\", simpson_data.variables, \"mysavedir\", device, **model_config)\n",
    "simpson_model.set_graph_constraint(constraint)\n",
    "simpson_model.run_train(simpson_data, training_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd55da06",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(simpson_model.networkx_graph().edges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70ae351b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"The correct graph is \", [(0, 1), (0, 2), (1, 2), (2, 3)])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bfd6839",
   "metadata": {},
   "source": [
    "## A larger graph example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72dc41df",
   "metadata": {},
   "outputs": [],
   "source": [
    "large_data = load_data(\"csuite_large_backdoor\", \"./data\", 0, dataset_config, model_config, False, run_context.download_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc92d599",
   "metadata": {},
   "outputs": [],
   "source": [
    "[train_row, train_col] = np.shape(large_data._train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8123e36",
   "metadata": {},
   "outputs": [],
   "source": [
    "large_train_data = pd.DataFrame(large_data._train_data, columns=[f\"X{i}\" for i in range(9)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbb001cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "large_train_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7930a19f",
   "metadata": {},
   "outputs": [],
   "source": [
    "if train_col < 15:\n",
    "    sns.pairplot(large_train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eed98cdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "large_model = DECI(\"mymodel\", large_data.variables, \"mysavedir\", device, **model_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59318d6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_params['max_steps_auglag'] = 15\n",
    "training_params['max_auglag_inner_epochs'] = 3000\n",
    "large_model.run_train(large_data, training_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cc6046f",
   "metadata": {},
   "outputs": [],
   "source": [
    "large_graph = large_model.networkx_graph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30c86893",
   "metadata": {},
   "outputs": [],
   "source": [
    "nx.draw_networkx(large_graph, arrows=True, with_labels=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1969b257",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"The true graph is:\")\n",
    "plot_true_graph(large_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f3ad9fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "### Model-based ATE estimate\n",
    "do_1 = large_model.sample(5000, intervention_idxs=np.array([7]), intervention_values=np.array([2.])).cpu().numpy()\n",
    "do_minus_1 = large_model.sample(5000, intervention_idxs=np.array([7]), intervention_values=np.array([0.])).cpu().numpy()\n",
    "ate_estimate = do_1[:, 8].mean() - do_minus_1[:, 8].mean()\n",
    "print(\"Estimated ATE:\", ate_estimate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1215bfd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Interventional test data ATE\n",
    "ate_true = large_data._intervention_data[0].test_data[:, 8].mean() - large_data._intervention_data[0].reference_data[:, 8].mean()\n",
    "print(\"Interventional ATE:\", ate_true)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04a7a9d9",
   "metadata": {},
   "source": [
    "## Imputation results\n",
    "\n",
    "DECI also learns an imputation network that can be used to fill in missing data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "170b8d4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_missing(data):\n",
    "    missing_data = data.copy()\n",
    "    mask = np.full(missing_data.shape, fill_value=True, dtype=bool)\n",
    "    n_rows, n_cols = data.shape\n",
    "    for row in range(n_rows):\n",
    "        i = np.random.choice(list(range(n_cols)))\n",
    "        missing_data[row, i] = 0.\n",
    "        mask[row, i] = False\n",
    "    return missing_data, mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68a23583",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_with_missingness, mask = make_missing(dataset._train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4420741a",
   "metadata": {},
   "outputs": [],
   "source": [
    "imputed = model.impute(data_with_missingness, mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce140335",
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = sns.scatterplot(x=dataset._train_data[~mask], y=imputed[~mask])\n",
    "ax.set(xlabel=\"True value\", ylabel=\"Imputed value\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1eef0634",
   "metadata": {},
   "source": [
    "## Analysing the DECI model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78de6a26",
   "metadata": {},
   "source": [
    "DECI gives us a simulator of the observational distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd240314",
   "metadata": {},
   "outputs": [],
   "source": [
    "simulation = pd.DataFrame(model.sample(5000).cpu().numpy(), columns=[\"A\", \"B\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "884265d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.scatterplot(train_data[\"A\"], train_data[\"B\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9a7b5a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.scatterplot(simulation[\"A\"], simulation[\"B\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62dd6a63",
   "metadata": {},
   "source": [
    "The DECI model also allows us to simulate from interventional distributions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51b195a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "simulation_intervention = pd.DataFrame(\n",
    "    model.sample(5000, intervention_idxs=np.array([0]), intervention_values=np.array([4.])).cpu().numpy(), \n",
    "    columns=[\"A\", \"B\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20a697b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "simulation_intervention.min()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03b40785",
   "metadata": {},
   "source": [
    "Intervening on A causes a change in the values of B."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8224157e",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax1 = plt.subplots()\n",
    "ax2 = ax1.twinx()\n",
    "sns.kdeplot(simulation_intervention[\"B\"], ax=ax1)\n",
    "sns.kdeplot(train_data[\"B\"].astype(np.float32), ax=ax2, color='r')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fc3ba72",
   "metadata": {},
   "source": [
    "Intervening on B does not cause a change for A."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11344e22",
   "metadata": {},
   "outputs": [],
   "source": [
    "simulation_intervention2 = pd.DataFrame(\n",
    "    model.sample(5000, intervention_idxs=np.array([1]), intervention_values=np.array([1.])).cpu().numpy(), \n",
    "    columns=[\"A\", \"B\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0302edf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax1 = plt.subplots()\n",
    "ax2 = ax1.twinx()\n",
    "sns.kdeplot(simulation_intervention2[\"A\"], ax=ax1)\n",
    "sns.kdeplot(train_data[\"A\"].astype(np.float32), ax=ax2, color='r')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "279887f8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "2fd0a64ad648981ef4b0280c53775d4f8aeceb44a2f0562bd016e7298af01310"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
