{
    "cells": [
        {
            "cell_type": "markdown",
            "id": "1e6f5686",
            "metadata": {},
            "source": [
                "# Treatment Estimation with Deep End-to-end Causal Inference\n",
                "\n",
                "This notebook demonstrates the functions `ate`, `cate`, `sample`, `_counterfactual` and `ite` and  that are supported on the DECI model.\n",
                " \n",
                "### Dataset availability\n",
                "To use the notebook, the CSuite datasets need to be available. Ensure that you have run the CSuite data generation script in `causica/data_generation/csuite/simulate.py` before attempting to load datasets.\n",
                "\n",
                "For Microsoft internal users, the datasets will be automatically downloaded from storage."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 1,
            "id": "aee90314",
            "metadata": {},
            "outputs": [],
            "source": [
                "import os\n",
                "# Use this to set the notebook's woring directory to the top-level causica directory\n",
                "os.chdir(\"../..\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 2,
            "id": "d0f4bcfe",
            "metadata": {},
            "outputs": [],
            "source": [
                "import torch\n",
                "import numpy as np\n",
                "from causica.experiment.steps.step_func import load_data\n",
                "from causica.models.deci.deci import DECI\n",
                "import seaborn as sns\n",
                "import pandas as pd\n",
                "import networkx as nx\n",
                "import matplotlib\n",
                "import matplotlib.pyplot as plt"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 3,
            "id": "8cc10b3f",
            "metadata": {},
            "outputs": [],
            "source": [
                "dataset_config = {'dataset_format': 'causal_csv', 'use_predefined_dataset': True, 'test_fraction': 0.1, \n",
                "                  'val_fraction': 0.1, 'random_seed': 0, 'negative_sample': False}\n",
                "### NOTE: setting 'var_dist_A_mode' = 'true' uses the true graph (for faster training)\n",
                "###       replace with 'enco' or 'three' to learn the graph\n",
                "model_config = {'tau_gumbel': 0.25, 'lambda_dag': 100.0, 'lambda_sparse': 5.0, 'spline_bins': 8, \n",
                "                'var_dist_A_mode': 'true', 'mode_f_sem': 'gnn_i', 'mode_adjacency': 'learn', \n",
                "                'norm_layers': True, 'res_connection': True, 'base_distribution_type': 'spline'}\n",
                "training_params = {'learning_rate': 0.05, 'batch_size': 256, 'stardardize_data_mean': False, \n",
                "                   'stardardize_data_std': False, 'rho': 1.0, 'safety_rho': 10000000000000.0, \n",
                "                   'alpha': 0.0, 'safety_alpha': 10000000000000.0, 'tol_dag': 1e-04, 'progress_rate': 0.65, \n",
                "                   'max_steps_auglag': 6, 'max_auglag_inner_epochs': 1000, 'max_p_train_dropout': 0.6, \n",
                "                   'reconstruction_loss_factor': 1.0, 'anneal_entropy': 'noanneal'}"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "c2c6c891",
            "metadata": {},
            "source": [
                "## Quickly training a DECI model"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 4,
            "id": "c09cd6f8",
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "Failure while loading azureml_run_type_providers. Failed to load entrypoint automl = azureml.train.automl.run:AutoMLRun._from_run_dto with exception (msrest 0.7.0 (/anaconda/envs/causica/lib/python3.8/site-packages), Requirement.parse('msrest<0.7.0,>=0.5.1'), {'azureml-core'}).\n",
                        "Failure while loading azureml_run_type_providers. Failed to load entrypoint azureml.scriptrun = azureml.core.script_run:ScriptRun._from_run_dto with exception (msrest 0.7.0 (/anaconda/envs/causica/lib/python3.8/site-packages), Requirement.parse('msrest<0.7.0,>=0.5.1')).\n",
                        "Failure while loading azureml_run_type_providers. Failed to load entrypoint azureml.PipelineRun = azureml.pipeline.core.run:PipelineRun._from_dto with exception (msrest 0.7.0 (/anaconda/envs/causica/lib/python3.8/site-packages), Requirement.parse('msrest<0.7.0,>=0.5.1'), {'azureml-core'}).\n",
                        "Failure while loading azureml_run_type_providers. Failed to load entrypoint azureml.ReusedStepRun = azureml.pipeline.core.run:StepRun._from_reused_dto with exception (msrest 0.7.0 (/anaconda/envs/causica/lib/python3.8/site-packages), Requirement.parse('msrest<0.7.0,>=0.5.1'), {'azureml-core'}).\n",
                        "Failure while loading azureml_run_type_providers. Failed to load entrypoint azureml.StepRun = azureml.pipeline.core.run:StepRun._from_dto with exception (msrest 0.7.0 (/anaconda/envs/causica/lib/python3.8/site-packages), Requirement.parse('msrest<0.7.0,>=0.5.1'), {'azureml-core'}).\n",
                        "Failure while loading azureml_run_type_providers. Failed to load entrypoint hyperdrive = azureml.train.hyperdrive:HyperDriveRun._from_run_dto with exception (msrest 0.7.0 (/anaconda/envs/causica/lib/python3.8/site-packages), Requirement.parse('msrest<0.7.0,>=0.5.1'), {'azureml-core'}).\n"
                    ]
                }
            ],
            "source": [
                "try:\n",
                "    from evaluation_pipeline.aml_run_context import setup_run_context_in_aml\n",
                "    run_context = setup_run_context_in_aml()\n",
                "except ImportError:\n",
                "    from common.experiment.run_context import RunContext\n",
                "    run_context = RunContext()"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "c65dd891",
            "metadata": {},
            "source": [
                "To load the dataset, ensure that you have run the CSuite data generation script in `causica/data_generation/csuite/simulate.py`, ensure that the CSuite datasets have been created under `../data`, and ensure that the notebook's working directory has been set correctly.\n",
                "\n",
                "Here, we use the true graph, but that is not necessary to use this functionality (it makes training quicker)."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 5,
            "id": "34415175",
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "/home/fosteradam/Causica/open_source/causica/datasets/csv_dataset_loader.py:133: UserWarning: Validation data file not found: ./data/csuite_nonlin_simpson/val.csv.\n",
                        "  warnings.warn(f\"Validation data file not found: {val_data_path}.\", UserWarning)\n"
                    ]
                },
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Minimum value of variable Column 0 inferred as -3.2742233276367183. This can be changed manually in the dataset's variables.json file\n",
                        "Max value of variable Column 0 inferred as 3.15141224861145. This can be changed manually in the dataset's variables.json file\n",
                        "Variable Column 0 inferred to be a queriable variable. This can be changed manually in the dataset's variables.json file by updating the \"query\" field.\n",
                        "Variable Column 0 inferred as not an active learning target variable. This can be changed manually in the dataset's variables.json file by updating the \"target\" field.\n",
                        "Variable Column 0 inferred as an always observed target variable. This can be changed manually in the dataset's variables.json file by updating the \"always_observed\" field.\n",
                        "Minimum value of variable Column 1 inferred as -1.5287634134292603. This can be changed manually in the dataset's variables.json file\n",
                        "Max value of variable Column 1 inferred as 2.87360954284668. This can be changed manually in the dataset's variables.json file\n",
                        "Variable Column 1 inferred to be a queriable variable. This can be changed manually in the dataset's variables.json file by updating the \"query\" field.\n",
                        "Variable Column 1 inferred as not an active learning target variable. This can be changed manually in the dataset's variables.json file by updating the \"target\" field.\n",
                        "Variable Column 1 inferred as an always observed target variable. This can be changed manually in the dataset's variables.json file by updating the \"always_observed\" field.\n",
                        "Minimum value of variable Column 2 inferred as -5.4948973655700675. This can be changed manually in the dataset's variables.json file\n",
                        "Max value of variable Column 2 inferred as 3.226113319396972. This can be changed manually in the dataset's variables.json file\n",
                        "Variable Column 2 inferred to be a queriable variable. This can be changed manually in the dataset's variables.json file by updating the \"query\" field.\n",
                        "Variable Column 2 inferred as not an active learning target variable. This can be changed manually in the dataset's variables.json file by updating the \"target\" field.\n",
                        "Variable Column 2 inferred as an always observed target variable. This can be changed manually in the dataset's variables.json file by updating the \"always_observed\" field.\n",
                        "Minimum value of variable Column 3 inferred as -1.937317371368408. This can be changed manually in the dataset's variables.json file\n",
                        "Max value of variable Column 3 inferred as 2.244907855987549. This can be changed manually in the dataset's variables.json file\n",
                        "Variable Column 3 inferred to be a queriable variable. This can be changed manually in the dataset's variables.json file by updating the \"query\" field.\n",
                        "Variable Column 3 inferred as not an active learning target variable. This can be changed manually in the dataset's variables.json file by updating the \"target\" field.\n",
                        "Variable Column 3 inferred as an always observed target variable. This can be changed manually in the dataset's variables.json file by updating the \"always_observed\" field.\n"
                    ]
                }
            ],
            "source": [
                "dataset = load_data(\"csuite_nonlin_simpson\", \"../data\", 0, dataset_config, model_config, False)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 6,
            "id": "9b42ca20",
            "metadata": {},
            "outputs": [],
            "source": [
                "train_data = pd.DataFrame(dataset._train_data, columns=[\"X0\", \"X1\", \"X2\", \"X3\"])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 7,
            "id": "4c798b58",
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>X0</th>\n",
                            "      <th>X1</th>\n",
                            "      <th>X2</th>\n",
                            "      <th>X3</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>0</th>\n",
                            "      <td>-0.313854</td>\n",
                            "      <td>-0.0862971</td>\n",
                            "      <td>-2.20715</td>\n",
                            "      <td>-1.1835</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>1</th>\n",
                            "      <td>0.0557112</td>\n",
                            "      <td>-0.248166</td>\n",
                            "      <td>-0.817248</td>\n",
                            "      <td>-0.760595</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>2</th>\n",
                            "      <td>0.111356</td>\n",
                            "      <td>-0.385469</td>\n",
                            "      <td>-1.03544</td>\n",
                            "      <td>-0.775432</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>3</th>\n",
                            "      <td>0.749707</td>\n",
                            "      <td>-0.736782</td>\n",
                            "      <td>-1.54338</td>\n",
                            "      <td>-1.01173</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>4</th>\n",
                            "      <td>0.0290762</td>\n",
                            "      <td>-0.454397</td>\n",
                            "      <td>-1.75066</td>\n",
                            "      <td>-1.08869</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "          X0         X1        X2        X3\n",
                            "0  -0.313854 -0.0862971  -2.20715   -1.1835\n",
                            "1  0.0557112  -0.248166 -0.817248 -0.760595\n",
                            "2   0.111356  -0.385469  -1.03544 -0.775432\n",
                            "3   0.749707  -0.736782  -1.54338  -1.01173\n",
                            "4  0.0290762  -0.454397  -1.75066  -1.08869"
                        ]
                    },
                    "execution_count": 7,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "train_data.head()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 8,
            "id": "901b6dd8",
            "metadata": {},
            "outputs": [],
            "source": [
                "model = DECI(\"mymodel\", dataset.variables, \"mysavedir\", \"cuda\", **model_config) #change cuda to cpu if GPU is not available"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 9,
            "id": "7038cd91",
            "metadata": {},
            "outputs": [
                {
                    "name": "stderr",
                    "output_type": "stream",
                    "text": [
                        "/home/fosteradam/Causica/open_source/causica/preprocessing/data_processor.py:397: UserWarning: Data too low for continous variables [0 2 3]\n",
                        "  warnings.warn(\n",
                        "/home/fosteradam/Causica/open_source/causica/preprocessing/data_processor.py:402: UserWarning: Data too high for continous variables [0 2 3]\n",
                        "  warnings.warn(\n",
                        "/home/fosteradam/Causica/open_source/causica/utils/helper_functions.py:49: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /opt/conda/conda-bld/pytorch_1603729062494/work/torch/csrc/utils/tensor_numpy.cpp:141.)\n",
                        "  return tuple(torch.as_tensor(array, dtype=dtype, device=device) for array in arrays)\n"
                    ]
                },
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Auglag Step: 0\n",
                        "LR: 0.05\n",
                        "Inner Step: 500, loss: 2.24, log p(x|A): -1.79, dag: 0.00000000,                 log p(A)_sp: -0.01, log q(A): 0.000, H filled: 0.000, rec: 0.433\n",
                        "Inner Step: 1000, loss: 1.86, log p(x|A): -1.51, dag: 0.00000000,                 log p(A)_sp: -0.01, log q(A): 0.000, H filled: 0.000, rec: 0.346\n",
                        "Best model found at innner step 852, with Loss 1.35\n",
                        "Dag penalty after inner: 0.0000000000\n",
                        "Time taken for this step 17.271703481674194\n",
                        "[[[0. 1. 1. 0.]\n",
                        "  [0. 0. 1. 0.]\n",
                        "  [0. 0. 0. 1.]\n",
                        "  [0. 0. 0. 0.]]]\n",
                        "Not done inner optimization.\n",
                        "Dag penalty: 0.000000000000000\n",
                        "Rho: 1.00, alpha: 0.00\n",
                        "Auglag Step: 1\n",
                        "LR: 0.05\n",
                        "Inner Step: 500, loss: 1.47, log p(x|A): -1.21, dag: 0.00000000,                 log p(A)_sp: -0.01, log q(A): 0.000, H filled: 0.000, rec: 0.254\n",
                        "Reducing lr to 0.00500\n",
                        "Inner Step: 1000, loss: 1.42, log p(x|A): -1.16, dag: 0.00000000,                 log p(A)_sp: -0.01, log q(A): 0.000, H filled: 0.000, rec: 0.249\n",
                        "Best model found at innner step 1000, with Loss 1.19\n",
                        "Dag penalty after inner: 0.0000000000\n",
                        "Time taken for this step 17.126287698745728\n",
                        "[[[0. 1. 1. 0.]\n",
                        "  [0. 0. 1. 0.]\n",
                        "  [0. 0. 0. 1.]\n",
                        "  [0. 0. 0. 0.]]]\n",
                        "Updating alpha.\n",
                        "Dag penalty: 0.000000000000000\n",
                        "Rho: 1.00, alpha: 0.00\n",
                        "Auglag Step: 2\n",
                        "LR: 0.05\n",
                        "Inner Step: 500, loss: 1.46, log p(x|A): -1.20, dag: 0.00000000,                 log p(A)_sp: -0.01, log q(A): 0.000, H filled: 0.000, rec: 0.250\n",
                        "Inner Step: 1000, loss: 1.76, log p(x|A): -1.50, dag: 0.00000000,                 log p(A)_sp: -0.01, log q(A): 0.000, H filled: 0.000, rec: 0.248\n",
                        "Best model found at innner step 900, with Loss 1.28\n",
                        "Dag penalty after inner: 0.0000000000\n",
                        "Time taken for this step 17.003172874450684\n",
                        "[[[0. 1. 1. 0.]\n",
                        "  [0. 0. 1. 0.]\n",
                        "  [0. 0. 0. 1.]\n",
                        "  [0. 0. 0. 0.]]]\n",
                        "Not done inner optimization.\n",
                        "Dag penalty: 0.000000000000000\n",
                        "Rho: 1.00, alpha: 0.00\n",
                        "Auglag Step: 3\n",
                        "LR: 0.05\n",
                        "Inner Step: 500, loss: 1.38, log p(x|A): -1.12, dag: 0.00000000,                 log p(A)_sp: -0.01, log q(A): 0.000, H filled: 0.000, rec: 0.249\n",
                        "Reducing lr to 0.00500\n",
                        "Inner Step: 1000, loss: 1.33, log p(x|A): -1.08, dag: 0.00000000,                 log p(A)_sp: -0.01, log q(A): 0.000, H filled: 0.000, rec: 0.242\n",
                        "Best model found at innner step 988, with Loss 1.17\n",
                        "Dag penalty after inner: 0.0000000000\n",
                        "Time taken for this step 17.142217874526978\n",
                        "[[[0. 1. 1. 0.]\n",
                        "  [0. 0. 1. 0.]\n",
                        "  [0. 0. 0. 1.]\n",
                        "  [0. 0. 0. 0.]]]\n",
                        "Updating alpha.\n",
                        "Dag penalty: 0.000000000000000\n",
                        "Rho: 1.00, alpha: 0.00\n",
                        "Auglag Step: 4\n",
                        "LR: 0.05\n",
                        "Inner Step: 500, loss: 1.39, log p(x|A): -1.14, dag: 0.00000000,                 log p(A)_sp: -0.01, log q(A): 0.000, H filled: 0.000, rec: 0.247\n",
                        "Inner Step: 1000, loss: 1.38, log p(x|A): -1.13, dag: 0.00000000,                 log p(A)_sp: -0.01, log q(A): 0.000, H filled: 0.000, rec: 0.246\n",
                        "Best model found at innner step 716, with Loss 1.27\n",
                        "Dag penalty after inner: 0.0000000000\n",
                        "Time taken for this step 17.15366768836975\n",
                        "[[[0. 1. 1. 0.]\n",
                        "  [0. 0. 1. 0.]\n",
                        "  [0. 0. 0. 1.]\n",
                        "  [0. 0. 0. 0.]]]\n",
                        "Not done inner optimization.\n",
                        "Dag penalty: 0.000000000000000\n",
                        "Rho: 1.00, alpha: 0.00\n",
                        "Auglag Step: 5\n",
                        "LR: 0.05\n",
                        "Inner Step: 500, loss: 1.37, log p(x|A): -1.12, dag: 0.00000000,                 log p(A)_sp: -0.01, log q(A): 0.000, H filled: 0.000, rec: 0.244\n",
                        "Inner Step: 1000, loss: 1.37, log p(x|A): -1.12, dag: 0.00000000,                 log p(A)_sp: -0.01, log q(A): 0.000, H filled: 0.000, rec: 0.244\n",
                        "Best model found at innner step 572, with Loss 1.27\n",
                        "Dag penalty after inner: 0.0000000000\n",
                        "Time taken for this step 17.25237250328064\n",
                        "[[[0. 1. 1. 0.]\n",
                        "  [0. 0. 1. 0.]\n",
                        "  [0. 0. 0. 1.]\n",
                        "  [0. 0. 0. 0.]]]\n",
                        "Updating alpha.\n",
                        "Dag penalty: 0.000000000000000\n",
                        "Rho: 1.00, alpha: 0.00\n"
                    ]
                }
            ],
            "source": [
                "model.run_train(dataset, training_params, run_context=run_context)"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "188d8c36",
            "metadata": {},
            "source": [
                "With the trained model, we are now able to quickly get ATE, CATE, ITE and counterfactual estimates without additional retraining!"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "d22d3bde",
            "metadata": {},
            "source": [
                "## Treatment effect estimation\n",
                "\n",
                "For treatment effect estimation with non-binary variables, we need to select the \"treatment\" `B` and \"reference\" `A` values that we are using. We select `mean + std` and `mean - std` as the values we are considering."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 10,
            "id": "2b74c138",
            "metadata": {},
            "outputs": [],
            "source": [
                "treatment_values = train_data.mean(0) + train_data.std(0)\n",
                "reference_values = train_data.mean(0) - train_data.std(0)"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "98fc6e02",
            "metadata": {},
            "source": [
                "### Average treatment effect (ATE)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 11,
            "id": "c8849361",
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "Computing the ATE between X0=0.9882734280299655 and X0=-0.9714331856451489\n",
                        "[ 1.96  -1.441  1.259  0.526]\n",
                        "Computing the ATE between X1=0.6042951007907429 and X1=-0.8248794225513154\n",
                        "[-0.002  1.429  1.484  0.708]\n",
                        "Computing the ATE between X2=-0.0874461591519049 and X2=-2.2234084444962807\n",
                        "[ 9.615e-04 -3.557e-04  2.136e+00  8.917e-01]\n",
                        "Computing the ATE between X3=-0.29498196284004263 and X3=-1.3015321400845723\n",
                        "[ 0.005 -0.002  0.005  1.007]\n"
                    ]
                }
            ],
            "source": [
                "ates = []\n",
                "for variable in range(treatment_values.shape[0]):\n",
                "    intervention_idxs = torch.tensor([variable])\n",
                "    intervention_value = torch.tensor([treatment_values[variable]])\n",
                "    reference_value = torch.tensor([reference_values[variable]])\n",
                "    print(f\"Computing the ATE between X{variable}={treatment_values[variable]} and X{variable}={reference_values[variable]}\")\n",
                "    # This estimate uses 200k samples for accuracy. You can get away with fewer if necessary\n",
                "    ate, _ = model.cate(intervention_idxs, intervention_value, reference_value, Ngraphs=1000, Nsamples_per_graph=200)\n",
                "    print(ate)\n",
                "    ates.append(ate)"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "4897e9eb",
            "metadata": {},
            "source": [
                "Sanity check: `X3` is not a directed parent of any node, so the ATE should be 0. `X2` is downstream of `X0` and `X1`, so these ATEs should also be 0."
            ]
        },
        {
            "cell_type": "markdown",
            "id": "0399710d",
            "metadata": {},
            "source": [
                "Note: the treatment effect of a variable upon itself (e.g. `X0` on `X0`) is simply the difference between the treatment value and the reference value."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 12,
            "id": "797b2cb5",
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "[[ 1.960e+00 -1.441e+00  1.259e+00  5.261e-01]\n",
                        " [-2.277e-03  1.429e+00  1.484e+00  7.083e-01]\n",
                        " [ 9.615e-04 -3.557e-04  2.136e+00  8.917e-01]\n",
                        " [ 4.524e-03 -2.425e-03  4.883e-03  1.007e+00]]\n"
                    ]
                }
            ],
            "source": [
                "ate_matrix = np.stack(ates)\n",
                "print(ate_matrix)"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "16d57264",
            "metadata": {},
            "source": [
                "### Conditional average treatment effect (CATE)\n",
                "Suppose we want to compute the treatment effet of `X1` on `X2`, but conditional for some fixed value of `X0`. For this, we use the `cate` method. The implementation of CATE involves fitting an auxiliary model, so it will be slower than ATE and counterfactual calculation.\n",
                "*Note 1*: there may be some variance in this estimate.\n",
                "*Note 2*: there are some restrictions on the graph for this method. Check the DECI paper for more info."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 13,
            "id": "daed8508",
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "[1.356]\n"
                    ]
                }
            ],
            "source": [
                "# The variable and value to condition on\n",
                "conditioning_idxs = torch.tensor([0])\n",
                "conditioning_values = torch.tensor([treatment_values[0]])\n",
                "# The intervention variable, intervention value and reference value, same as for ATE\n",
                "intervention_idxs = torch.tensor([1])\n",
                "intervention_value = torch.tensor([treatment_values[1]])\n",
                "reference_value = torch.tensor([reference_values[1]])\n",
                "# The variables to compute the effect upon (this has to be length 1)\n",
                "effect_idxs = torch.tensor([2])\n",
                "# Set the parameters of the auxiliary CATE model\n",
                "model.cate_rff_n_features = 100\n",
                "cate_estimate, _ = model.cate(intervention_idxs, intervention_value, reference_value, effect_idxs, conditioning_idxs,\n",
                "                              conditioning_values, Ngraphs=1000, Nsamples_per_graph=20)\n",
                "print(cate_estimate)"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "4e0644bb",
            "metadata": {},
            "source": [
                "### Sample\n",
                "Sample is the underlying method used for both ATE and CATE calculation. For ATE estimation, we simply take the mean over multiple samples. For CATE, we fix an auxiliary regression model to the samples. Sample can be computed under any interventional distribution."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 14,
            "id": "2905dedb",
            "metadata": {},
            "outputs": [],
            "source": [
                "%matplotlib inline"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 15,
            "id": "5026fc0f",
            "metadata": {},
            "outputs": [],
            "source": [
                "samples = model.sample(200000, intervention_idxs=torch.tensor([1]), intervention_values=torch.tensor([treatment_values[1]]),\n",
                "                       samples_per_graph=200)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 16,
            "id": "8c050ac4",
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUwAAAEvCAYAAAAuDvirAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAR9klEQVR4nO3dbYxeZV7H8e/PdkWigjwUxLbrYKiJgLobRiQhRt2qVNksvICkGqWJTRoJmt3EzaasicYXJEXNsiEKhsiGgqvQsLuhWUStxYeYYNlhZWULi0wEoYK0uyCLLxZT9u+L+5rk7nQ6c3U6M/c99PtJ7pxz/8+5zrkuHn69zjlzOqkqJEkL+45Rd0CSVgsDU5I6GZiS1MnAlKROBqYkdTIwJanT2lF3YLHOP//8mpiYGHU3JL3HPPXUU1+vqnVzbVu1gTkxMcHU1NSouyHpPSbJf55om5fkktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUqeuwEzyUpJnkjydZKrVzk2yL8kLbXnO0P63JplO8nySa4bqV7TjTCe5M0la/YwkD7X6gSQTSzxOSTplJzPD/Nmq+kBVTbbvO4H9VbUJ2N++k+RSYCtwGbAFuCvJmtbmbmAHsKl9trT6duDNqroEuAO4ffFDkqTlcSqX5NcBu9v6buD6ofqDVfVOVb0ITANXJrkIOKuqnqjBX/N+/6w2M8d6GNg8M/uUpHHRG5gF/G2Sp5LsaLULq+o1gLa8oNXXA68MtT3Uauvb+uz6MW2q6ijwFnDeyQ1FkpZX77vkV1fVq0kuAPYl+do8+841M6x56vO1OfbAg7DeAfD+979//h5LQyZ2Prrs53hp17XLfg6NVtcMs6pebcvDwBeAK4HX22U2bXm47X4I2DjUfAPwaqtvmKN+TJska4GzgTfm6Mc9VTVZVZPr1s35l4lI0rJZMDCTfHeS751ZB34B+CqwF9jWdtsGPNLW9wJb25Pvixk83HmyXba/neSqdn/yplltZo51A/B4+essJY2ZnkvyC4EvtGcwa4G/qKq/TvIlYE+S7cDLwI0AVXUwyR7gWeAocEtVvduOdTNwH3Am8Fj7ANwLPJBkmsHMcusSjE2SltSCgVlV/wH8+Bz1bwCbT9DmNuC2OepTwOVz1L9FC1xJGle+6SNJnQxMSepkYEpSJwNTkjoZmJLUycCUpE4GpiR1WrW/l1waN76v/t7nDFOSOhmYktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJnQxMSepkYEpSJwNTkjoZmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKRO3YGZZE2Sf03yxfb93CT7krzQlucM7Xtrkukkzye5Zqh+RZJn2rY7k6TVz0jyUKsfSDKxhGOUpCVxMjPMjwLPDX3fCeyvqk3A/vadJJcCW4HLgC3AXUnWtDZ3AzuATe2zpdW3A29W1SXAHcDtixqNJC2jrsBMsgG4FvizofJ1wO62vhu4fqj+YFW9U1UvAtPAlUkuAs6qqieqqoD7Z7WZOdbDwOaZ2ackjYveGeangU8A3x6qXVhVrwG05QWtvh54ZWi/Q622vq3Prh/TpqqOAm8B5/UOQpJWwoKBmeTDwOGqeqrzmHPNDGue+nxtZvdlR5KpJFNHjhzp7I4kLY2eGebVwEeSvAQ8CHwoyZ8Dr7fLbNrycNv/ELBxqP0G4NVW3zBH/Zg2SdYCZwNvzO5IVd1TVZNVNblu3bquAUrSUlkwMKvq1qraUFUTDB7mPF5VvwrsBba13bYBj7T1vcDW9uT7YgYPd55sl+1vJ7mq3Z+8aVabmWPd0M5x3AxTkkZp7Sm03QXsSbIdeBm4EaCqDibZAzwLHAVuqap3W5ubgfuAM4HH2gfgXuCBJNMMZpZbT6FfkrQssloncpOTkzU1NTXqbugUTex8dNRdWFVe2nXtqLvwnpfkqaqanGubb/pIUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJnQxMSepkYEpSJwNTkjoZmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJnQxMSepkYEpSJwNTkjoZmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktRpwcBM8l1JnkzylSQHk/x+q5+bZF+SF9rynKE2tyaZTvJ8kmuG6lckeaZtuzNJWv2MJA+1+oEkE8swVkk6JT0zzHeAD1XVjwMfALYkuQrYCeyvqk3A/vadJJcCW4HLgC3AXUnWtGPdDewANrXPllbfDrxZVZcAdwC3n/rQJGlpLRiYNfC/7ev72qeA64Ddrb4buL6tXwc8WFXvVNWLwDRwZZKLgLOq6omqKuD+WW1mjvUwsHlm9ilJ46LrHmaSNUmeBg4D+6rqAHBhVb0G0JYXtN3XA68MNT/Uauvb+uz6MW2q6ijwFnDeIsYjScumKzCr6t2q+gCwgcFs8fJ5dp9rZljz1Odrc+yBkx1JppJMHTlyZIFeS9LSWnsyO1fV/yT5Bwb3Hl9PclFVvdYutw+33Q4BG4eabQBebfUNc9SH2xxKshY4G3hjjvPfA9wDMDk5eVygSu91EzsfXZHzvLTr2hU5z2rT85R8XZLva+tnAj8HfA3YC2xru20DHmnre4Gt7cn3xQwe7jzZLtvfTnJVuz9506w2M8e6AXi83eeUpLHRM8O8CNjdnnR/B7Cnqr6Y5AlgT5LtwMvAjQBVdTDJHuBZ4ChwS1W92451M3AfcCbwWPsA3As8kGSawcxy61IMTpKW0oKBWVX/Bnxwjvo3gM0naHMbcNsc9SnguPufVfUtWuBK0rjyTR9J6mRgSlInA1OSOhmYktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJnQxMSepkYEpSJwNTkjoZmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJnQxMSepkYEpSpwUDM8nGJH+f5LkkB5N8tNXPTbIvyQttec5Qm1uTTCd5Psk1Q/UrkjzTtt2ZJK1+RpKHWv1AkollGKsknZKeGeZR4Ler6keAq4BbklwK7AT2V9UmYH/7Ttu2FbgM2ALclWRNO9bdwA5gU/tsafXtwJtVdQlwB3D7EoxNkpbUgoFZVa9V1Zfb+tvAc8B64Dpgd9ttN3B9W78OeLCq3qmqF4Fp4MokFwFnVdUTVVXA/bPazBzrYWDzzOxTksbFSd3DbJfKHwQOABdW1WswCFXggrbbeuCVoWaHWm19W59dP6ZNVR0F3gLOO5m+SdJy6w7MJN8DfA74WFV9c75d56jVPPX52szuw44kU0mmjhw5slCXJWlJdQVmkvcxCMvPVtXnW/n1dplNWx5u9UPAxqHmG4BXW33DHPVj2iRZC5wNvDG7H1V1T1VNVtXkunXrerouSUum5yl5gHuB56rqU0Ob9gLb2vo24JGh+tb25PtiBg93nmyX7W8nuaod86ZZbWaOdQPweLvPKUljY23HPlcDvwY8k+TpVvsksAvYk2Q78DJwI0BVHUyyB3iWwRP2W6rq3dbuZuA+4EzgsfaBQSA/kGSawcxy66kNS5KW3oKBWVX/zNz3GAE2n6DNbcBtc9SngMvnqH+LFriSNK5800eSOhmYktTJwJSkTj0PfXSamtj56Ki7II0VZ5iS1MnAlKROBqYkdTIwJamTgSlJnQxMSepkYEpSJwNTkjoZmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJnQxMSepkYEpSJwNTkjoZmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlKnBQMzyWeSHE7y1aHauUn2JXmhLc8Z2nZrkukkzye5Zqh+RZJn2rY7k6TVz0jyUKsfSDKxxGOUpCXRM8O8D9gyq7YT2F9Vm4D97TtJLgW2Ape1NnclWdPa3A3sADa1z8wxtwNvVtUlwB3A7YsdjCQtpwUDs6r+CXhjVvk6YHdb3w1cP1R/sKreqaoXgWngyiQXAWdV1RNVVcD9s9rMHOthYPPM7FOSxsli72FeWFWvAbTlBa2+HnhlaL9Drba+rc+uH9Omqo4CbwHnLbJfkrRslvqhz1wzw5qnPl+b4w+e7EgylWTqyJEji+yiJC3OYgPz9XaZTVsebvVDwMah/TYAr7b6hjnqx7RJshY4m+NvAQBQVfdU1WRVTa5bt26RXZekxVlsYO4FtrX1bcAjQ/Wt7cn3xQwe7jzZLtvfTnJVuz9506w2M8e6AXi83eeUpLGydqEdkvwl8DPA+UkOAb8H7AL2JNkOvAzcCFBVB5PsAZ4FjgK3VNW77VA3M3jifibwWPsA3As8kGSawcxy65KMTJKW2IKBWVW/fIJNm0+w/23AbXPUp4DL56h/ixa4kjTOfNNHkjoZmJLUycCUpE4L3sOUdPqZ2Pnosp/jpV3XLvs5lpozTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJnQxMSepkYEpSJwNTkjoZmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOvk7fVahlfh9K5KO5wxTkjoZmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MlfUSFpJFbiV628tOvaJT3e2Mwwk2xJ8nyS6SQ7R90fSZptLGaYSdYAfwL8PHAI+FKSvVX17Gh7dvL8BWXSe9e4zDCvBKar6j+q6v+AB4HrRtwnSTrGWMwwgfXAK0PfDwE/udQncfYn6VSMS2Bmjlodt1OyA9jRvv5vkudP8jznA18/yTbjxjGMB8cwHuYdQ25f1DF/8EQbxiUwDwEbh75vAF6dvVNV3QPcs9iTJJmqqsnFth8HjmE8OIbxsNJjGJd7mF8CNiW5OMl3AluBvSPukyQdYyxmmFV1NMlvAn8DrAE+U1UHR9wtSTrGWAQmQFX9FfBXy3yaRV/OjxHHMB4cw3hY0TGk6rhnK5KkOYzLPUxJGnunZWAm+a32GubBJH8w6v4sVpKPJ6kk54+6LycryR8m+VqSf0vyhSTfN+o+9Vjtr/Am2Zjk75M81/77/+io+7RYSdYk+dckX1ypc552gZnkZxm8RfRjVXUZ8Ecj7tKiJNnI4FXSl0fdl0XaB1xeVT8G/Dtw64j7s6ChV3h/EbgU+OUkl462VyftKPDbVfUjwFXALatwDDM+Cjy3kic87QITuBnYVVXvAFTV4RH3Z7HuAD7BHD/gvxpU1d9W1dH29V8Y/OztuFv1r/BW1WtV9eW2/jaDwFk/2l6dvCQbgGuBP1vJ856OgfnDwE8lOZDkH5P8xKg7dLKSfAT4r6r6yqj7skR+HXhs1J3oMNcrvKsubGYkmQA+CBwYcVcW49MMJgzfXsmTjs2PFS2lJH8HfP8cm36HwZjPYXA58hPAniQ/VGP24wILjOGTwC+sbI9O3nxjqKpH2j6/w+Ay8bMr2bdF6nqFdzVI8j3A54CPVdU3R92fk5Hkw8Dhqnoqyc+s5Lnfk4FZVT93om1JbgY+3wLyySTfZvA+6pGV6l+PE40hyY8CFwNfSQKDS9kvJ7myqv57Bbu4oPn+PQAk2QZ8GNg8bn9gnUDXK7zjLsn7GITlZ6vq86PuzyJcDXwkyS8B3wWcleTPq+pXl/vEp93PYSb5DeAHqup3k/wwsB94/yr5H/Y4SV4CJqtqVf0lCkm2AJ8CfrqqxuoPqxNJspbBA6rNwH8xeKX3V1bTW2kZ/Cm7G3ijqj424u6csjbD/HhVfXglznc63sP8DPBDSb7K4Kb9ttUalqvcHwPfC+xL8nSSPx11hxbSHlLNvML7HLBnNYVlczXwa8CH2j/3p9tMTR1OuxmmJC3W6TjDlKRFMTAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTv8PDUbAW9xbt4MAAAAASUVORK5CYII=",
                        "text/plain": [
                            "<Figure size 360x360 with 1 Axes>"
                        ]
                    },
                    "metadata": {
                        "needs_background": "light"
                    },
                    "output_type": "display_data"
                }
            ],
            "source": [
                "plt.figure(figsize=(5, 5))\n",
                "plt.hist(samples[:, 2].cpu().numpy())\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 17,
            "id": "760dbdbb",
            "metadata": {},
            "outputs": [],
            "source": [
                "# Take samples under a different intervention\n",
                "samples2 = model.sample(200000, intervention_idxs=torch.tensor([1]), intervention_values=torch.tensor([reference_values[1]]),\n",
                "                        samples_per_graph=200)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 18,
            "id": "9848e266",
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUwAAAEvCAYAAAAuDvirAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAATfUlEQVR4nO3df4xeVX7f8fcn9i5BSWD5MVBkezuscJoAze6Gietq1apZp8EN0Zo/QJpICVZqyRKi0UZKlZrkj6p/WDJtFVrUQoXCFkO3BYvsFiuENK7JNqpETYYNG9Z4KaOFgmsHOwshpBVEJt/+8ZypHg/jmeP54Xlsv1/So3vv995zn3Nk+eNz7507TlUhSVrY9612ByTpfGFgSlInA1OSOhmYktTJwJSkTgamJHVau9odWKyrr766xsfHV7sbki4wL7744p9W1dhc+87bwBwfH2dqamq1uyHpApPkf51pn5fkktSpKzCTfCrJU0m+k+RIkr+d5MokB5K81pZXDB1/b5LpJK8muXWofkuSl9u+B5Kk1S9J8mSrH0oyvuwjlaQl6p1h/mvgd6vqR4DPAkeAXcDBqtoIHGzbJLkRmARuArYCDyZZ087zELAT2Ng+W1t9B/BuVd0A3A/ct8RxSdKyWzAwk1wG/F3gEYCq+suq+jNgG7C3HbYXuL2tbwOeqKoPq+p1YBrYlOQ64LKqer4GL7A/NqvNzLmeArbMzD4laVT0zDA/A5wE/n2SP0rym0l+ALi2qo4DtOU17fh1wFtD7Y+22rq2Prt+WpuqOgW8B1y1qBFJ0grpCcy1wI8DD1XV54H/Q7v8PoO5ZoY1T32+NqefONmZZCrJ1MmTJ+fvtSQts57APAocrapDbfspBgH6drvMpi1PDB2/Yaj9euBYq6+fo35amyRrgcuBd2Z3pKoerqqJqpoYG5vzx6QkacUsGJhV9SfAW0n+RittAV4B9gPbW2078HRb3w9Mtiff1zN4uPNCu2x/P8nmdn/yrlltZs51B/Bc+Ys6JY2Y3h9c/yXgq0k+CXwX+EUGYbsvyQ7gTeBOgKo6nGQfg1A9BdxTVR+189wNPApcCjzbPjB4oPR4kmkGM8vJJY5LkpZdzteJ3MTERPmmj6TlluTFqpqYa59v+khSp/P2XXJdGMZ3PXNOvueNPbedk+/Rhc0ZpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJnQxMSepkYEpSJwNTkjoZmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTgamJHXqCswkbyR5OclLSaZa7cokB5K81pZXDB1/b5LpJK8muXWofks7z3SSB5Kk1S9J8mSrH0oyvszjlKQlO5sZ5k9W1eeqaqJt7wIOVtVG4GDbJsmNwCRwE7AVeDDJmtbmIWAnsLF9trb6DuDdqroBuB+4b/FDkqSVsZRL8m3A3ra+F7h9qP5EVX1YVa8D08CmJNcBl1XV81VVwGOz2syc6ylgy8zsU5JGRW9gFvB7SV5MsrPVrq2q4wBteU2rrwPeGmp7tNXWtfXZ9dPaVNUp4D3gqrMbiiStrLWdx32hqo4luQY4kOQ78xw718yw5qnP1+b0Ew/CeifApz/96fl7LEnLrGuGWVXH2vIE8HVgE/B2u8ymLU+0w48CG4aarweOtfr6OeqntUmyFrgceGeOfjxcVRNVNTE2NtbTdUlaNgsGZpIfSPJDM+vATwPfBvYD29th24Gn2/p+YLI9+b6ewcOdF9pl+/tJNrf7k3fNajNzrjuA59p9TkkaGT2X5NcCX2/PYNYC/7GqfjfJHwL7kuwA3gTuBKiqw0n2Aa8Ap4B7quqjdq67gUeBS4Fn2wfgEeDxJNMMZpaTyzA2SVpWCwZmVX0X+Owc9e8BW87QZjewe476FHDzHPUPaIErSaPKN30kqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTgamJHXq/eUbugiN73pmtbsgjRRnmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJnQxMSepkYEpSJ/9fcl0UzsX/sf7GnttW/Du0upxhSlInA1OSOhmYktSpOzCTrEnyR0l+u21fmeRAktfa8oqhY+9NMp3k1SS3DtVvSfJy2/dAkrT6JUmebPVDScaXcYyStCzOZob5ZeDI0PYu4GBVbQQOtm2S3AhMAjcBW4EHk6xpbR4CdgIb22drq+8A3q2qG4D7gfsWNRpJWkFdgZlkPXAb8JtD5W3A3ra+F7h9qP5EVX1YVa8D08CmJNcBl1XV81VVwGOz2syc6ylgy8zsU5JGRe8M818Bvwr81VDt2qo6DtCW17T6OuCtoeOOttq6tj67flqbqjoFvAdcNbsTSXYmmUoydfLkyc6uS9LyWDAwk/wscKKqXuw851wzw5qnPl+b0wtVD1fVRFVNjI2NdXZHkpZHzw+ufwH4UpKfAb4fuCzJfwDeTnJdVR1vl9sn2vFHgQ1D7dcDx1p9/Rz14TZHk6wFLgfeWeSYJGlFLDjDrKp7q2p9VY0zeJjzXFX9PLAf2N4O2w483db3A5Ptyff1DB7uvNAu299Psrndn7xrVpuZc93RvuNjM0xJWk1LeTVyD7AvyQ7gTeBOgKo6nGQf8ApwCrinqj5qbe4GHgUuBZ5tH4BHgMeTTDOYWU4uoV+StCLOKjCr6hvAN9r694AtZzhuN7B7jvoUcPMc9Q9ogStJo8o3fSSpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJnQxMSepkYEpSJwNTkjoZmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJnQxMSeq0YGAm+f4kLyT5VpLDSf5Zq1+Z5ECS19ryiqE29yaZTvJqkluH6rckebnteyBJWv2SJE+2+qEk4yswVklakp4Z5ofAF6vqs8DngK1JNgO7gINVtRE42LZJciMwCdwEbAUeTLKmneshYCewsX22tvoO4N2qugG4H7hv6UOTpOW1YGDWwF+0zU+0TwHbgL2tvhe4va1vA56oqg+r6nVgGtiU5Drgsqp6vqoKeGxWm5lzPQVsmZl9StKo6LqHmWRNkpeAE8CBqjoEXFtVxwHa8pp2+DrgraHmR1ttXVufXT+tTVWdAt4DrpqjHzuTTCWZOnnyZNcAJWm5dAVmVX1UVZ8D1jOYLd48z+FzzQxrnvp8bWb34+GqmqiqibGxsQV6LUnL66yeklfVnwHfYHDv8e12mU1bnmiHHQU2DDVbDxxr9fVz1E9rk2QtcDnwztn0TZJWWs9T8rEkn2rrlwI/BXwH2A9sb4dtB55u6/uByfbk+3oGD3deaJft7yfZ3O5P3jWrzcy57gCea/c5JWlkrO045jpgb3vS/X3Avqr67STPA/uS7ADeBO4EqKrDSfYBrwCngHuq6qN2rruBR4FLgWfbB+AR4PEk0wxmlpPLMThJWk4LBmZV/THw+Tnq3wO2nKHNbmD3HPUp4GP3P6vqA1rgStKo8k0fSepkYEpSJwNTkjoZmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJnQxMSepkYEpSJwNTkjoZmJLUae1qd0C6UIzvembFv+ONPbet+HfozJxhSlInA1OSOhmYktTJwJSkTgamJHUyMCWp04KBmWRDkt9PciTJ4SRfbvUrkxxI8lpbXjHU5t4k00leTXLrUP2WJC+3fQ8kSatfkuTJVj+UZHwFxipJS9IzwzwF/EpV/SiwGbgnyY3ALuBgVW0EDrZt2r5J4CZgK/BgkjXtXA8BO4GN7bO11XcA71bVDcD9wH3LMDZJWlYLBmZVHa+qb7b194EjwDpgG7C3HbYXuL2tbwOeqKoPq+p1YBrYlOQ64LKqer6qCnhsVpuZcz0FbJmZfUrSqDire5jtUvnzwCHg2qo6DoNQBa5ph60D3hpqdrTV1rX12fXT2lTVKeA94Kqz6ZskrbTuwEzyg8BvAb9cVX8+36Fz1Gqe+nxtZvdhZ5KpJFMnT55cqMuStKy6AjPJJxiE5Ver6mut/Ha7zKYtT7T6UWDDUPP1wLFWXz9H/bQ2SdYClwPvzO5HVT1cVRNVNTE2NtbTdUlaNj1PyQM8Ahypqt8Y2rUf2N7WtwNPD9Un25Pv6xk83HmhXba/n2RzO+dds9rMnOsO4Ll2n1OSRkbPbyv6AvALwMtJXmq1XwP2APuS7ADeBO4EqKrDSfYBrzB4wn5PVX3U2t0NPApcCjzbPjAI5MeTTDOYWU4ubViStPwWDMyq+u/MfY8RYMsZ2uwGds9RnwJunqP+AS1wJWlU+aaPJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJalTz7vkGjHju55Z7S5IFyVnmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJnQxMSepkYEpSJwNTkjotGJhJvpLkRJJvD9WuTHIgyWttecXQvnuTTCd5NcmtQ/Vbkrzc9j2QJK1+SZInW/1QkvFlHqMkLYueGeajwNZZtV3AwaraCBxs2yS5EZgEbmptHkyyprV5CNgJbGyfmXPuAN6tqhuA+4H7FjsYSVpJCwZmVf0B8M6s8jZgb1vfC9w+VH+iqj6sqteBaWBTkuuAy6rq+aoq4LFZbWbO9RSwZWb2KUmjZLH3MK+tquMAbXlNq68D3ho67mirrWvrs+untamqU8B7wFWL7JckrZjlfugz18yw5qnP1+bjJ092JplKMnXy5MlFdlGSFmexgfl2u8ymLU+0+lFgw9Bx64Fjrb5+jvppbZKsBS7n47cAAKiqh6tqoqomxsbGFtl1SVqcxQbmfmB7W98OPD1Un2xPvq9n8HDnhXbZ/n6Sze3+5F2z2syc6w7guXafU5JGytqFDkjyn4C/B1yd5CjwT4E9wL4kO4A3gTsBqupwkn3AK8Ap4J6q+qid6m4GT9wvBZ5tH4BHgMeTTDOYWU4uy8ikC9D4rmfOyfe8see2c/I955sFA7Oqfu4Mu7ac4fjdwO456lPAzXPUP6AFriSNMt/0kaROBqYkdTIwJamTgSlJnQxMSepkYEpSJwNTkjoZmJLUycCUpE4GpiR1MjAlqZOBKUmdDExJ6mRgSlInA1OSOhmYktTJwJSkTgamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1MnAlKROBqYkdTIwJamTgSlJndaudgckjZ7xXc+s+He8see2Ff+O5eYMU5I6GZiS1MnAlKROBqYkdfKhzzI7FzfLJa0OZ5iS1GlkAjPJ1iSvJplOsmu1+yNJs41EYCZZA/xb4B8ANwI/l+TG1e2VJJ1uVO5hbgKmq+q7AEmeALYBr6xqryStmPPxh+NHJTDXAW8NbR8F/tZyf4kPZCQtxagEZuao1ccOSnYCO9vmXyR5dZHfdzXwp4tsO6oc0/nBMZ1DuW9Rzf76mXaMSmAeBTYMba8Hjs0+qKoeBh5e6pclmaqqiaWeZ5Q4pvODYzq/jcRDH+APgY1Jrk/ySWAS2L/KfZKk04zEDLOqTiX5R8B/AdYAX6mqw6vcLUk6zUgEJkBV/Q7wO+fo65Z8WT+CHNP5wTGdx1L1sWcrkqQ5jMo9TEkaeRd1YCb5pfY65uEk/3y1+7NckvzjJJXk6tXuy1Ik+RdJvpPkj5N8PcmnVrtPi3WhvfqbZEOS309ypP39+fJq9+lcuGgDM8lPMnib6Meq6ibgX65yl5ZFkg3A3wfeXO2+LIMDwM1V9WPA/wTuXeX+LMoF+urvKeBXqupHgc3APRfAmBZ00QYmcDewp6o+BKiqE6vcn+VyP/CrzPGD/+ebqvq9qjrVNv8Hg5/PPR/9/1d/q+ovgZlXf89bVXW8qr7Z1t8HjjB4Y++CdjEH5g8DfyfJoST/LclPrHaHlirJl4D/XVXfWu2+rIB/CDy72p1YpLle/b1gwiXJOPB54NAqd2XFjcyPFa2EJP8V+Gtz7Pp1BmO/gsHlxE8A+5J8pkb8xwYWGNOvAT99bnu0NPONp6qebsf8OoNLwK+ey74to65Xf89HSX4Q+C3gl6vqz1e7Pyvtgg7MqvqpM+1LcjfwtRaQLyT5KwbvxJ48V/1bjDONKcnfBK4HvpUEBpev30yyqar+5Bx28azM92cEkGQ78LPAllH/x2weXa/+nm+SfIJBWH61qr622v05Fy7mS/L/DHwRIMkPA59kRH+BQI+qermqrqmq8aoaZ/CX9MdHOSwXkmQr8E+AL1XV/13t/izBBffqbwb/Kj8CHKmq31jt/pwrF3NgfgX4TJJvM7gJv/08nsFcqP4N8EPAgSQvJfl3q92hxWgPrmZe/T0C7LsAXv39AvALwBfbn81LSX5mtTu10nzTR5I6XcwzTEk6KwamJHUyMCWpk4EpSZ0MTEnqZGBKUicDU5I6GZiS1On/AVhCWvcfWSmAAAAAAElFTkSuQmCC",
                        "text/plain": [
                            "<Figure size 360x360 with 1 Axes>"
                        ]
                    },
                    "metadata": {
                        "needs_background": "light"
                    },
                    "output_type": "display_data"
                }
            ],
            "source": [
                "plt.figure(figsize=(5, 5))\n",
                "plt.hist(samples2[:, 2].cpu().numpy())\n",
                "plt.show()"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "f25cccb5",
            "metadata": {},
            "source": [
                "ATE is estimated by taking the difference of the average of the two sets of samples."
            ]
        },
        {
            "cell_type": "markdown",
            "id": "b4197ca5",
            "metadata": {},
            "source": [
                "### Counterfactual estimation"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 19,
            "id": "ef1982f5",
            "metadata": {},
            "outputs": [],
            "source": [
                "graph = model._get_adj_matrix_tensor(round=True, samples=1, most_likely_graph=True)\n",
                "factual_samples = train_data.values.astype('float32')\n",
                "counterfactual_samples = model._counterfactual(factual_samples, graph, \n",
                "                                               intervention_idxs=torch.tensor([1]),\n",
                "                                               intervention_values=torch.tensor([reference_values[1]])\n",
                "                                              ).detach().cpu().numpy()\n",
                "counterfactual_data = pd.DataFrame(counterfactual_samples, columns=[\"X0\", \"X1\", \"X2\", \"X3\"])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 20,
            "id": "f2824513",
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>X0</th>\n",
                            "      <th>X1</th>\n",
                            "      <th>X2</th>\n",
                            "      <th>X3</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>0</th>\n",
                            "      <td>-0.313854</td>\n",
                            "      <td>-0.0862971</td>\n",
                            "      <td>-2.20715</td>\n",
                            "      <td>-1.1835</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>1</th>\n",
                            "      <td>0.0557112</td>\n",
                            "      <td>-0.248166</td>\n",
                            "      <td>-0.817248</td>\n",
                            "      <td>-0.760595</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>2</th>\n",
                            "      <td>0.111356</td>\n",
                            "      <td>-0.385469</td>\n",
                            "      <td>-1.03544</td>\n",
                            "      <td>-0.775432</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>3</th>\n",
                            "      <td>0.749707</td>\n",
                            "      <td>-0.736782</td>\n",
                            "      <td>-1.54338</td>\n",
                            "      <td>-1.01173</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>4</th>\n",
                            "      <td>0.0290762</td>\n",
                            "      <td>-0.454397</td>\n",
                            "      <td>-1.75066</td>\n",
                            "      <td>-1.08869</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "          X0         X1        X2        X3\n",
                            "0  -0.313854 -0.0862971  -2.20715   -1.1835\n",
                            "1  0.0557112  -0.248166 -0.817248 -0.760595\n",
                            "2   0.111356  -0.385469  -1.03544 -0.775432\n",
                            "3   0.749707  -0.736782  -1.54338  -1.01173\n",
                            "4  0.0290762  -0.454397  -1.75066  -1.08869"
                        ]
                    },
                    "execution_count": 20,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "train_data.head()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 21,
            "id": "04d177fd",
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>X0</th>\n",
                            "      <th>X1</th>\n",
                            "      <th>X2</th>\n",
                            "      <th>X3</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>0</th>\n",
                            "      <td>-0.313854</td>\n",
                            "      <td>-0.824879</td>\n",
                            "      <td>-2.904177</td>\n",
                            "      <td>-0.954822</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>1</th>\n",
                            "      <td>0.055711</td>\n",
                            "      <td>-0.824879</td>\n",
                            "      <td>-1.268570</td>\n",
                            "      <td>-0.528514</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>2</th>\n",
                            "      <td>0.111356</td>\n",
                            "      <td>-0.824879</td>\n",
                            "      <td>-1.292711</td>\n",
                            "      <td>-0.628512</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>3</th>\n",
                            "      <td>0.749707</td>\n",
                            "      <td>-0.824879</td>\n",
                            "      <td>-1.585754</td>\n",
                            "      <td>-0.990060</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>4</th>\n",
                            "      <td>0.029076</td>\n",
                            "      <td>-0.824879</td>\n",
                            "      <td>-1.901278</td>\n",
                            "      <td>-1.020412</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "         X0        X1        X2        X3\n",
                            "0 -0.313854 -0.824879 -2.904177 -0.954822\n",
                            "1  0.055711 -0.824879 -1.268570 -0.528514\n",
                            "2  0.111356 -0.824879 -1.292711 -0.628512\n",
                            "3  0.749707 -0.824879 -1.585754 -0.990060\n",
                            "4  0.029076 -0.824879 -1.901278 -1.020412"
                        ]
                    },
                    "execution_count": 21,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "counterfactual_data.head()"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "2d7c3d05",
            "metadata": {},
            "source": [
                "Observations:\n",
                " - no change to column X0: this makes sense because X0 is upstream of the variable we are intervening on, X1\n",
                " - X1 is set to the intervention value\n",
                " - X2 and X3 change to their counterfactual values"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "95ca587c",
            "metadata": {},
            "source": [
                "### Individual treatment effect (ITE)\n",
                "Counterfactuals are the basis for computing ITEs. ITE does not involve resampling the noise. Unless there is significant graph uncertainty, it makes sense to use just the most likely graph, which is very fast"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 22,
            "id": "bcfab5de",
            "metadata": {},
            "outputs": [],
            "source": [
                "ite, _ = model.ite(factual_samples, intervention_idxs=torch.tensor([1]),\n",
                "                   intervention_values=torch.tensor([treatment_values[1]]),\n",
                "                   reference_values=torch.tensor([reference_values[1]]),\n",
                "                   Ngraphs=1, most_likely_graph=True)\n",
                "ite_df = pd.DataFrame(ite, columns=[\"X0\", \"X1\", \"X2\", \"X3\"])"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 23,
            "id": "3f210755",
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>X0</th>\n",
                            "      <th>X1</th>\n",
                            "      <th>X2</th>\n",
                            "      <th>X3</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>0</th>\n",
                            "      <td>0.0</td>\n",
                            "      <td>1.429174</td>\n",
                            "      <td>1.499376</td>\n",
                            "      <td>0.465578</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>1</th>\n",
                            "      <td>0.0</td>\n",
                            "      <td>1.429174</td>\n",
                            "      <td>1.682924</td>\n",
                            "      <td>0.832765</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>2</th>\n",
                            "      <td>0.0</td>\n",
                            "      <td>1.429174</td>\n",
                            "      <td>1.704715</td>\n",
                            "      <td>0.828641</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>3</th>\n",
                            "      <td>0.0</td>\n",
                            "      <td>1.429174</td>\n",
                            "      <td>1.764969</td>\n",
                            "      <td>0.797604</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>4</th>\n",
                            "      <td>0.0</td>\n",
                            "      <td>1.429174</td>\n",
                            "      <td>1.672215</td>\n",
                            "      <td>0.715442</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "    X0        X1        X2        X3\n",
                            "0  0.0  1.429174  1.499376  0.465578\n",
                            "1  0.0  1.429174  1.682924  0.832765\n",
                            "2  0.0  1.429174  1.704715  0.828641\n",
                            "3  0.0  1.429174  1.764969  0.797604\n",
                            "4  0.0  1.429174  1.672215  0.715442"
                        ]
                    },
                    "execution_count": 23,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "ite_df.head()"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "b628d8a2",
            "metadata": {},
            "source": [
                "The average ITE over the population should equal the ATE (approximately)."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 24,
            "id": "aee10822",
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "X0    0.000000\n",
                            "X1    1.429174\n",
                            "X2    1.460712\n",
                            "X3    0.683575\n",
                            "dtype: float64"
                        ]
                    },
                    "execution_count": 24,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "ite_df.mean(0)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 25,
            "id": "ae425832",
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "[-1.068e-05  1.429e+00  1.486e+00  7.072e-01]\n"
                    ]
                }
            ],
            "source": [
                "ate, _ = model.cate(intervention_idxs=torch.tensor([1]),\n",
                "                    intervention_values=torch.tensor([treatment_values[1]]),\n",
                "                    reference_values=torch.tensor([reference_values[1]]), Ngraphs=1000, Nsamples_per_graph=200)\n",
                "print(ate)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "c6dfdf9d",
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "interpreter": {
            "hash": "2fd0a64ad648981ef4b0280c53775d4f8aeceb44a2f0562bd016e7298af01310"
        },
        "kernelspec": {
            "display_name": "Python [conda env:causica]",
            "language": "python",
            "name": "conda-env-causica-py"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.8.2"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 5
}