{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fffa4296-fe20-4e0d-b56c-53fde5d227df",
   "metadata": {},
   "source": [
    "# Checking analytic bounds' validity and constraints' correctness in Zero-inflated MCAR and MAR model\n",
    "\n",
    "In the paper \"Zero Inflation as a Missing Data Problem: a Proxy-based Approach\", we prove that:\n",
    "1. For zero-inflation MCAR case $X^{(1)} \\rightarrow X \\leftarrow R \\rightarrow W$:\n",
    "   1. (**C1**) $p(w_0|r_1) = p(w_0|x_1)$\n",
    "   2. (**C2a**) The bound for $p(w_0|x_0)$ is  \n",
    "      1. If $OR > 1$ then $\\max_x p(w_0 \\mid x) < p(w_0 \\mid r_0) \\leq 1$.  \n",
    "      2. If $OR < 1$ then $0 \\leq p(w_0 \\mid r_0) < \\min_x p(w_0 \\mid x)$.\n",
    "   3. (**C2b**) For any $q(w_0|r_0)$ inside the bound, the matrix inversion equation $D = [\\mathbf{q}_{W|R}]^{-1} \\mathbf{p}_{WX}$ creates a random matrix $D$.\n",
    "2. In zero-inflation MAR case $R \\leftarrow C \\rightarrow X^{(1)} \\rightarrow X \\leftarrow R \\rightarrow W$:\n",
    "   1. (**M3**) Either $OR(c) > 1, \\forall c$ or $OR(c) < 1, \\forall c$.\n",
    "   2. (**C3**) $p(w_0|w_1) = p(w_0|x_1, c), \\forall c$. This leads to a marginal constraint\n",
    "      1. (**M5**) $p(w_0|x_1, c) = p(w_0|x_1), \\forall c$.\n",
    "   3. (**C4a**) The bound for $p(w_0|x_0)$ is  \n",
    "      1. If $OR(c) > 1$ then $\\max_{x, c} p(w_0 \\mid x, c) < p(w_0 \\mid r_0) \\leq 1$.  \n",
    "      2. If $OR(c) < 1$ then $0 \\leq p(w_0 \\mid r_0) < \\min_{x, c} p(w_0 \\mid x, c)$.  \n",
    "   4. (**C4b**) For any $q(w_0|r_0)$ inside the bound, the matrix inversion equation $D = [\\mathbf{q}_{W|R}]^{-1} \\mathbf{p}_{WXC}$ creates a random matrix $D$.\n",
    "  \n",
    "Here $OR = \\frac{p(w_1 \\mid x_1)}{p(w_0 \\mid x_1)} \\frac{p(w_0 \\mid x_0)}{p(w_1 \\mid x_0)}$, and $OR(c) = \\frac{p(w_1 \\mid x_1, c)}{p(w_0 \\mid x_1, c)} \\frac{p(w_0 \\mid x_0, c)}{p(w_1 \\mid x_0, c)}$.\n",
    "\n",
    "This code checks these results by simulating random DGPs and compute truth value of those claims."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e4211b7-372f-4011-8c5c-649efc8462b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import io\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from itertools import product\n",
    "from pathlib import Path\n",
    "\n",
    "from pgmpy.models import BayesianNetwork\n",
    "from pgmpy.inference import VariableElimination\n",
    "from pgmpy.factors.discrete.CPD import TabularCPD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6721132-a7fa-4362-ae22-0d5d7b9d54d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pgmpy throws a warning when method `add_cpds` is used, so disable it\n",
    "# https://pgmpy.org/_modules/pgmpy/models/BayesianNetwork.html#BayesianNetwork.add_cpds\n",
    "import logging\n",
    "logger = logging.getLogger('pgmpy')\n",
    "logger.setLevel(level=logging.ERROR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36d7bffb-7fb3-4e23-9889-0ff25d42d732",
   "metadata": {},
   "outputs": [],
   "source": [
    "# EXP 1\n",
    "CASE = 'MCAR'     # choose either 'MCAR' or 'MAR'\n",
    "N = 1000000\n",
    "seed = 42        # random seed\n",
    "compute_num_bound = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59f13466-7631-4017-800a-e00ed6930655",
   "metadata": {},
   "outputs": [],
   "source": [
    "# EXP 2\n",
    "CASE = 'MCAR'     # choose either 'MCAR' or 'MAR'\n",
    "N = 20\n",
    "seed = 42        # random seed\n",
    "compute_num_bound = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef583620-ab60-4e9d-81e2-90ae3d3329c0",
   "metadata": {},
   "source": [
    "## MCAR case: $X^{(1)} \\rightarrow X \\leftarrow R \\rightarrow W$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "822fcf94-cc55-4afa-a939-68ce8ab49b18",
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_ZI_consistency_edge(net):\n",
    "    card_R = net.get_cardinality(\"R\")\n",
    "    card_X1 = net.get_cardinality(\"X1\")\n",
    "    \"\"\"\n",
    "    cpd table for p(x | x(1), r)\n",
    "    ---\n",
    "    R=   |     0     |     1     |\n",
    "    X1=  |   0 |   1 |   0 |   1 |\n",
    "    ---\n",
    "    X=0  | 1.0 | 1.0 | 1.0 | 0.0 |\n",
    "    X=1  | 0.0 | 0.0 | 0.0 | 1.0 |\n",
    "    ---\n",
    "    \"\"\"\n",
    "    cpd_X = TabularCPD(\n",
    "        \"X\",\n",
    "        2,\n",
    "        [[1.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]],\n",
    "        evidence=[\"R\", \"X1\"],\n",
    "        evidence_card=[card_R, card_X1],\n",
    "    )\n",
    "    net.add_cpds(cpd_X)\n",
    "    return net\n",
    "\n",
    "\n",
    "def odd_ratio(p, c=None):\n",
    "    \"\"\"\n",
    "    Input:\n",
    "        If c=None, p is 2x2 matrix with p[i,j] = p(a=i|b=j).\n",
    "        If c=0,1, p is 2x2x2 matrix with p[i,j,k] = p(a=i|b=j,c=k).\n",
    "    Output: Odd-ratio p(a=1|b=1,c) / p(a=0|b=1,c) * p(a=0|b=0,c) / p(a=1|b=0,c)\n",
    "    \"\"\"\n",
    "    if c == None:\n",
    "        return p[1, 1] / p[0, 1] * p[0, 0] / p[1, 0]\n",
    "    elif isinstance(c, int):\n",
    "        return p[1, 1, c] / p[0, 1, c] * p[0, 0, c] / p[1, 0, c]\n",
    "\n",
    "\n",
    "class CaseMCAR:\n",
    "    def __init__(self, cards):\n",
    "        \"\"\"\n",
    "        Create the graph X(1) -> X <- R -> W\n",
    "        \"\"\"\n",
    "        self.vertices = [\"X1\", \"X\", \"R\", \"W\"]\n",
    "        self.cards = cards\n",
    "        self.edges = [(\"X1\", \"X\"), (\"R\", \"X\"), (\"R\", \"W\")]\n",
    "        self.net = BayesianNetwork(self.edges)\n",
    "        # self.net.to_daft(node_pos={'X': (0,0), 'X1': (0,1), 'R': (1,0), 'W': (2,0)}).render()\n",
    "        self.get_random_cpds()\n",
    "        # If |x| < epsilon, then x is consider \"zero\" (small noise)\n",
    "        self.eps = 1e-12\n",
    "        self.eps2 = 1e-7\n",
    "        self.lb, self.ub = None, None\n",
    "        self.num_lb, self.num_ub = None, None\n",
    "\n",
    "    def get_random_cpds(self):\n",
    "        \"\"\"\n",
    "        Get a random data-generation process in the model, i.e., satisfying 3 conditions\n",
    "        1. random cpds\n",
    "        2. applying the ZI consistency\n",
    "        3. W must associate to R\n",
    "        \"\"\"\n",
    "        while True:\n",
    "            self.net.get_random_cpds(n_states=self.cards, inplace=True)\n",
    "            self.net = add_ZI_consistency_edge(self.net)\n",
    "            infer = VariableElimination(self.net)\n",
    "            self.p_WR = infer.query([\"W\", \"R\"]).values\n",
    "            self.p_WX = infer.query([\"W\", \"X\"]).values\n",
    "            # Proxy assumption: W must assoc with X, e.g., OR cannot be 1\n",
    "            # Marginal assumption: W must assoc with X\n",
    "            if (odd_ratio(self.p_WR, c=None) != 1) and (\n",
    "                odd_ratio(self.p_WX, c=None) != 1\n",
    "            ):\n",
    "                break\n",
    "\n",
    "        # Compute the conditional distributions p(w|r) and p(w|x)\n",
    "        self.p_W_X = np.zeros_like(self.p_WX)\n",
    "        for x in range(self.cards[\"X\"]):\n",
    "            self.p_W_X[:, x] = self.p_WX[:, x] / np.sum(self.p_WX[:, x])\n",
    "        self.p_W_R = np.zeros_like(self.p_WR)\n",
    "        for r in range(self.cards[\"R\"]):\n",
    "            self.p_W_R[:, r] = self.p_WR[:, r] / np.sum(self.p_WR[:, r])\n",
    "\n",
    "    def is_C1(self):\n",
    "        \"\"\"\n",
    "        This function check truth value of result:\n",
    "        (C1) $p(w_0|r_1) = p(w_0|x_1)$\n",
    "        Output: Boolean\n",
    "        \"\"\"\n",
    "        ans = np.abs(self.p_W_R[0, 1] - self.p_W_X[0, 1]) < self.eps\n",
    "        return ans\n",
    "\n",
    "    def is_C2_bound_valid(self):\n",
    "        \"\"\"\n",
    "        This function check truth value of result (C2):\n",
    "            a) If $OR > 1$ then $\\max_x p(w_0 \\mid x) < p(w_0 \\mid r_0) \\leq 1$.  \n",
    "            b) If $OR < 1$ then $0 \\leq p(w_0 \\mid r_0) < \\min_x p(w_0 \\mid x)$.\n",
    "            where p(w_0 \\mid r_0) is the ground-truth proxy-indicator c.d.f.\n",
    "        - The bound is valid when OR(a, b) = True\n",
    "        - The bound is incorrect when OR(a, b) = False\n",
    "        - AND(a, b) is always False.\n",
    "\n",
    "        Output: Boolean value of OR(a, b)\n",
    "        \"\"\"\n",
    "        OR = odd_ratio(self.p_WX, c=None)\n",
    "        if OR > 1:\n",
    "            self.lb = np.max(self.p_W_X[0, :])\n",
    "            self.ub = 1\n",
    "        elif OR < 1:\n",
    "            self.lb = 0\n",
    "            self.ub = np.min(self.p_W_X[0, :])\n",
    "        else:  # This cannot happen as we ruled out this case when we generate DGP\n",
    "            print(\"Error: OR == 1!\")\n",
    "            return 2  # error code\n",
    "        return (self.lb < self.p_W_R[0, 0]) and (self.p_W_R[0, 0] < self.ub)\n",
    "\n",
    "    def id_p_w0_r1(self):\n",
    "        \"\"\"\n",
    "        p(w0|r1) = p(w0|x1) is identified\n",
    "        \"\"\"\n",
    "        self.hat_p_w0_r1 = self.p_WX[0, 1] / np.sum(self.p_WX[:, 1], axis=0)\n",
    "        return self.hat_p_w0_r1\n",
    "\n",
    "    def is_C2_bound_int(self, n):\n",
    "        \"\"\"\n",
    "        Is bound interior valid?\n",
    "        1. Selecting n values of p(w0|r0) inside the calculated bounds.\n",
    "        2. Creating p(W|R) and check if p(RX) = [p(W|R)]^{-1} p(WX) is a random matrix.\n",
    "        \"\"\"\n",
    "        assert self.lb < self.ub\n",
    "\n",
    "        self.id_p_w0_r1()\n",
    "        p_RXs = []\n",
    "        ps = np.linspace(self.lb + 1e-5, self.ub - 1e-5, n)\n",
    "        for p in ps:\n",
    "            p_W_R = np.asarray(\n",
    "                [\n",
    "                    [p, self.hat_p_w0_r1],\n",
    "                    [1 - p, 1 - self.hat_p_w0_r1],\n",
    "                ]\n",
    "            )\n",
    "            p_RX = np.linalg.inv(p_W_R) @ self.p_WX.reshape(2, -1)\n",
    "            p_RXs.append(p_RX.flatten())\n",
    "        p_RXs = np.asarray(p_RXs)\n",
    "        is_nonnegative = (p_RXs >= -self.eps2).all()  # can be a small negative noise\n",
    "        is_addingto1 = (np.max(np.abs(1 - np.sum(p_RXs, axis=1))) <= self.eps2).all()\n",
    "        return is_nonnegative and is_addingto1\n",
    "\n",
    "    def get_numerical_bound(self, seed:int, dgp:int):\n",
    "        \"\"\"\n",
    "        Computing numerical bound using autobound package from\n",
    "        the supplement of Duarte et al. (2023) (https://doi.org/10.1080/01621459.2023.2216909).\n",
    "        Newer version is from this Docker: docker run -p 8888:8888 -it gjardimduarte/autolab:v4\n",
    "        \"\"\"\n",
    "        from autobounds.causalProblem import causalProblem\n",
    "        from autobounds.DAG import DAG\n",
    "\n",
    "        # Create the graph\n",
    "        # the ZI counterfactual X(1) is denoted by Y1, as `autobound` cannot handle similar variable names yet\n",
    "        dag = DAG()\n",
    "        dag.from_structure(\"Y1 -> X, R -> X, R -> W\")\n",
    "        problem = causalProblem(dag)\n",
    "\n",
    "        # This works better\n",
    "        # Adding Zi consistency constraint: p(X=1 | R=0, X(1)=0) = 0 => p(X=1,R=0,other_vars) = 0\n",
    "        \"\"\"probability table for MCAR\n",
    "        W,Y1, X, R, prob\n",
    "        0, 0, 1, 0, 0.0\n",
    "        0, 1, 1, 0, 0.0\n",
    "        1, 0, 1, 0, 0.0\n",
    "        1, 1, 1, 0, 0.0\n",
    "        \"\"\"\n",
    "        cartesian_prod_WX1 = np.vstack(\n",
    "            list(product(range(self.cards[\"W\"]), range(self.cards[\"X1\"])))\n",
    "        )\n",
    "        n = len(cartesian_prod_WX1)\n",
    "        data_ZI_0 = pd.DataFrame(  # probability table\n",
    "            np.hstack(\n",
    "                (\n",
    "                    cartesian_prod_WX1,\n",
    "                    np.ones((n, 1)),\n",
    "                    np.zeros((n, 1)),\n",
    "                    np.zeros((n, 1)),\n",
    "                )\n",
    "            ),\n",
    "            columns=[\"W\", \"Y1\", \"X\", \"R\", \"prob\"],\n",
    "        )\n",
    "        data_ZI_0 = io.StringIO(data_ZI_0.to_csv(index=False))\n",
    "        problem.load_data(data_ZI_0, optimize=True)\n",
    "        # Adding Zi consistency constraint: p(X=x | R=1, X(1)=x) = 1 => p(X=1,R=1,X1=0, other_vars) = 0\n",
    "        data_ZI_1 = pd.DataFrame(  # probability table\n",
    "            [[1, 1, 0, 0, 0.0], [1, 1, 0, 1, 0.0]],\n",
    "            columns=[\"X\", \"R\", \"Y1\", \"W\", \"prob\"],\n",
    "        )\n",
    "        data_ZI_1 = io.StringIO(data_ZI_1.to_csv(index=False))\n",
    "        problem.load_data(data_ZI_1, optimize=True)\n",
    "\n",
    "        # Axioms of probability constraints\n",
    "        problem.add_prob_constraints()  # sum to 1\n",
    "        for para in problem.parameters:  # non-negative\n",
    "            problem.add_constraint([(para[0], [para[1]])], symbol=\">=\")\n",
    "\n",
    "        # Adding observational data\n",
    "        data_p_WX = io.StringIO(\n",
    "            pd.DataFrame(\n",
    "                {\"W\": [0, 0, 1, 1], \"X\": [0, 1, 0, 1], \"prob\": self.p_WX.flatten()}\n",
    "            ).to_csv(index=False)\n",
    "        )\n",
    "        problem.load_data(data_p_WX, optimize=True)\n",
    "\n",
    "        # Adding estimands - p(w=0|r=0) = p(W(R=0)=0)\n",
    "        problem.set_estimand(problem.query(\"W(R=0)=0\"))\n",
    "\n",
    "        # Writing optimization programs\n",
    "        prog = problem.write_program()\n",
    "\n",
    "        # Writing problem file to solve directly with scip in terminal\n",
    "        for sense in ['max','min']:\n",
    "            Path(f\"./dgp{dgp}\").mkdir(parents=True, exist_ok=True)\n",
    "            prog.to_pip(f\"./dgp{dgp}/seed{seed}_dgp{dgp}_{sense}.pip\", sense=sense)\n",
    "        return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51eaafd5-1b45-435b-827b-c71382b29f2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "if CASE == \"MCAR\":\n",
    "    print(f\"MCAR case. Graph: X(1) -> X <- R -> W\")\n",
    "    cards = {\"X1\": 2, \"X\": 2, \"R\": 2, \"W\": 2}\n",
    "    case_mcar = CaseMCAR(cards)\n",
    "\n",
    "    np.random.seed(seed)\n",
    "    if compute_num_bound:\n",
    "        \"\"\"\n",
    "        Experiment 2: BOUND CORROBORATION VIA NUMERICAL METHODS\n",
    "        1. Compute analytical bounds for N DGPs\n",
    "        2. Write polynomial program to be solved for numerical bounds\n",
    "        3. You should compare the bounds manually by yourself\n",
    "        \"\"\"\n",
    "        N = 50\n",
    "        file_name = f\"mcar_compare-seed{seed}-N{N}.csv\"\n",
    "        results = []\n",
    "        for dgp in tqdm(range(N)):\n",
    "            case_mcar.get_random_cpds()\n",
    "            # compute analytic bounds\n",
    "            _ = case_mcar.is_C2_bound_valid()\n",
    "            # write program to solve for num bounds using scip directly in terminal\n",
    "            _ = case_mcar.get_numerical_bound(seed, dgp)\n",
    "            results.append(\n",
    "                    [\n",
    "                        seed,\n",
    "                        dgp,\n",
    "                        case_mcar.lb,\n",
    "                        case_mcar.ub,\n",
    "                        None,\n",
    "                        None,\n",
    "                        case_mcar.p_W_R[0, 0]\n",
    "                    ]\n",
    "                )\n",
    "        # save results\n",
    "        pd.DataFrame(\n",
    "            np.vstack(results, dtype=object),\n",
    "            columns=['seed','dgp','lb','ub','num_lb','num_ub','p_w0_r0']\n",
    "        ).to_csv(file_name, mode='w', index=False)\n",
    "    else:\n",
    "        \"\"\"\n",
    "        Experiment 1: BOUND VALIDITY\n",
    "        1. Check claims of ZI MCAR theorem, here denoted as C1, C2a, C2b\n",
    "        \"\"\"\n",
    "        file_name = f\"mcar_results-seed{seed}-N{int(N/1000)}k.txt\"\n",
    "        K = np.min([200, N])    # Save file every K steps\n",
    "        columns = [\n",
    "            \"seed\",\n",
    "            \"DGP_#\",\n",
    "            \"is_C1\",\n",
    "            \"is_C2_vlid\",\n",
    "            \"is_C2_bint\",\n",
    "            \"lb\",\n",
    "            \"ub\",\n",
    "            *[f\"p_w0_r{i}\" for i in range(cards['R'])],\n",
    "            *[f\"p_w0_x{i}\" for i in range(cards['X'])],\n",
    "        ]\n",
    "        with open(file_name, 'w') as f:\n",
    "            f.write(','.join(columns)+'\\n')\n",
    "            f.close\n",
    "        for b in tqdm(range(int(N / K))):\n",
    "            results = []\n",
    "            for i in range(K):\n",
    "                dgp = b*K + i\n",
    "                case_mcar.get_random_cpds()\n",
    "\n",
    "                re_C1 = case_mcar.is_C1()\n",
    "                re_C2_vlid = case_mcar.is_C2_bound_valid()\n",
    "                case_mcar.id_p_w0_r1()\n",
    "                re_C2_bint = case_mcar.is_C2_bound_int(20)\n",
    "    \n",
    "                results.append(\n",
    "                    [\n",
    "                        seed,\n",
    "                        dgp,\n",
    "                        re_C1,\n",
    "                        re_C2_vlid,\n",
    "                        re_C2_bint,\n",
    "                        case_mcar.lb,\n",
    "                        case_mcar.ub,\n",
    "                        *case_mcar.p_W_R[0, :],\n",
    "                        *case_mcar.p_W_X[0, :],\n",
    "                    ]\n",
    "                )\n",
    "            # save results\n",
    "            pd.DataFrame(\n",
    "                np.vstack(results, dtype=object), columns=columns\n",
    "            ).to_csv(file_name, mode='a', index=False, header=False)\n",
    "\n",
    "        results = pd.read_csv(file_name)\n",
    "        re_C1 = np.sum(results[\"is_C1\"]).astype(int)\n",
    "        re_C2_vlid = np.sum(results[\"is_C2_vlid\"]).astype(int)\n",
    "        re_C2_bint = np.sum(results[\"is_C2_bint\"]).astype(int)\n",
    "        print(\n",
    "            f\"(C1) p(w0|r1) is identified and equals p(w0|x1) for {re_C1}/{N} times ({int(re_C1*100/N)}%).\"\n",
    "        )\n",
    "        print(\n",
    "            f\"(C2_v) The bound is valid for {re_C2_vlid}/{N} times ({int(re_C2_vlid*100/N)}%).\"\n",
    "        )\n",
    "        print(\n",
    "            f\"(C2_i) The bound interior is valid for {re_C2_bint}/{N} times ({int(re_C2_bint*100/N)}%).\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4eb9699-86df-4c80-89a9-3f54c3700000",
   "metadata": {},
   "source": [
    "## MAR case: $R \\leftarrow C \\rightarrow X^{(1)} \\rightarrow X \\leftarrow R \\rightarrow W$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "905356c8-9030-44db-8351-6e0819c0118a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CaseMAR:\n",
    "    def __init__(self, cards):\n",
    "        \"\"\"\n",
    "        Create the graph R <- C -> X(1) -> X <- R -> W\n",
    "        \"\"\"\n",
    "        self.vertices = [\"X1\", \"X\", \"R\", \"W\", \"C\"]\n",
    "        self.cards = cards\n",
    "        print(self.cards)\n",
    "        self.edges = [(\"X1\", \"X\"), (\"R\", \"X\"), (\"R\", \"W\"), (\"C\", \"X1\"), (\"C\", \"R\")]\n",
    "        self.net = BayesianNetwork(self.edges)\n",
    "        # self.net.to_daft(node_pos={'X': (0,0), 'X1': (0,1), 'R': (1,0), 'W': (2,0), 'C': (1,1)}).render()\n",
    "        self.get_random_cpds()\n",
    "        # If |x| < epsilon, then x is consider \"zero\" (small noise)\n",
    "        self.eps = 1e-12\n",
    "        self.eps2 = 1e-7\n",
    "        self.lb, self.ub = None, None\n",
    "        self.num_lb, self.num_ub = None, None\n",
    "\n",
    "    def get_random_cpds(self):\n",
    "        \"\"\"\n",
    "        Get a random data-generation process in the model, i.e., satisfying 3 conditions\n",
    "        1. random cpds\n",
    "        2. applying the ZI consistency\n",
    "        3. W must associate to R\n",
    "        \"\"\"\n",
    "        while True:\n",
    "            self.net.get_random_cpds(n_states=self.cards, inplace=True)\n",
    "            self.net = add_ZI_consistency_edge(self.net)\n",
    "            infer = VariableElimination(self.net)\n",
    "            self.p_WR = infer.query([\"W\", \"R\"]).values\n",
    "            self.p_WXC = infer.query([\"W\", \"X\", \"C\"]).values\n",
    "            # Proxy assumption: W must assoc with X, e.g., OR cannot be 1\n",
    "            # Marginal assumption: W must assoc with X conditional on C\n",
    "            conditions = [odd_ratio(self.p_WXC, c) != 1 for c in range(self.cards[\"C\"])]\n",
    "            if odd_ratio(self.p_WR, c=None) != 1 and np.all(conditions):\n",
    "                break\n",
    "\n",
    "        # Compute the conditional distributions p(w|r) and p(w|x,c)\n",
    "        self.p_W_XC = np.zeros_like(self.p_WXC)\n",
    "        for c in range(self.cards[\"C\"]):\n",
    "            for x in range(self.cards[\"X\"]):\n",
    "                self.p_W_XC[:, x, c] = self.p_WXC[:, x, c] / np.sum(self.p_WXC[:, x, c])\n",
    "        self.p_W_R = np.zeros_like(self.p_WR)\n",
    "        for r in range(self.cards[\"R\"]):\n",
    "            self.p_W_R[:, r] = self.p_WR[:, r] / np.sum(self.p_WR[:, r])\n",
    "\n",
    "    def is_C3M5(self):\n",
    "        \"\"\"\n",
    "        Check truth value of C3 and M5\n",
    "        - \"C3 and M5 are true\" when p(w0 | r1) =  p(w0 | x1, c) for all c\n",
    "        - False otherwise.\n",
    "        Input: BayesianNet of this case\n",
    "        output: Boolean value of AND(C3, M5)\n",
    "        \"\"\"\n",
    "        max_val = np.max([*self.p_W_XC[0, 1, :], self.p_W_R[0, 1]])\n",
    "        min_val = np.min([*self.p_W_XC[0, 1, :], self.p_W_R[0, 1]])\n",
    "        return (max_val - min_val) < self.eps\n",
    "\n",
    "    def is_M3(self):\n",
    "        \"\"\"\n",
    "        Check truth value of result M3.\n",
    "        Output: Boolean value of M3.\n",
    "        \"\"\"\n",
    "        # OR could be equal 1, so need to check both\n",
    "        is_M2_smaller_1 = [odd_ratio(self.p_WXC, c) < 1 for c in range(self.cards[\"C\"])]\n",
    "        is_M2_smaller_1 = np.all(is_M2_smaller_1)\n",
    "        is_M2_larger_1 = [odd_ratio(self.p_WXC, c) > 1 for c in range(self.cards[\"C\"])]\n",
    "        is_M2_larger_1 = np.all(is_M2_larger_1)\n",
    "        return is_M2_smaller_1 or is_M2_larger_1\n",
    "\n",
    "    def is_C4_bound_valid(self):\n",
    "        \"\"\"\n",
    "        This function check truth value of result (C4):\n",
    "            a) If $OR(c) > 1$ then $\\max_{x,c} p(w_0 \\mid x, c) < p(w_0 \\mid r_0) \\leq 1$.  \n",
    "            b) If $OR(c) < 1$ then $0 \\leq p(w_0 \\mid r_0) < \\min_{x,c} p(w_0 \\mid x, c)$.\n",
    "            where p(w_0 \\mid r_0) is the ground-truth proxy-indicator c.d.f.\n",
    "        - The bound is valid when OR(a, b) = True\n",
    "        - The bound is incorrect when OR(a, b) = False\n",
    "        - AND(a, b) is always False.\n",
    "\n",
    "        Output: Boolean value of OR(a, b)\n",
    "        \"\"\"\n",
    "        OR = odd_ratio(self.p_WXC, c=1)  # Assuming M2 is correct\n",
    "        if OR > 1:\n",
    "            self.lb = np.max(self.p_W_XC[0, :, :])\n",
    "            self.ub = 1\n",
    "        elif OR < 1:\n",
    "            self.lb = 0\n",
    "            self.ub = np.min(self.p_W_XC[0, :, :])\n",
    "        else:  # This cannot happen as we ruled out this case when we generate DGP\n",
    "            print(\"Error: OR == 1!\")\n",
    "            return 2  # error code\n",
    "        return (self.lb < self.p_W_R[0, 0]) and (self.p_W_R[0, 0] < self.ub)\n",
    "\n",
    "    def id_p_w0_r1(self):\n",
    "        \"\"\"\n",
    "        p(w0|r1) = p(w0|x1) is identified\n",
    "        \"\"\"\n",
    "        self.hat_p_w0_r1 = self.p_WXC[0, 1, :] / np.sum(self.p_WXC[:, 1, :], axis=0)\n",
    "        self.hat_p_w0_r1 = self.hat_p_w0_r1.mean()\n",
    "        return self.hat_p_w0_r1\n",
    "\n",
    "    def is_C4_bound_int(self, n):\n",
    "        \"\"\"\n",
    "        Is the bound interior valid?\n",
    "        1. Selecting n values of p(w0|r0) inside the calculated bounds.\n",
    "        2. Creating p(W|R) and check if p(RX) = [p(W|R)]^{-1} p(WXC) is a random matrix.\n",
    "        \"\"\"\n",
    "        assert self.lb < self.ub\n",
    "\n",
    "        self.id_p_w0_r1()\n",
    "        p_RXCs = []\n",
    "        ps = np.linspace(self.lb + 1e-5, self.ub - 1e-5, n)\n",
    "        for p in ps:\n",
    "            p_W_R = np.asarray(\n",
    "                [\n",
    "                    [p, self.hat_p_w0_r1],\n",
    "                    [1 - p, 1 - self.hat_p_w0_r1],\n",
    "                ]\n",
    "            )\n",
    "            p_RXC = np.linalg.inv(p_W_R) @ self.p_WXC.reshape(2, -1)\n",
    "            p_RXCs.append(p_RXC.flatten())\n",
    "        p_RXCs = np.asarray(p_RXCs)\n",
    "        is_nonnegative = (p_RXCs >= -self.eps2).all()  # can be a small negative noise\n",
    "        is_addingto1 = (np.max(np.abs(1 - np.sum(p_RXCs, axis=1))) <= self.eps2).all()\n",
    "        return is_nonnegative and is_addingto1\n",
    "\n",
    "    def get_numerical_bound(self, seed:int, dgp:int):\n",
    "        \"\"\"\n",
    "        Computing numerical bound using autobound package from\n",
    "        the supplement of Duarte et al. (2023) (https://doi.org/10.1080/01621459.2023.2216909).\n",
    "        Newer version is from this Docker: docker run -p 8888:8888 -it gjardimduarte/autolab:v4\n",
    "        \"\"\"\n",
    "        from autobounds.causalProblem import causalProblem\n",
    "        from autobounds.DAG import DAG\n",
    "\n",
    "        # Create the graph\n",
    "        # the ZI counterfactual X(1) is denoted by Y1, as `autobound` cannot handle similar variable names yet\n",
    "        dag = DAG()\n",
    "        dag.from_structure(\"C -> Y1, C -> R, Y1 -> X, R -> X, R -> W\")\n",
    "        problem = causalProblem(dag)\n",
    "\n",
    "        # This works better\n",
    "        # Adding Zi consistency constraint: p(X=1 | R=0, X(1)=0) = 0 => p(X=1,R=0,other_vars) = 0\n",
    "        \"\"\"probability table for MCAR\n",
    "        W,Y1, X, R, prob\n",
    "        0, 0, 1, 0, 0.0\n",
    "        0, 1, 1, 0, 0.0\n",
    "        1, 0, 1, 0, 0.0\n",
    "        1, 1, 1, 0, 0.0\n",
    "        \"\"\"\n",
    "        cartesian_prod_WX1C = np.vstack(\n",
    "            list(\n",
    "                product(\n",
    "                    range(self.cards[\"W\"]),\n",
    "                    range(self.cards[\"X1\"]),\n",
    "                    range(self.cards[\"C\"]),\n",
    "                )\n",
    "            )\n",
    "        )\n",
    "        n = len(cartesian_prod_WX1C)\n",
    "        data_ZI_0 = pd.DataFrame(  # probability table\n",
    "            np.hstack(\n",
    "                (\n",
    "                    np.ones((n, 1), dtype=int),\n",
    "                    np.zeros((n, 1), dtype=int),\n",
    "                    cartesian_prod_WX1C,\n",
    "                    np.zeros((n, 1), dtype=float),\n",
    "                ),\n",
    "                dtype=object,\n",
    "            ),\n",
    "            columns=[\"X\", \"R\", \"W\", \"Y1\", \"C\", \"prob\"],\n",
    "        )\n",
    "        data_ZI_0 = io.StringIO(data_ZI_0.to_csv(index=False))\n",
    "        problem.load_data(data_ZI_0, optimize=True)\n",
    "        # Adding Zi consistency constraint: p(X=x | R=1, X(1)=x) = 1 => p(X=1,R=1,X1=0,other_vars) = 0\n",
    "        cartesian_prod_WC = np.vstack(\n",
    "            list(\n",
    "                product(\n",
    "                    range(self.cards[\"W\"]),\n",
    "                    range(self.cards[\"C\"]),\n",
    "                )\n",
    "            )\n",
    "        )\n",
    "        n = len(cartesian_prod_WC)\n",
    "        data_ZI_1 = pd.DataFrame(  # p(X=1,R=1,X1=0,other_vars) = 0\n",
    "            np.hstack(\n",
    "                (\n",
    "                    np.ones((n, 1), dtype=int),\n",
    "                    np.ones((n, 1), dtype=int),\n",
    "                    np.zeros((n, 1), dtype=int),\n",
    "                    cartesian_prod_WC,\n",
    "                    np.zeros((n, 1), dtype=float),\n",
    "                ),\n",
    "                dtype=object,\n",
    "            ),\n",
    "            columns=[\"X\", \"R\", \"Y1\", \"W\", \"C\", \"prob\"],\n",
    "        )\n",
    "        problem.load_data(io.StringIO(data_ZI_1.to_csv(index=False)), optimize=True)\n",
    "        # p(X=0,R=1,X1=1,other_vars) = 0\n",
    "        data_ZI_1.loc[:, [\"X\", \"Y1\"]] = 1 - data_ZI_1.loc[:, [\"X\", \"Y1\"]]\n",
    "        problem.load_data(io.StringIO(data_ZI_1.to_csv(index=False)), optimize=True)\n",
    "\n",
    "        # Axioms of probability constraints\n",
    "        problem.add_prob_constraints()  # sum to 1\n",
    "        for para in problem.parameters:  # non-negative\n",
    "            problem.add_constraint([(para[0], [para[1]])], symbol=\">=\")\n",
    "\n",
    "        # Adding observational data\n",
    "        cartesian_prod_WXC = np.vstack(\n",
    "            list(\n",
    "                product(\n",
    "                    range(self.cards[\"W\"]),\n",
    "                    range(self.cards[\"X\"]),\n",
    "                    range(self.cards[\"C\"]),\n",
    "                )\n",
    "            )\n",
    "        )\n",
    "        data_p_WXC = pd.DataFrame(\n",
    "            np.hstack((cartesian_prod_WXC, self.p_WXC.reshape(-1, 1)), dtype=object),\n",
    "            columns=[\"W\", \"X\", \"C\", \"prob\"],\n",
    "        )\n",
    "        data_p_WXC = io.StringIO(data_p_WXC.to_csv(index=False))\n",
    "        problem.load_data(data_p_WXC, optimize=True)\n",
    "\n",
    "        # Adding estimands - p(w=0|r=0) = p(W(R=0)=0)\n",
    "        problem.set_estimand(problem.query(\"W(R=0)=0\"))\n",
    "\n",
    "        # Writing optimization programs\n",
    "        prog = problem.write_program()\n",
    "\n",
    "        # Writing problem file to solve directly with scip in terminal\n",
    "        for sense in ['max','min']:\n",
    "            Path(f\"./dgp{dgp}\").mkdir(parents=True, exist_ok=True)\n",
    "            prog.to_pip(f\"./dgp{dgp}/seed{seed}_dgp{dgp}_{sense}.pip\", sense=sense)\n",
    "        return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa99db99-c167-4e73-b520-6ceb90d71205",
   "metadata": {},
   "outputs": [],
   "source": [
    "if CASE == \"MAR\":\n",
    "    print(f\"MAR case. Graph: R <- C -> X(1) -> X <- R -> W\")\n",
    "    cards = {\"X1\": 2, \"X\": 2, \"R\": 2, \"W\": 2, \"C\": 2}\n",
    "    case_mar = CaseMAR(cards)\n",
    "\n",
    "    np.random.seed(seed)\n",
    "    if compute_num_bound:\n",
    "        \"\"\"\n",
    "        Experiment 2: BOUND CORROBORATION VIA NUMERICAL METHODS\n",
    "        1. Compute analytical bounds for N DGPs\n",
    "        2. Write polynomial program to be solved for numerical bounds\n",
    "        3. You should compare the bounds manually by yourself\n",
    "        \"\"\"\n",
    "        N = 50\n",
    "        file_name = f\"mar_compare-seed{seed}-N{N}.csv\"\n",
    "        results = []\n",
    "        np.random.seed(42)\n",
    "        for dgp in tqdm(range(N)):\n",
    "            case_mar.get_random_cpds()\n",
    "            # compute analytic bounds\n",
    "            _ = case_mar.is_C4_bound_valid()\n",
    "            # write program to solve for num bounds using scip directly in terminal\n",
    "            _ = case_mar.get_numerical_bound(seed, dgp)\n",
    "            results.append(\n",
    "                [seed, dgp, case_mar.lb, case_mar.ub, None, None, case_mar.p_W_R[0, 0]]\n",
    "            )\n",
    "        # save results\n",
    "        pd.DataFrame(\n",
    "            results, columns=[\"seed\", \"dgp\", \"lb\", \"ub\", \"num_lb\", \"num_ub\", \"p_w0_r0\"]\n",
    "        ).to_csv(file_name, mode=\"w\", index=False)\n",
    "    else:\n",
    "        \"\"\"\n",
    "        Experiment 1: BOUND VALIDITY\n",
    "        1. Check claims of ZI MCAR theorem, here denoted as C3, C4a, C4b and constraints M3, M5\n",
    "        \"\"\"\n",
    "        file_name = f\"mar_results-seed{seed}-N{int(N/1000)}k.txt\"\n",
    "        K = np.min([200, N])  # Save file every K steps\n",
    "        columns = [\n",
    "            \"seed\",\n",
    "            \"DGP_#\",\n",
    "            \"is_C3M5\",\n",
    "            \"is_M3\",\n",
    "            \"is_C4_vlid\",\n",
    "            \"is_C4_bint\",\n",
    "            \"lb\",\n",
    "            \"ub\",\n",
    "            *[f\"p_w0_r{i}\" for i in range(cards[\"R\"])],\n",
    "            *[\n",
    "                f\"p_w0_x{i}c{j}\"\n",
    "                for i, j in product(range(cards[\"X\"]), range(cards[\"C\"]))\n",
    "            ],\n",
    "        ]\n",
    "        with open(file_name, \"w\") as f:\n",
    "            f.write(\",\".join(columns) + \"\\n\")\n",
    "            f.close\n",
    "        for b in tqdm(range(int(N / K))):\n",
    "            results = []\n",
    "            for i in range(K):\n",
    "                dgp = b * K + i\n",
    "                case_mar.get_random_cpds()\n",
    "\n",
    "                re_C3M5 = case_mar.is_C3M5()\n",
    "                re_M3 = case_mar.is_M3()\n",
    "                re_C4_vlid = case_mar.is_C4_bound_valid()\n",
    "                case_mar.id_p_w0_r1()\n",
    "                re_C4_bint = case_mar.is_C4_bound_int(20)\n",
    "\n",
    "                results.append(\n",
    "                    [\n",
    "                        seed,\n",
    "                        i,\n",
    "                        re_C3M5,\n",
    "                        re_M3,\n",
    "                        re_C4_vlid,\n",
    "                        re_C4_bint,\n",
    "                        case_mar.lb,\n",
    "                        case_mar.ub,\n",
    "                        *case_mar.p_W_R[0, :],\n",
    "                        *case_mar.p_W_XC[0, :, :].flatten(),\n",
    "                    ]\n",
    "                )\n",
    "            # save results\n",
    "            pd.DataFrame(results, columns=columns).to_csv(\n",
    "                file_name, mode=\"a\", index=False, header=False\n",
    "            )\n",
    "\n",
    "        results = pd.read_csv(file_name)\n",
    "\n",
    "        re_C3M5 = np.sum(results[\"is_C3M5\"]).astype(int)\n",
    "        re_M3 = np.sum(results[\"is_M3\"]).astype(int)\n",
    "        re_C4_vlid = np.sum(results[\"is_C4_vlid\"]).astype(int)\n",
    "        re_C4_bint = np.sum(results[\"is_C4_bint\"]).astype(int)\n",
    "        print(\n",
    "            f\"(C3,M5) p(w0|r1) is identified and equals p(w0|x1, c) for all c, for {re_C3M5}/{N} times ({int(re_C3M5*100/N)}%).\"\n",
    "        )\n",
    "        print(\n",
    "            f\"(M3) The odd-ratio marginal constraint is correct for {re_M3}/{N} times ({int(re_M3*100/N)}%)\"\n",
    "        )\n",
    "        print(\n",
    "            f\"(C4_v) The bound is valid for {re_C4_vlid}/{N} times ({int(re_C4_vlid*100/N)}%).\"\n",
    "        )\n",
    "        print(\n",
    "            f\"(C4_i) The bound interior is valid for {re_C4_bint}/{N} times ({int(re_C4_bint*100/N)}%).\"\n",
    "        )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
