{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a0e67c18-e1c7-4be0-be5e-ded76887fdc8",
   "metadata": {},
   "source": [
    "This notebook illustrates incompatibility issue discussed in the paper https://openreview.net/forum?id=7MYznm5Kp2.\n",
    "\n",
    "We create a DGP (a Bayesian network) w.r.t. ZI MCAR graph in Figure 2(a), and obtain the ground-truth $p(W|R)$. We then sample $100000$ data points $(W_i, X_i)$ from this DGP, and estimate $\\hat{p}(W,X)$ by counting (MLE for categorical data). Finally, we calculate $p(R,X)$ using the Kuroki-Pearl matrix inversion equation. The estimated $p(R,X)$ has negative elements, rendering it invalid.\n",
    "```\n",
    "[ 0.51789683 -0.08300896]\n",
    "[ 0.05589317  0.50921896]\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b2e30271-1bca-4b4b-a070-fddd601e7881",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:pgmpy:Replacing existing CPD for X\n",
      "WARNING:pgmpy:Replacing existing CPD for X1\n",
      "WARNING:pgmpy:Replacing existing CPD for X\n",
      "WARNING:pgmpy:Replacing existing CPD for R\n",
      "WARNING:pgmpy:Replacing existing CPD for W\n",
      "WARNING:pgmpy:Replacing existing CPD for X\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MCAR case. Graph: X(1) -> X <- R -> W\n",
      "\n",
      "\n",
      "True p(W,X):\n",
      " [[0.42643891 0.31215362]\n",
      " [0.14620603 0.11520144]]\n",
      "True p(W|R):\n",
      " [[0.74919143 0.73043156]\n",
      " [0.25080857 0.26956844]]\n",
      "True p(R,X):\n",
      " [[0.43502295 0.        ]\n",
      " [0.13762199 0.42735506]]\n",
      "Computed p(R,X) via matrix inversion using true p(W,X) and true p(W|R):\n",
      " [[ 4.35022949e-01 -1.81411279e-16]\n",
      " [ 1.37621992e-01  4.27355059e-01]]\n",
      "Element-wise difference to true p(R,X):\n",
      " [[1.33226763e-15 1.81411279e-16]\n",
      " [1.38777878e-16 4.99600361e-16]]\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "af727fb8f7d04551b02b8a5c1e84c3ce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Estimated p(W,X):\n",
      " [[0.42883 0.30976]\n",
      " [0.14496 0.11645]]\n",
      "Computed p(R,X) via matrix inversion using estimated p(W,X) and true p(W|R):\n",
      " [[ 0.51789683 -0.08300896]\n",
      " [ 0.05589317  0.50921896]]\n",
      "Element-wise difference to true p(R,X):\n",
      " [[0.08287388 0.08300896]\n",
      " [0.08172882 0.0818639 ]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from pgmpy.models import BayesianNetwork\n",
    "from pgmpy.inference import VariableElimination\n",
    "from pgmpy.factors.discrete.CPD import TabularCPD\n",
    "\n",
    "def add_ZI_consistency_edge(net):\n",
    "    card_R = net.get_cardinality(\"R\")\n",
    "    card_X1 = net.get_cardinality(\"X1\")\n",
    "    \"\"\"\n",
    "    cpd table for p(x | x(1), r)\n",
    "    ---\n",
    "    R=   |     0     |     1     |\n",
    "    X1=  |   0 |   1 |   0 |   1 |\n",
    "    ---\n",
    "    X=0  | 1.0 | 1.0 | 1.0 | 0.0 |\n",
    "    X=1  | 0.0 | 0.0 | 0.0 | 1.0 |\n",
    "    ---\n",
    "    \"\"\"\n",
    "    cpd_X = TabularCPD(\n",
    "        \"X\",\n",
    "        2,\n",
    "        [[1.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]],\n",
    "        evidence=[\"R\", \"X1\"],\n",
    "        evidence_card=[card_R, card_X1],\n",
    "    )\n",
    "    net.add_cpds(cpd_X)\n",
    "    return net\n",
    "\n",
    "\n",
    "def odd_ratio(p, c=None):\n",
    "    \"\"\"\n",
    "    Input:\n",
    "        If c=None, p is 2x2 matrix with p[i,j] = p(a=i|b=j).\n",
    "        If c=0,1, p is 2x2x2 matrix with p[i,j,k] = p(a=i|b=j,c=k).\n",
    "    Output: Odd-ratio p(a=1|b=1,c) / p(a=0|b=1,c) * p(a=0|b=0,c) / p(a=1|b=0,c)\n",
    "    \"\"\"\n",
    "    if c == None:\n",
    "        return p[1, 1] / p[0, 1] * p[0, 0] / p[1, 0]\n",
    "    elif isinstance(c, int):\n",
    "        return p[1, 1, c] / p[0, 1, c] * p[0, 0, c] / p[1, 0, c]\n",
    "\n",
    "\n",
    "class CaseMCAR:\n",
    "    def __init__(self, cards):\n",
    "        \"\"\"\n",
    "        Create the graph X(1) -> X <- R -> W\n",
    "        \"\"\"\n",
    "        self.vertices = [\"X1\", \"X\", \"R\", \"W\"]\n",
    "        self.cards = cards\n",
    "        self.edges = [(\"X1\", \"X\"), (\"R\", \"X\"), (\"R\", \"W\")]\n",
    "        self.net = BayesianNetwork(self.edges)\n",
    "        # self.net.to_daft(node_pos={'X': (0,0), 'X1': (0,1), 'R': (1,0), 'W': (2,0)}).render()\n",
    "        self.get_random_cpds()\n",
    "        # If |x| < epsilon, then x is consider \"zero\" (small noise)\n",
    "        self.eps = 1e-15\n",
    "        self.eps2 = 1e-12\n",
    "        self.lb, self.ub = None, None\n",
    "        self.num_lb, self.num_ub = None, None\n",
    "\n",
    "    def get_random_cpds(self):\n",
    "        \"\"\"\n",
    "        Get a random data-generation process in the model, i.e., satisfying 3 conditions\n",
    "        1. random cpds\n",
    "        2. applying the ZI consistency\n",
    "        3. W must associate to R\n",
    "        \"\"\"\n",
    "        while True:\n",
    "            self.net.get_random_cpds(n_states=self.cards, inplace=True)\n",
    "            self.net = add_ZI_consistency_edge(self.net)\n",
    "            infer = VariableElimination(self.net)\n",
    "            self.p_WR = infer.query([\"W\", \"R\"]).values\n",
    "            self.p_WX = infer.query([\"W\", \"X\"]).values\n",
    "            # Proxy assumption: W must assoc with X, e.g., OR cannot be 1\n",
    "            # Marginal assumption: W must assoc with X\n",
    "            if (odd_ratio(self.p_WR, c=None) != 1) and (\n",
    "                odd_ratio(self.p_WX, c=None) != 1\n",
    "            ):\n",
    "                break\n",
    "\n",
    "        # Compute the conditional distributions p(w|r) and p(w|x)\n",
    "        self.p_RX = infer.query([\"R\", \"X\"]).values\n",
    "        self.p_W_R = np.zeros_like(self.p_WR)\n",
    "        for r in range(self.cards[\"R\"]):\n",
    "            self.p_W_R[:, r] = self.p_WR[:, r] / np.sum(self.p_WR[:, r])\n",
    "\n",
    "\n",
    "np.random.seed(81)\n",
    "\n",
    "# Create the DGP\n",
    "print(f\"MCAR case. Graph: X(1) -> X <- R -> W\")\n",
    "cards = {\"X1\": 2, \"X\": 2, \"R\": 2, \"W\": 2}\n",
    "case_mcar = CaseMCAR(cards)\n",
    "case_mcar.get_random_cpds()\n",
    "\n",
    "# True pieces\n",
    "print(\"\\n\")\n",
    "print(\"True p(W,X):\\n\", case_mcar.p_WX)\n",
    "print(\"True p(W|R):\\n\", case_mcar.p_W_R)\n",
    "print(\"True p(R,X):\\n\", case_mcar.p_RX)\n",
    "\n",
    "# Calculate p(R,X) using Kuroki-Pearl method, use true p(W,X) and true p(W|R)\n",
    "hat_p_RX = np.matmul(np.linalg.inv(case_mcar.p_W_R), case_mcar.p_WX)\n",
    "print(\"Computed p(R,X) via matrix inversion using true p(W,X) and true p(W|R):\\n\", hat_p_RX)\n",
    "print(\"Element-wise difference to true p(R,X):\\n\", np.abs(hat_p_RX - case_mcar.p_RX))\n",
    "\n",
    "# Sample data from the DGP\n",
    "print(\"\\n\")\n",
    "samples = case_mcar.net.simulate(100000)\n",
    "\n",
    "# Estimate p(W,X) from data using plug-in estimator\n",
    "hat_p_WX = np.zeros((2, 2))\n",
    "hat_p_WX[0,0] = np.mean((samples.W == 0) * (samples.X == 0))\n",
    "hat_p_WX[1,0] = np.mean((samples.W == 1) * (samples.X == 0))\n",
    "hat_p_WX[0,1] = np.mean((samples.W == 0) * (samples.X == 1))\n",
    "hat_p_WX[1,1] = np.mean((samples.W == 1) * (samples.X == 1))\n",
    "print(\"Estimated p(W,X):\\n\", hat_p_WX)\n",
    "\n",
    "# Calculate p(R,X) using Kuroki-Pearl method, use estimated p(W,X) and true p(W|R)\n",
    "hat_p_RX = np.matmul(np.linalg.inv(case_mcar.p_W_R), hat_p_WX)\n",
    "print(\"Computed p(R,X) via matrix inversion using estimated p(W,X) and true p(W|R):\\n\", hat_p_RX)\n",
    "print(\"Element-wise difference to true p(R,X):\\n\", np.abs(hat_p_RX - case_mcar.p_RX))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab29d0f4-0210-4dee-a8cf-b4daf151daef",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
