{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed26117a",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pyreadr\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import pandas as pd\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "# import rpy2.robjects as robjects\n",
    "import statsmodels.api as sm\n",
    "import statsmodels.formula.api as smf\n",
    "import multiprocess as mp\n",
    "from collections import namedtuple\n",
    "from multiprocess.pool import ThreadPool\n",
    "from sklearn.dummy import DummyRegressor\n",
    "from scipy.stats import norm\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48492432-fcde-4e30-8518-efa3daa83fef",
   "metadata": {},
   "outputs": [],
   "source": [
    "import signal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "145b6a9e-8689-4e6c-8b66-e6a4c8bfdeb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from statsmodels.gam.api import BSplines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b88d68f3-8893-45e9-9c7e-b26ea524021c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9c35f59-fce8-4ef0-b9e9-c7fac978f2c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dill\n",
    "\n",
    "def write_pkl(obj, fname):\n",
    "    with open(fname, 'wb') as wb:\n",
    "        dill.dump(obj, wb)\n",
    "    return\n",
    "\n",
    "def read_pkl(fname):\n",
    "    with open(fname, 'rb') as rb:\n",
    "        obj = dill.load(rb)\n",
    "    return obj"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70d576a5",
   "metadata": {},
   "source": [
    "<!--$$\\require{cancel}$$-->\n",
    "# Simulation scheme"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9dd25fae",
   "metadata": {},
   "source": [
    "## Data generating process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22aec556",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "np.random.seed(0)\n",
    "N = 10_000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "851e4cc8",
   "metadata": {},
   "source": [
    "Predefine true parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f5cef8",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "# stage 1\n",
    "beta_phi1 = np.array([.5, .5])\n",
    "beta_K1 = np.array([0, 1, 2])\n",
    "beta_p1 = np.array([0, 3, 5, -0.5])\n",
    "beta_mu1 = np.array([0.5, -0.3, 1, -0.5, 0])#0.2])\n",
    "# stage 2\n",
    "beta_phi2 = np.array([0.7, -0.5, 0.5, -0.1])\n",
    "beta_K2 = np.array([-3, 1, 1, 0.5, 1])\n",
    "beta_p2 = np.array([ 0.5 , 2,  1, -0.8 , 0.65])\n",
    "beta_mu2 = np.array([-3, 1, 1.5, -.5, .01, 1.5, 1, -.5])#, 0.12)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdde0bcc",
   "metadata": {},
   "source": [
    "### First stage\n",
    "* Baseline covariate\n",
    "    * $X_1 \\sim \\text{Uniform}(-.3, .7)$ \n",
    "    <!-- * $X_1 \\sim \\text{Bernoulli}(0.75)$ -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56a50760",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "X1 = np.random.uniform(low=-.3, high=.7, size=N)\n",
    "#np.random.binomial(1, 0.75, N)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f254c309",
   "metadata": {},
   "source": [
    "* Propensity\n",
    "    * $A_1 \\sim \\text{Bernoulli}(\\varphi(X_1))$ where\n",
    "      $\\varphi_1(X_1) := \\mathbb P(A_1=1 | X_1)$ is the propensity score for $A_1$\n",
    "    * $\\text{logit}\\left( \\varphi_1(X_1) \\right) = 0.5 + 0.5X_1^2$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07f27cc0-90ab-44a6-9d3d-d7d3772d01ad",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "# expitr = robjects.r(\"\"\"function(x) {\n",
    "#     # inverse logit function\n",
    "#     return( exp(x) / (1 + exp(x)) )\n",
    "# }\"\"\")\n",
    "\n",
    "def expit(x):\n",
    "  \"\"\"Inverse logit function (sigmoid function).\n",
    "\n",
    "  Args:\n",
    "    x: A numeric value or NumPy array.\n",
    "\n",
    "  Returns:\n",
    "    The inverse logit of x.\n",
    "  \"\"\"\n",
    "  return np.exp(x) / (1 + np.exp(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ed889b8-dddc-4f15-8b4b-683af6ce5bde",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_phi1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25cb898a-b809-4d5e-9d1d-9d2268c19168",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi1(x1):\n",
    "    \"\"\"Propensity score for A1.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    beta_phi1: A NumPy array of coefficients.\n",
    "    \n",
    "    Returns:\n",
    "    The propensity score for A1.\n",
    "    \"\"\"\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    return expit(np.c_[np.ones(len(x1)), x1**2] @ beta_phi1)\n",
    "\n",
    "def genA1(x1):\n",
    "    \"\"\"Generate A1.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    beta_phi1: A NumPy array of coefficients.\n",
    "    \n",
    "    Returns:\n",
    "    A binary array of A1 values.\n",
    "    \"\"\"\n",
    "    return np.random.binomial(1, phi1(x1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5643160a-59d8-40e9-96af-df86ebbabcb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "A1 = genA1(X1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad9d8d33",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "print(A1.mean())  # about 61%\n",
    "plt.hist(A1, density=True);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b707fb0",
   "metadata": {},
   "source": [
    "* Censoring\n",
    "    * $C_1 \\sim \\text{Bernoulli}(K_1^{A_1}( X_1))$\n",
    "      where $K_1^{A_1}( X_1) := \\mathbb P(C_1=0|X_1, A_1)$ is the non-censoriong probability for $C_1$\n",
    "    * $\\text{logit}\\left( K_1^{A_1}( X_1) \\right) = X_1^2 + \\eta_1$\n",
    "      for fixed $\\eta_1 = 2.5$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d26458ad-1f8f-4464-913d-bd6f6064061b",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_K1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "611a197d-b5cd-41f7-9eb0-2275c1cce845",
   "metadata": {},
   "outputs": [],
   "source": [
    "eta1 = 2.5\n",
    "\n",
    "def K1(x1, a1):\n",
    "    \"\"\"Censoring probability for C1.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    a1: A numeric value or NumPy array.\n",
    "    \n",
    "    Returns:\n",
    "    The censoring probability for C1.\n",
    "    \"\"\"\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(a1, int):\n",
    "        a1 = np.repeat(a1, len(x1))\n",
    "    return expit(np.c_[np.ones(len(x1)), x1**2] @ beta_K1[:-1] + eta1)\n",
    "    # return expit(np.c_[np.ones(len(x1)), x1, a1] @ beta_K1 + eta1)\n",
    "\n",
    "def genC1(x1, a1):\n",
    "    \"\"\"Generate C1.\n",
    "  \n",
    "    Args:\n",
    "      x1: A numeric value or NumPy array.\n",
    "      a1: A numeric value or NumPy array.\n",
    "  \n",
    "    Returns:\n",
    "      A binary array of C1 values.\n",
    "    \"\"\"\n",
    "    return np.random.binomial(1, 1 - K1(x1, a1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5aa279fd",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def genC1a(x1):\n",
    "    \"\"\"Generate C1^a1.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    genC1_func: A function that generates C1 values.\n",
    "    \n",
    "    Returns:\n",
    "    A NumPy array with two columns, C1_0 and C1_1.\n",
    "    \"\"\"\n",
    "    c1_0 = genC1(x1, 0)\n",
    "    c1_1 = genC1(x1, 1)\n",
    "    c1a = np.c_[c1_0, c1_1]\n",
    "    return c1a\n",
    "\n",
    "C1a = genC1a(X1)\n",
    "C1_0 = C1a[:, 0]\n",
    "C1_1 = C1a[:, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "653e0a7a",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "print( np.mean(C1_0) )  # about 3%\n",
    "print( np.mean(C1_1) )  # about 1%\n",
    "plt.hist(C1_0, density=True, alpha=.5) \n",
    "plt.hist(C1_1, density=True, alpha=.5);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ef74b87",
   "metadata": {},
   "source": [
    "* Survival\n",
    "    * $S_1 \\sim \\text{Bernoulli}(p_1^{A_1}( X_1))$ where $p_1^{A_1}( X_1) := \\mathbb P(S_1=0 |  X_1, A_1, C_1=0)$ is the survival probability for $S_1$\n",
    "    * $\\text{logit}\\left( p_1^{A_1}( X_1) \\right) = 3X_1^2 + 5A_1 - 0.5 A_1X_1$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f3ac4ce-8704-43b9-8e3d-dbfd22ac2c75",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_p1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdf521b9",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def p1(x1, a1):\n",
    "    \"\"\"Survival probability for S1.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    a1: A numeric value or NumPy array.\n",
    "    \n",
    "    Returns:\n",
    "    The survival probability for S1.\n",
    "    \"\"\"\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(a1, int):\n",
    "        a1 = np.repeat(a1, len(x1))\n",
    "    return expit(np.c_[np.ones(len(x1)), x1**2, a1, a1 * x1] @ beta_p1)\n",
    "\n",
    "def genS1(x1, a1):\n",
    "    \"\"\"Generate S1.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    a1: A numeric value or NumPy array.\n",
    "    \n",
    "    Returns:\n",
    "    A binary array of S1 values.\n",
    "    \"\"\"\n",
    "    return np.random.binomial(1, p1(x1, a1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34964aed",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def genS1a(x1):\n",
    "    \"\"\"Generate S1^{a1} for all a1 in {0, 1}.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    \n",
    "    Returns:\n",
    "    A NumPy array with two columns, S1_0 and S1_1.\n",
    "    \"\"\"\n",
    "    s1_0 = genS1(x1, 0)\n",
    "    s1_1 = genS1(x1, 1)\n",
    "    \n",
    "    # # While loop to ensure S1_0 <= S1_1\n",
    "    # while not np.all(s1_0 <= s1_1):\n",
    "    #    indices = np.where(s1_0 > s1_1)[0]\n",
    "    #    s1_0[indices] = genS1(x1[indices], 0)\n",
    "    #    s1_1[indices] = genS1(x1[indices], 1)\n",
    "    \n",
    "    s1a = np.c_[s1_0, s1_1]\n",
    "    return s1a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "351c2958",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "S1a = genS1a(X1)\n",
    "S1_0 = S1a[:, 0]\n",
    "S1_1 = S1a[:, 1]\n",
    "\n",
    "print( np.mean(S1_0) )  # about 75%\n",
    "print( np.mean(S1_1) )  # about 97%\n",
    "\n",
    "print( \"monotonicity?\", np.mean(S1_0 <= S1_1) )  # monotonicity probabilistically?\n",
    "\n",
    "plt.figure(figsize=(4,3))\n",
    "plt.stem(S1a.mean(0))\n",
    "plt.xticks(range(2), [\"S1_0\", \"S1_1\"])\n",
    "plt.ylabel(\"probability\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10007cbb",
   "metadata": {},
   "source": [
    "* Outcome: \n",
    "    * $X_{2} | \\{X_1, A_1\\} \\sim \\mathcal N \\left(0.5 - 0.3X_{1}^2 + A_1 - 0.5A_1X_{1},\\; 1.5^2\\right)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3b514f9-d44c-4290-a690-f60cafa78270",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_mu1[:-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5ac562e",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def mu1(x1, a1, s1):\n",
    "    \"\"\"Intermediate outcome.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    a1: A numeric value or NumPy array.\n",
    "    s1: A numeric value or NumPy array.\n",
    "    \n",
    "    Returns:\n",
    "    The intermediate outcome.\n",
    "    \"\"\"\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(a1, int):\n",
    "        a1 = np.repeat(a1, len(x1))\n",
    "    if isinstance(s1, int):\n",
    "        s1 = np.repeat(s1, len(x1))\n",
    "    # return expit(np.c_[np.ones(len(x1)), x1, a1, a1 * x1] @ beta_mu1[:-1])\n",
    "    return np.c_[np.ones(len(x1)), x1**2, a1, a1 * x1] @ beta_mu1[:-1]\n",
    "\n",
    "def genX2(x1, a1, s1):\n",
    "    \"\"\"Generate X2.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    a1: A numeric value or NumPy array.\n",
    "    s1: A numeric value or NumPy array.\n",
    "    \n",
    "    Returns:\n",
    "    A NumPy array of X2 values.\n",
    "    \"\"\"\n",
    "    # return np.random.binomial(1, mu1(x1, a1, s1))\n",
    "    return np.random.normal(loc=mu1(x1, a1, s1), scale=1.5, size=len(x1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3c77166",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def genX2a(x1, s1a):\n",
    "    \"\"\"Generate X2^a1 for all a1 in {0, 1}.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    s1a: A NumPy array with two columns, S1_0 and S1_1.\n",
    "    \n",
    "    Returns:\n",
    "    A NumPy array with two columns, X2_0 and X2_1.\n",
    "    \"\"\"\n",
    "    x2_0 = genX2(x1, 0, s1a[:, 0])\n",
    "    x2_1 = genX2(x1, 1, s1a[:, 1])\n",
    "    x2a = np.c_[x2_0, x2_1]\n",
    "    return x2a\n",
    "\n",
    "X2a = genX2a(X1, S1a)\n",
    "X2_0 = X2a[:, 0]\n",
    "X2_1 = X2a[:, 1]\n",
    "# X2[C1 == 1 | S1 == 0] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4e1f995-e291-4b24-b273-e67dcac02207",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(4,3))\n",
    "sns.kdeplot(X2a);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "524b5adc",
   "metadata": {},
   "source": [
    "### Second stage\n",
    "\n",
    "* Propensity\n",
    "    * $A_2 \\sim \\text{Bernoulli}(\\varphi_2(\\overline{ X}_2), A_1)$ where\n",
    "      $\\varphi_2(\\overline{ X}_2, A_1) := \\mathbb P(A_2=1 | \\overline{ X}_2, A_1, C_1=0, S_1=1)$ is the propensity score for $A_2$\n",
    "    * $\\text{logit}\\left( \\varphi_2(\\overline{ X}_2, A_1) \\right) = 0.7 - 0.5X_{1}^2 + 0.5X_{2} -0.1 X_2^2$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d64eb874-3b24-45d4-abf4-8198e311f99d",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_phi2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "701d327e-6fcd-4299-b5df-94f5ffb61f0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi2(x1, x2):\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(x2, int):\n",
    "        x2 = np.array([x2])\n",
    "    return expit(np.c_[np.ones(len(x1)), x1**2, x2, x2**2] @ beta_phi2)\n",
    "\n",
    "def genA2(x1, x2):\n",
    "    return np.random.binomial(1, phi2(x1, x2))\n",
    "\n",
    "def genA2a(x1, x2a):\n",
    "    x2_0 = x2a[:, 0]\n",
    "    x2_1 = x2a[:, 1]\n",
    "    a2_0 = genA2(x1, x2_0)\n",
    "    a2_1 = genA2(x1, x2_1)\n",
    "    a2a = np.c_[a2_0, a2_1]\n",
    "    return a2a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "059c5cbd-4032-44db-924b-eb0bb542d369",
   "metadata": {},
   "outputs": [],
   "source": [
    "A2a = genA2a(X1, X2a)\n",
    "A2_0 = A2a[:, 0]\n",
    "A2_1 = A2a[:, 1]\n",
    "# A2[C1 == 1 | S1 == 0] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e205d97a-8613-4699-88cb-d4810be42b45",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(4,3))\n",
    "plt.stem(A2a.mean(0))\n",
    "plt.xticks(range(2), [\"A2_0\", \"A2_1\"])\n",
    "plt.ylabel(\"probability\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9435ea41-119c-42df-bd43-f5c02d5f4cfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    A1.mean(),\n",
    "    A2_0[A1==0].mean(),\n",
    "    A2_1[A1==1].mean()\n",
    ")  # balanced."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6200984e",
   "metadata": {},
   "source": [
    "* Censor\n",
    "    * $C_2 \\sim \\text{Bernoulli}(K_2^{A_1A_2}(\\overline{ X}_2))$ where $K_2^{A_1A_2}(\\overline{ X}_2) := \\mathbb P(C_2=0 | \\overline{ X}_2, \\overline A_2, C_1=0, S_1=1)$ is the non-censoriong probability for $C_2$\n",
    "    * $\\text{logit}\\left( K_2^{A_1A_2}( \\overline X_2) \\right) = -3 + X_1 + X_2 + 0.5 A_2 + A_2X_2 + \\eta_2$ for fixed $\\eta_2=4$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "391a2f1a-30a3-4820-9211-f5b0f82c2b0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_K2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "193b7c3a-1c0d-45c2-8620-87eee0730d4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "eta2 = 4  # Adjust eta2 for different censoring rates\n",
    "# eta2 = 5 : 15% censoring\n",
    "#      = 3 : 20%\n",
    "#      = 0 : 50%\n",
    "\n",
    "# beta_K2 = np.array([-3, 1, 1, 0.5, 1])\n",
    "\n",
    "def K2(x1, x2, a1, a2, eta2=5):\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(x2, int):\n",
    "        x2 = np.array([x2])\n",
    "    if isinstance(a1, int):\n",
    "        a1 = np.repeat(a1, len(x1))\n",
    "    if isinstance(a2, int):\n",
    "        a2 = np.repeat(a2, len(x1))\n",
    "    # return expit(np.c_[np.ones(len(x1)), x1, x2, a1] @ beta_K2[:-1] + eta2)\n",
    "    return expit(np.c_[np.ones(len(x1)), x1, x2, a2, a2 * x2] @ beta_K2 + eta2)\n",
    "\n",
    "def genC2(x1, x2, a1, a2, eta2=5):\n",
    "    return np.random.binomial(1, 1 - K2(x1, x2, a1, a2, eta2))\n",
    "\n",
    "def genC2a(x1, x2a, c1a, eta2=5):\n",
    "    x2_0 = x2a[:, 0]\n",
    "    x2_1 = x2a[:, 1]\n",
    "\n",
    "    c1_0 = c1a[:, 0]\n",
    "    c1_1 = c1a[:, 1]\n",
    "    c2_00 = genC2(x1, x2_0, 0, 0, eta2)\n",
    "    c2_01 = genC2(x1, x2_0, 0, 1, eta2)\n",
    "    c2_10 = genC2(x1, x2_1, 1, 0, eta2)\n",
    "    c2_11 = genC2(x1, x2_1, 1, 1, eta2)\n",
    "\n",
    "    # # monotonicity btw C1 and C2\n",
    "    # c2_00[c2_00 < c1_0] = 1\n",
    "    # c2_01[c2_01 < c1_0] = 1\n",
    "    # c2_10[c2_10 < c1_1] = 1\n",
    "    # c2_11[c2_11 < c1_1] = 1\n",
    "    \n",
    "    c2a = np.c_[c2_00, c2_01, c2_10, c2_11]\n",
    "    return c2a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baf7d63c-376f-4efc-8437-54b2a84b0b70",
   "metadata": {},
   "outputs": [],
   "source": [
    "C2a = genC2a(X1, X2a, C1a)\n",
    "C2_00 = C2a[:, 0]\n",
    "C2_01 = C2a[:, 1]\n",
    "C2_10 = C2a[:, 2]\n",
    "C2_11 = C2a[:, 3]\n",
    "# C2[C1 == 1 | S1 == 0] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30281b01-029d-4c74-9f77-b2a016e435cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Mean of C2_00:\", np.mean(C2_00, where=~np.isnan(C2_00)))\n",
    "print(\"Mean of C2_01:\", np.mean(C2_01, where=~np.isnan(C2_01)))\n",
    "print(\"Mean of C2_10:\", np.mean(C2_10, where=~np.isnan(C2_10)))\n",
    "print(\"Mean of C2_11:\", np.mean(C2_11, where=~np.isnan(C2_11)))\n",
    "\n",
    "plt.figure(figsize=(4,3))\n",
    "plt.stem(C2a.mean(0))\n",
    "plt.xticks(range(4), [\"C2_00\", \"C2_01\", \"C2_10\", \"C2_11\"])\n",
    "plt.ylabel(\"probability\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36101959",
   "metadata": {},
   "source": [
    "* Survival\n",
    "    * $S_2 \\sim \\text{Bernoulli}(p_2^{A_1A_2}(\\overline{X}_2))$ where $p_2^{A_1A_2}(\\overline{X}_2) := \\mathbb P(S_2=1 | \\overline{X}_2, \\overline A_2, C_2=0, S_1=1)$ is the survival probability for $S_2$\n",
    "    * $\\text{logit}\\left( p_2^{A_1A_2}(\\overline{X}_2) \\right) = 0.5 + 2 X_1 + X_1X_2 -0.8 A_1 + 0.65 A_2$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0699900a-264c-4dc5-a9ce-e89e476db075",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_p2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd49d7ad-36ee-4559-97b9-bcf9af1648eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# beta_p2 = np.array([ 0.8 , -1.42,  0.8 , -0.65])\n",
    "\n",
    "def p2(x1, x2, a1, a2):\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(x2, int):\n",
    "        x2 = np.array([x2])\n",
    "    if isinstance(a1, int):\n",
    "        a1 = np.repeat(a1, len(x1))\n",
    "    if isinstance(a2, int):\n",
    "        a2 = np.repeat(a2, len(x1))\n",
    "    return expit(np.c_[np.ones(len(x1)), x1, x1*x2, a1, a2] @ beta_p2)\n",
    "\n",
    "def genS2(x1, x2, a1, a2):\n",
    "    return np.random.binomial(1, p2(x1, x2, a1, a2))\n",
    "\n",
    "def genS2a(x1, x2a, s1a):\n",
    "    x2_0 = x2a[:, 0]\n",
    "    x2_1 = x2a[:, 1]\n",
    "    \n",
    "    s1_0 = s1a[:, 0]\n",
    "    s1_1 = s1a[:, 1]\n",
    "    s2_00 = genS2(x1, x2_0, 0, 0)\n",
    "    s2_01 = genS2(x1, x2_0, 0, 1)\n",
    "    s2_10 = genS2(x1, x2_1, 1, 0)\n",
    "    s2_11 = genS2(x1, x2_1, 1, 1)\n",
    "\n",
    "    # # monotonicity btw S1 and S2\n",
    "    # s2_00[s2_00 > s1_0] = 0\n",
    "    # s2_01[s2_01 > s1_0] = 0\n",
    "    # s2_10[s2_10 > s1_1] = 0\n",
    "    # s2_11[s2_11 > s1_1] = 0\n",
    "\n",
    "    # # While loop to ensure monotonicity\n",
    "    # while not np.all((s2_00 <= s2_01) & (s2_00 <= s2_10) & (s2_01 <= s2_11) & (s2_10 <= s2_11)):\n",
    "    #     indices = np.where(~((s2_00 <= s2_01) & (s2_00 <= s2_10) & (s2_01 <= s2_11) & (s2_10 <= s2_11)))[0]\n",
    "    #     s2_00[indices] = genS2(x1[indices], x2_0[indices], 0, 0)\n",
    "    #     s2_01[indices] = genS2(x1[indices], x2_0[indices], 0, 1)\n",
    "    #     s2_10[indices] = genS2(x1[indices], x2_1[indices], 1, 0)\n",
    "    #     s2_11[indices] = genS2(x1[indices], x2_1[indices], 1, 1)\n",
    "\n",
    "    #     s2_00[s2_00 > s1_0] = 0\n",
    "    #     s2_01[s2_01 > s1_0] = 0\n",
    "    #     s2_10[s2_10 > s1_1] = 0\n",
    "    #     s2_11[s2_11 > s1_1] = 0\n",
    "\n",
    "    s2a = np.c_[s2_00, s2_01, s2_10, s2_11]\n",
    "    return s2a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd659bb0-ca35-4828-aee6-b71fd8bb4e9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "S2a = genS2a(X1, X2a, S1a)\n",
    "S2_00 = S2a[:, 0]\n",
    "S2_01 = S2a[:, 1]\n",
    "S2_10 = S2a[:, 2]\n",
    "S2_11 = S2a[:, 3]\n",
    "# S2[C1 == 1 | C2 == 1 | S1 == 0] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d2447bc-3242-4a22-81fc-6df23b803096",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Mean of S2_00:\", np.mean(S2_00))\n",
    "print(\"Mean of S2_01:\", np.mean(S2_01))\n",
    "print(\"Mean of S2_10:\", np.mean(S2_10))\n",
    "print(\"Mean of S2_11:\", np.mean(S2_11))\n",
    "\n",
    "plt.figure(figsize=(4,3))\n",
    "plt.stem(S2a.mean(0))\n",
    "plt.xticks(range(4), [\"S2_00\", \"S2_01\", \"S2_10\", \"S2_11\"])\n",
    "plt.ylabel(\"probability\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f47a650",
   "metadata": {},
   "source": [
    "* Outcome: \n",
    "    * $Y | \\{\\bar X, \\bar A, \\bar C, \\bar S\\} = \\mu_2^{A_1A_2}(\\overline{\\mathbf X}_2) + \\epsilon_2$\n",
    "    * $\\mu_2^{A_1A_2}(\\overline{\\mathbf X}_2) = -3 + X_1 + 1.5A_1 -0.5 A_1X_1 + \\exp(X_2)/100 + A_2(1.5 + A_1 -0.5 X_2)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b49ddaf3-9adf-4f4d-bd86-d9d6d45188e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_mu2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a81ea5d-8626-4540-8fc2-f6711d8de6fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mu2(x1, x2, a1, a2):\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(x2, int):\n",
    "        x2 = np.array([x2])\n",
    "    if isinstance(a1, int):\n",
    "        a1 = np.repeat(a1, len(x1))\n",
    "    if isinstance(a2, int):\n",
    "        a2 = np.repeat(a2, len(x1))\n",
    "    return np.dot(np.c_[np.ones(len(x1)), x1, a1, a1*x1, np.exp(x2), a2, a2*a1, a2*x2], beta_mu2)\n",
    "\n",
    "def genY(x1, x2, a1, a2):\n",
    "    return mu2(x1, x2, a1, a2) + np.random.normal(loc=0, scale=1.5, size=len(x1))\n",
    "\n",
    "def genYa(x1, x2a):\n",
    "    x2_0 = x2a[:, 0]\n",
    "    x2_1 = x2a[:, 1]\n",
    "\n",
    "    y_00 = genY(x1, x2_0, 0, 0)\n",
    "    y_01 = genY(x1, x2_0, 0, 1)\n",
    "    y_10 = genY(x1, x2_1, 1, 0)\n",
    "    y_11 = genY(x1, x2_1, 1, 1)\n",
    "\n",
    "    ya = np.column_stack((y_00, y_01, y_10, y_11))\n",
    "    return ya"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc6492de-2c36-49a5-be49-4cfeec8446ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "Ya = genYa(X1, X2a)\n",
    "Y_00 = Ya[:, 0]\n",
    "Y_01 = Ya[:, 1]\n",
    "Y_10 = Ya[:, 2]\n",
    "Y_11 = Ya[:, 3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d0eaac7-c524-409e-b345-ab50936fdd54",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.kdeplot(Y_00, color='red', label='Y_00')\n",
    "sns.kdeplot(Y_01, color='green', label='Y_01')\n",
    "sns.kdeplot(Y_10, color='blue', label='Y_10')\n",
    "sns.kdeplot(Y_11, color='purple', label='Y_11')\n",
    "plt.xlabel('Y')\n",
    "plt.ylabel('Density')\n",
    "# plt.ylim(0, 0.12)\n",
    "# plt.xlim(-15, 15)\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8008ce93",
   "metadata": {},
   "source": [
    "## Functions `follow.a1()` and `follow.a1a2()`\n",
    "* Return potential outcomes of a variable if it were to follow specific treatments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffc8e9d2-d567-4a15-9654-8d7af055a985",
   "metadata": {},
   "outputs": [],
   "source": [
    "def follow_a1(a1, alldata, var=\"a2\"):\n",
    "    ind = np.vectorize(lambda x: f\"{var}_{int(x)}\")(a1)\n",
    "    return np.array([alldata.loc[alldata.index[i], ind[i]] for i in range(len(alldata))])\n",
    "\n",
    "def follow_a1a2(a1, a2, alldata, var=\"y\"):\n",
    "    ind = np.vectorize(lambda x, y: f\"{var}_{int(x)}{int(y)}\")(a1, a2)\n",
    "    return np.array([alldata.loc[alldata.index[i], ind[i]] for i in range(len(alldata))])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa11026f",
   "metadata": {},
   "source": [
    "## Function `genData()`\n",
    "* Generates all the potential outcomes in trajectories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ba9a504-0d44-482d-9d08-7930e3ce7b50",
   "metadata": {},
   "outputs": [],
   "source": [
    "def genData(N, eta2):\n",
    "    # Stage 1\n",
    "    x1 = np.random.uniform(low=-.3, high=.7, size=N)\n",
    "    #np.random.binomial(1, 0.75, size=N)\n",
    "    a1 = genA1(x1)\n",
    "    c1a = genC1a(x1)\n",
    "    s1a = genS1a(x1)\n",
    "    x2a = genX2a(x1, s1a)\n",
    "\n",
    "    # Stage 2\n",
    "    a2a = genA2a(x1, x2a)\n",
    "    c2a = genC2a(x1, x2a, c1a, eta2)\n",
    "    s2a = genS2a(x1, x2a, s1a)\n",
    "    ya = genYa(x1, x2a)\n",
    "\n",
    "    # All data\n",
    "    dfa = pd.DataFrame(np.c_[x1, x2a, a1, a2a, c1a, c2a, s1a, s2a, ya])\n",
    "    #print(dfa.head())\n",
    "    dfa.columns = ['x1', 'x2_0', 'x2_1', 'a1', 'a2_0', 'a2_1', \n",
    "                   'c1_0', 'c1_1', 'c2_00', 'c2_01', 'c2_10', 'c2_11',\n",
    "                   's1_0', 's1_1', 's2_00', 's2_01', 's2_10', 's2_11',\n",
    "                   'y_00', 'y_01', 'y_10', 'y_11']\n",
    "    \n",
    "    dfa[\"X1\"] = dfa[\"x1\"]\n",
    "    dfa[\"A1\"] = dfa[\"a1\"]\n",
    "    dfa[\"C1\"] = follow_a1(dfa.a1, dfa, var=\"c1\")\n",
    "    dfa[\"S1\"] = follow_a1(dfa.a1, dfa, var=\"s1\")\n",
    "    dfa[\"X2\"] = follow_a1(dfa.a1, dfa, var=\"x2\")\n",
    "    a2a1_a = follow_a1(dfa.a1, dfa, var=\"a2\")\n",
    "    dfa[\"A2\"] = a2a1_a\n",
    "    dfa[\"C2\"] = follow_a1a2(dfa.a1, a2a1_a, dfa, var=\"c2\")\n",
    "    dfa[\"S2\"] = follow_a1a2(dfa.a1, a2a1_a, dfa, var=\"s2\")\n",
    "    dfa[\"Y\"] = follow_a1a2(dfa.a1, a2a1_a, dfa, var=\"y\")\n",
    "    dfa.loc[dfa.C1==1, [\"S1\", \"X2\", \"A2\", \"S2\", \"Y\"]] = 0\n",
    "    dfa.loc[(~np.isnan(dfa.S1)) & (dfa.S1==0), [\"X2\", \"A2\", \"S2\", \"Y\"]] = 0\n",
    "    dfa.loc[(~np.isnan(dfa.C2)) & (dfa.C2==1), [\"S2\", \"Y\"]] = 0\n",
    "    dfa.loc[(~np.isnan(dfa.S2)) & (dfa.S2==0), \"Y\"] = 0\n",
    "\n",
    "    # Observed data\n",
    "    dfs = dfa.loc[:, [\"X1\", \"A1\", \"C1\", \"S1\",\n",
    "                      \"X2\", \"A2\", \"C2\", \"S2\", \"Y\"]].copy()\n",
    "    dfs.loc[dfs.C1==1, [\"S1\", \"X2\", \"A2\", \"C2\", \"S2\", \"Y\"]] = np.nan  # Set NA for C1==1 rows\n",
    "    dfs.loc[(~np.isnan(dfs.S1)) & (dfs.S1==0), [\"X2\", \"A2\", \"C2\", \"S2\", \"Y\"]] = np.nan  # Set NA for S1==0 rows\n",
    "    dfs.loc[(~np.isnan(dfs.C2)) & (dfs.C2==1), [\"S2\", \"Y\"]] = np.nan  # Set NA for C2==1 rows\n",
    "    dfs.loc[(~np.isnan(dfs.S2)) & (dfs.S2==0), \"Y\"] = np.nan  # Set NA for S2==0 rows\n",
    "\n",
    "    # Always survivors\n",
    "    dfas = dfa[s2a[:, 0] == 1]  # Filter for always survivors\n",
    "    \n",
    "    # Return data\n",
    "    return {\"all\": dfa, \"obs\": dfs, \"as\": dfas}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be237808-fc5d-49b3-8ed7-013c7c2571e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0)\n",
    "\n",
    "pd.options.mode.chained_assignment = None\n",
    "\n",
    "data_15 = genData(N=10_000, eta2=eta2)\n",
    "XX_all15 = data_15[\"all\"]\n",
    "XX_obs15 = data_15[\"obs\"]\n",
    "XX_as15 = data_15[\"as\"]\n",
    "\n",
    "# print(XX_all15.head())\n",
    "print(XX_all15.shape)\n",
    "\n",
    "pd.options.mode.chained_assignment = 'warn'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9454d409-3cf7-47ef-86af-bc704d84c227",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sum((XX_obs15.S2 == 1) & (~np.isnan(XX_obs15.S2))), \n",
    "      \"observed survivors / total\", XX_obs15.shape[0])\n",
    "print(XX_as15.shape[0], \n",
    "      \"always survivors / total\", XX_obs15.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee56883d-dfb2-4067-bd8f-4e1c22e469a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# linear decision rules\n",
    "def reg1(eta1, eta2, data):\n",
    "    criteria = eta1 + eta2 * data[\"X1\"] >= 0\n",
    "    d1 = np.array(criteria, dtype=int)\n",
    "    return d1\n",
    "\n",
    "def reg2(eta1, eta2, eta3, data):\n",
    "    criteria = eta1 + eta2 * data[\"X1\"] + eta3 * data[\"X2\"] >= 0\n",
    "    d2 = np.array(criteria, dtype=int)\n",
    "    return d2\n",
    "\n",
    "eta_true = np.array([-0.14,  0.99, 0, -0.707, 0.707])\n",
    "# eta_true = np.array([0.511, 0.303, 0.378, 0.046, 0.642])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57a2c5bf-5757-4546-a1af-b6e698a45c75",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "df_large = genData(500_000, eta2)\n",
    "df_all = df_large[\"all\"]\n",
    "write_pkl(df_large, \"df_large_sim2.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "419acb62-b0b9-44ed-974c-a9306f40b2c6",
   "metadata": {},
   "source": [
    "df_large = read_pkl(\"df_large_sim2.pkl\")\n",
    "df_all = df_large[\"all\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab42c75d-ccb5-4524-91c9-7149ee9922a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sum((df_large[\"obs\"].S2 == 1) & (~np.isnan(df_large[\"obs\"].S2))), \n",
    "      \"observed survivors / total\", df_large[\"obs\"].shape[0])\n",
    "print(df_large[\"as\"].shape[0], \n",
    "      \"always survivors / total\", df_large[\"obs\"].shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbad9d10-fa2d-48fc-89a8-0bdbac1156be",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_large['obs'].C1.mean(), df_large['obs'].C2.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5494d26b-d4a9-4c27-9eac-331ef379bc1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_large['obs'].S1.mean(), df_large['obs'].S2.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "667242a0-f17c-455f-98a2-fb56113ba758",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f7cad1b3-79a9-44c0-a241-94aa5d3d3b5e",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "# Estimators"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32b8005e",
   "metadata": {},
   "source": [
    "## Fit all the nuisance models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bb46396-8257-490a-83b3-7bd9da9865a1",
   "metadata": {},
   "source": [
    "- first stage\n",
    "  - $\\text{logit}\\left( \\varphi_1(X_1) \\right) = 0.5 + 0.5X_1^2$\n",
    "  - $\\text{logit}\\left( K_1^{A_1}( X_1) \\right) = X_1^2 + \\eta_1$\n",
    "      for fixed $\\eta_1 = 2.5$\n",
    "  - $\\text{logit}\\left( p_1^{A_1}( X_1) \\right) = 3X_1^2 + 5A_1 - 0.5 A_1X_1$\n",
    "- second stage\n",
    "  - $\\text{logit}\\left( \\varphi_2(\\overline{ X}_2, A_1) \\right) = 0.7 - 0.5X_{1}^2 + 0.5X_{2} -0.1 X_2^2$\n",
    "  - $\\text{logit}\\left( K_2^{A_1A_2}( \\overline X_2) \\right) = -3 + X_1 + X_2 + 0.5 A_2 + A_2X_2 + \\eta_2$ for fixed $\\eta_2=4$\n",
    "  - $\\text{logit}\\left( p_2^{A_1A_2}(\\overline{X}_2) \\right) = 0.5 + 2 X_1 + X_1X_2 -0.8 A_1 + 0.65 A_2$\n",
    "- outcome\n",
    "  - $\\mu_2^{A_1A_2}(\\overline{\\mathbf X}_2) = -3 + X_1 + 1.5A_1 -0.5 A_1X_1 + \\exp(X_2)/100 + A_2(1.5 + A_1 -0.5 X_2)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ffb4eac-7653-4497-96d8-0765dbe9ab24",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.options.mode.chained_assignment = None\n",
    "np.set_printoptions(3, suppress=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b5296a5-2604-4510-9b99-e6d72be63854",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_models(data,\n",
    "               phi1_false=False, phi2_false=False,\n",
    "               K1_false=False, K2_false=False,\n",
    "               p1_false=False, p2_false=False,\n",
    "               mu2_false=False, Ep2_false=False, \n",
    "               Emupi_false=False):\n",
    "    \"\"\"\n",
    "    Fits various nuisance models based on boolean flags.\n",
    "\n",
    "    Args:\n",
    "        data: A pandas DataFrame containing the observed data.\n",
    "        phi1_false, phi2_false, etc.: Boolean flags indicating whether to use\n",
    "            a point-five model instead of fitting the corresponding model.\n",
    "        Ep2_false, Emu2p2_false: (Currently unused)\n",
    "\n",
    "    Returns:\n",
    "        A dictionary containing fitted models and prediction functions.\n",
    "    \"\"\"\n",
    "\n",
    "    # Fit models based on flags\n",
    "    models = {}\n",
    "    if not phi1_false:\n",
    "        models['phi1.hat'] = smf.glm(\"A1 ~ 1 + I(X1**2)\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['phi1.hat'] = smf.glm(\"A1 ~ 0 + X1 + I(X1**2)\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "        #None  # Placeholder for point-five model\n",
    "\n",
    "    if not phi2_false:\n",
    "        models['phi2.hat'] = smf.glm(\"A2 ~ 1 + I(X1**2) + X2 + I(X2**2)\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['phi2.hat'] = smf.glm(\"A2 ~ 1 + X2\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "        #None\n",
    "\n",
    "    if not K1_false:\n",
    "        models['K1.hat'] = smf.glm(\"I(1-C1) ~ 1 + I(X1**2)\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['K1.hat'] = smf.glm(\"I(1-C1) ~ 0 + X1\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "        #None\n",
    "\n",
    "    if not K2_false:\n",
    "        models['K2.hat'] = smf.glm(\"I(1-C2) ~ 1 + X1 + X2 + A2 + A2:X2\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['K2.hat'] = smf.glm(\"I(1-C2) ~ 1 + X2\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "        None\n",
    "\n",
    "    if not p1_false:\n",
    "        models['p1.hat'] = smf.glm(\"S1 ~ 1 + I(X1**2) + A1 + A1:X1\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['p1.hat'] = smf.glm(\"S1 ~ 0 + A1\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "        #None\n",
    "\n",
    "    if not p2_false:\n",
    "        models['p2.hat'] = smf.glm(\"S2 ~ 1 + X1 + X1:X2 + A1 + A2\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['p2.hat'] = smf.glm(\"S2 ~ 0 + A1 + A2\",\n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "        #None\n",
    "\n",
    "    if not mu2_false:\n",
    "        models['mu2.hat'] = smf.ols(\"Y ~ 1 + X1 + A1 + A1:X1 + I(np.exp(X2)) + A2 + A1:A2 + X2:A2\", \n",
    "                                    data=data).fit()\n",
    "    else:\n",
    "        models['mu2.hat'] = smf.ols(\"Y ~ 1 + X1 + X2\", \n",
    "                                    data=data).fit()\n",
    "        #None  # Placeholder for zero model\n",
    "\n",
    "    # Define prediction functions for fitted models\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        \n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'],\n",
    "            'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        pred = models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def ps2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, len(data.X1))\n",
    "        \n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        pred = models['phi2.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a2)) & (a2 == 0)] = 1 - pred[(~np.isnan(a2)) & (a2 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        return models['K1.hat'].predict(fitdf).values\n",
    "\n",
    "    def cp2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['K2.hat'].predict(fitdf).values\n",
    "    \n",
    "    def sp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'A1': a1, \n",
    "        })\n",
    "        return models['p1.hat'].predict(fitdf).values\n",
    "    \n",
    "    def sp2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'X2': data.X2,\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['p2.hat'].predict(fitdf).values\n",
    "\n",
    "    def pcs1pcs2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2) * sp2(a1, a2)\n",
    "    \n",
    "    def pcs1pc2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    def pc1(a1):\n",
    "        return ps1(a1) * cp1(a1)\n",
    "    \n",
    "    def m2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['mu2.hat'].predict(fitdf).values\n",
    "\n",
    "    # Return dictionary with models and prediction functions\n",
    "    return {**models,  # Unpack models dictionary\n",
    "            'ps1': ps1, 'ps2': ps2,  # Prediction functions\n",
    "            'cp1': cp1, 'cp2': cp2,\n",
    "            'sp1': sp1, 'sp2': sp2, 'm2': m2,\n",
    "            'pcs1pcs2': pcs1pcs2, 'pcs1pc2': pcs1pc2, 'pcs1': pcs1, 'pc1': pc1,\n",
    "            'Ep2.false': Ep2_false, 'Emupi.false': Emupi_false}  # Currently unused"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e6285b4",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "models = fit_models(XX_obs15)\n",
    "m_123457 = fit_models(XX_obs15, mu2_false=True)\n",
    "m_12345 = fit_models(XX_obs15, mu2_false=True, Emupi_false=True)\n",
    "m_12367 = fit_models(XX_obs15, p2_false=True, Ep2_false=True)\n",
    "\n",
    "m_1234 = fit_models(XX_obs15, mu2_false=True)\n",
    "m_1236 = fit_models(XX_obs15, p2_false=True)\n",
    "\n",
    "m_1235 = fit_models(XX_obs15, Emupi_false=True)\n",
    "m_1237 = fit_models(XX_obs15, Ep2_false=True)\n",
    "\n",
    "m_4567 = fit_models(XX_obs15, phi1_false=True, K1_false=True,\n",
    "                    p1_false=True,\n",
    "                    phi2_false=True, K2_false=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57b5be2f",
   "metadata": {},
   "source": [
    "## Plug-in estimator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d76f2351-145c-4e07-b6b5-4c023ae5bda9",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "source": [
    "### plugin for D\n",
    "\n",
    "$$\n",
    "D = \\mathbb E \\bigg[ p_1^0(X_1) \\mathbb E\\big( p_2^0(H_1) \\big| X_1, A_1=0, C_1=0, S_1=1 \\big) \\bigg],\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "989af59c-9ef9-4386-a6de-04219704efdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def m_p200(data, models):\n",
    "    Ep2_false = models['Ep2.false']\n",
    "\n",
    "    def sp2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'X2': data.X2,\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['p2.hat'].predict(fitdf).values\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'],\n",
    "            'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        pred = models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        return models['K1.hat'].predict(fitdf).values\n",
    "\n",
    "    def sp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'A1': a1, \n",
    "        })\n",
    "        return models['p1.hat'].predict(fitdf).values\n",
    "        \n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "\n",
    "    # Conditional mean of p_2^0(H_1)\n",
    "    p200x = sp2(0, 0)\n",
    "    data_filled[\"target\"] = (data_filled.A1 == 0) * \\\n",
    "                            (data_filled.C1 == 0) * \\\n",
    "                            (data_filled.S1 == 1) / pcs1(0) * p200x  # Add target column\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    if not Ep2_false:\n",
    "        # # OLS\n",
    "        # m_p200_model = smf.ols(\"target ~ 1 + X1\", data=data_filtered).fit()\n",
    "        # return m_p200_model.predict(data_filled)\n",
    "\n",
    "        # GAM\n",
    "        bs = BSplines(data_filled.X1, df=[3+3+1], degree=[3])\n",
    "        m_p200_model = smf.glmgam(\"target ~ 1 + X1 + I(X1**2)\", data=data_filled, smoother=bs).fit()\n",
    "        return m_p200_model.predict(data_filled, exog_smooth=bs.x)\n",
    "    else:\n",
    "        m_p200_model = smf.ols(\"target ~ 1\", data=data_filled).fit()\n",
    "        return m_p200_model.predict(data_filled)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73e45080-4831-4428-bbc6-f8fe195ce16f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def D_hat(data, models):\n",
    "    \"\"\"\n",
    "    Calculates the D-hat estimator based on data, fitted models, and Ep2.false flag.\n",
    "    \n",
    "    Args:\n",
    "      data: A NumPy array containing the data.\n",
    "      models: A dictionary containing fitted models (ps1, ps2, cp1, cp2, sp1, sp2).\n",
    "      Ep2_false: Boolean flag indicating whether to use a zero model for m_p200.\n",
    "    \n",
    "    Returns:\n",
    "      A float representing the D-hat estimator.\n",
    "    \"\"\"\n",
    "    sp1 = models['sp1']\n",
    "    Ep2_false = models['Ep2.false']\n",
    "    \n",
    "    # Conditional mean of p_2^0(H_1)\n",
    "    p10x = sp1(0)\n",
    "    E_p200 = m_p200(data, models)\n",
    "    \n",
    "    # D-hat calculation (empirical version)\n",
    "    Dhat = p10x * E_p200\n",
    "    \n",
    "    # Handle missing values and calculate mean\n",
    "    Dhat[np.isnan(Dhat)] = 0\n",
    "    Dhat = np.mean(Dhat)\n",
    "    \n",
    "    return Dhat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3382ace-52a1-4e03-9033-e53e19d0182f",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    D_hat(XX_obs15, models),\n",
    "    D_hat(XX_obs15, m_123457),\n",
    "    D_hat(XX_obs15, m_12345),\n",
    "    D_hat(XX_obs15, m_4567)\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "12cb1ea8-8282-4a56-90e3-b26dc6e0b3dc",
   "metadata": {},
   "source": [
    "### Plugin for N\n",
    "$$\n",
    "N = \\mathbb E \\bigg[ g(X_1) p_1^0(X_1) \\mathbb E\\big( p_2^0(H_1) \\big| X_1, A_1=0, C_1=0, S_1=1 \\big) \\bigg],\n",
    "$$\n",
    "\n",
    "$$\n",
    "\\begin{align*}\n",
    "g(X_1)\n",
    " &= \\sum_{a_1,a_2 \\in \\{0,1\\}} m_{\\mu_2, \\pi}(x_1, \\bar a) \\mathbf1\\{\\pi_1(X_1) = a_1\\} \\\\\n",
    " &= \\sum_{a_1,a_2 \\in \\{0,1\\}} \\mathbb E \\bigg[ \\mu_2^{a_2}(H_1) \\mathbf1\\{\\pi_2(H_1)=a_2\\} \\bigg| X_1, A_1=a_1, C_1=0, S_1=1 \\bigg] \\mathbf1\\{\\pi_1(X_1) = a_1\\} \\\\\n",
    " &= \\sum_{a_1 \\in \\{0,1\\}} \\mathbb E \\bigg[ \\sum_{a_2} \\mu_2^{a_2}(H_1) \\mathbf1\\{\\pi_2(H_1)=a_2\\} \\bigg| X_1, A_1=a_1, C_1=0, S_1=1 \\bigg] \\mathbf1\\{\\pi_1(X_1) = a_1\\}\n",
    "\\end{align*}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9d787c7-c5c0-4d81-93cd-cae88b0991b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def g_(d1, d2, data, models):\n",
    "    Emupi_false = models['Emupi.false']\n",
    "\n",
    "    def m2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['mu2.hat'].predict(fitdf).values\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'],\n",
    "            'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        pred = models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        return models['K1.hat'].predict(fitdf).values\n",
    "\n",
    "    def sp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'A1': a1, \n",
    "        })\n",
    "        return models['p1.hat'].predict(fitdf).values\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    \n",
    "    data_filled[\"target\"] = (data_filled.A1 == d1) * \\\n",
    "                            (data_filled.C1 == 0) * \\\n",
    "                            (data_filled.S1 == 1) / pcs1(d1) * \\\n",
    "                            m2(d1, d2)\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    if not Emupi_false:\n",
    "        # # OLS\n",
    "        # m_m2_model = smf.ols(\"target ~ 1 + X1\", data=data_filtered).fit()\n",
    "        # m_m2 = m_m2_model.predict(data_filled)\n",
    "\n",
    "        # GAM\n",
    "        bs = BSplines(data_filled[\"X1\"], df=[3+3+1], degree=[3])\n",
    "        m_m2_model = smf.glmgam(\"target ~ 1 + X1 + I(X1**2)\", data=data_filled, smoother=bs).fit()\n",
    "        m_m2 = m_m2_model.predict(data_filled, exog_smooth=bs.x)\n",
    "    else:\n",
    "        m_m2_model = smf.ols(\"target ~ 1\", data=data_filled).fit()\n",
    "        m_m2 = m_m2_model.predict(data_filled)\n",
    "    \n",
    "    return m_m2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c0a4d23-a18e-48e2-8e6a-e5c441fb122e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pygam import LinearGAM, s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3748666-ac98-4864-b88c-5154818b445b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def m_mu_pi(d1, d2, a1, a2, data, models):\n",
    "    Emupi_false = models['Emupi.false']\n",
    "    \n",
    "    def m2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['mu2.hat'].predict(fitdf).values\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'],\n",
    "            'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        pred = models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        return models['K1.hat'].predict(fitdf).values\n",
    "\n",
    "    def sp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'A1': a1, \n",
    "        })\n",
    "        return models['p1.hat'].predict(fitdf).values\n",
    "        \n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    \n",
    "    data_filled[\"target\"] = (data_filled.A1 == a1) * \\\n",
    "                            (data_filled.C1 == 0) * \\\n",
    "                            (data_filled.S1 == 1) / pcs1(a1) * \\\n",
    "                            m2(a1, a2) * (d2 == a2)\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    \n",
    "    if (data_filled.target.mean() != 0) and (data_filled.target.mean() != 1):\n",
    "        if not Emupi_false:\n",
    "            # # OLS\n",
    "            # m_m2_model = smf.ols(\"target ~ 1 + X1\", data=data_filtered).fit()\n",
    "    \n",
    "            x1_x1sq = np.c_[data_filled.X1, data_filled.X1**2]\n",
    "            m_m2_model = LinearGAM(s(0)).fit(x1_x1sq, data_filled.target)\n",
    "            m_m2 = m_m2_model.predict(x1_x1sq)\n",
    "        else:\n",
    "            m_m2_model = smf.ols(\"target ~ 1\", data=data_filled).fit()\n",
    "            m_m2 = m_m2_model.predict(data_filled)\n",
    "    else:\n",
    "        m_m2 = np.zeros(data_filled.shape[0])\n",
    "    \n",
    "    return m_m2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b253a53b-1170-4a45-9dda-a7ebee079704",
   "metadata": {},
   "outputs": [],
   "source": [
    "def N_hat(d1, d2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the N-hat estimator based on data, fitted models, Ep2.false flag,\n",
    "    and additional arguments d1 and d2.\n",
    "\n",
    "    Args:\n",
    "        d1: A scalar value.\n",
    "        d2: A scalar value.\n",
    "        data: A NumPy array containing the data.\n",
    "        models: A dictionary containing fitted models (ps1, ps2, cp1, cp2, sp1, sp2, m2).\n",
    "        Ep2_false: Boolean flag indicating whether to use a zero model for m_p200.\n",
    "\n",
    "    Returns:\n",
    "        A float representing the N-hat estimator.\n",
    "    \"\"\"\n",
    "    Ep2_false = models['Ep2.false']\n",
    "    Emupi_false = models['Emupi.false']\n",
    "\n",
    "    # Extract prediction functions for models\n",
    "    sp1 = models['sp1']; sp2 = models['sp2']\n",
    "    m2 = models['m2']\n",
    "\n",
    "    # Conditional mean of p_2^0(H_1)\n",
    "    p10x = sp1(0)\n",
    "    p200x = sp2(0, 0)\n",
    "    # data_filled[\"target\"] = p200x  # Add target column\n",
    "    E_p200 = m_p200(data, models)\n",
    "\n",
    "    # g(X1)\n",
    "    g = g_(d1, d2, data, models)\n",
    "\n",
    "    # N-hat calculation (empirical version)\n",
    "    Nhat = g * p10x * E_p200\n",
    "\n",
    "    # Handle missing values and calculate mean\n",
    "    Nhat[np.isnan(Nhat)] = 0\n",
    "    Nhat = np.mean(Nhat)\n",
    "\n",
    "    return Nhat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fadae74b-ec0b-47f3-9418-7439f90491a4",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "d1_true = reg1(*eta_true[:2], XX_obs15)\n",
    "d2_true = reg2(*eta_true[2:], XX_obs15)\n",
    "\n",
    "d1_true.mean(), d2_true.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a561769-a42d-4e87-90aa-f4f60eefc9d0",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "(\n",
    "    N_hat(d1_true, d2_true, XX_obs15, models),\n",
    "    N_hat(d1_true, d2_true, XX_obs15, m_4567)\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "63f86b86-46aa-497a-9c48-fdee60f717cf",
   "metadata": {},
   "source": [
    "### Plugin for V=N/D\n",
    "\n",
    "$$\n",
    "V_{1111}(\\pi) = \\mathbb E \\left[ g(X_1) \\frac{ p_1^0(X_1) \\mathbb E\\left( p_2^0(H_1) | X_1, A_1=0, C_1=0, S_1=1 \\right) }{ \\mathbb E\\left[ p_1^0(X_1) \\mathbb E\\left( p_2^0(H_1) | X_1, A_1=0, C_1=0, S_1=1 \\right) \\right] } \\right],\n",
    "$$\n",
    "\n",
    "$$\n",
    "g(X_1) = \\sum_{a_1,a_2 \\in \\{0,1\\}} \\mathbb E \\left[ \\mu_2^{a_2}(H_1) \\mathbf1\\{\\pi_2(H_1)=a_2\\} | X_1, A_1=a_1, C_1=0, S_1=1 \\right] \\mathbf1\\{\\pi_1(X_1) = a_1\\}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb624d62",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def V_plugin(d1, d2, data, models):\n",
    "    # eta11~24: parameters for decision rules\n",
    "    # reg1,2: decision rules\n",
    "    # data: data.frame\n",
    "    # models: fitted models from fit.models()\n",
    "    \n",
    "    ## decisions\n",
    "    #d1 = reg1(eta11, eta12, data)\n",
    "    #d2 = reg2(eta21, eta22, eta23, data)\n",
    "    \n",
    "    # numerator\n",
    "    val = N_hat(d1, d2, data, models) / D_hat(data, models)\n",
    "    \n",
    "    return val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "715b3b99-7ff2-4d33-807c-d9485af63024",
   "metadata": {},
   "outputs": [],
   "source": [
    "V_hat = V_plugin(d1_true, d2_true, XX_obs15, models)\n",
    "V_hat#, V_true, V_true2, V_true_def"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25bb38c9-a37a-466f-ad56-30385b39b138",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e7f7155-b8ae-4c85-8e00-5963b8fee7a3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "44c02631-117b-40e1-89bb-e56f2b4ece41",
   "metadata": {},
   "source": [
    "## \"True\" value using ID & true model\n",
    "- Plugin estimator with true & flexible models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb72f170-92ff-408e-88ef-0b753478f94d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def true_models(data):\n",
    "    models = {}\n",
    "    \n",
    "    class trueModel1():\n",
    "        def __init__(self, md):\n",
    "            self.model_func = md\n",
    "            \n",
    "        def predict(self, fitdf):\n",
    "            return pd.Series(self.model_func(fitdf.X1))\n",
    "\n",
    "    class trueModel11():\n",
    "        def __init__(self, md):\n",
    "            self.model_func = md\n",
    "            \n",
    "        def predict(self, fitdf):\n",
    "            return pd.Series(self.model_func(fitdf.X1, fitdf.A1))\n",
    "\n",
    "    class trueModel2():\n",
    "        def __init__(self, md):\n",
    "            self.model_func = md\n",
    "            \n",
    "        def predict(self, fitdf):\n",
    "            return pd.Series(self.model_func(fitdf.X1, fitdf.X2))\n",
    "\n",
    "    class trueModel22():\n",
    "        def __init__(self, md):\n",
    "            self.model_func = md\n",
    "            \n",
    "        def predict(self, fitdf):\n",
    "            return pd.Series(self.model_func(fitdf.X1, fitdf.X2, fitdf.A1, fitdf.A2))\n",
    "\n",
    "    models['phi1.hat'] = trueModel1(phi1)\n",
    "    models['phi2.hat'] = trueModel2(phi2)\n",
    "    models['K1.hat'] = trueModel11(K1)\n",
    "    models['K2.hat'] = trueModel22(K2)\n",
    "    models['p1.hat'] = trueModel11(p1)\n",
    "    models['p2.hat'] = trueModel22(p2)\n",
    "    models['mu2.hat'] = trueModel22(mu2)\n",
    "        \n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        \n",
    "        pred = phi1(data['X1'])  #models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "    \n",
    "    def ps2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, len(data.X1))\n",
    "        \n",
    "        pred = phi2(data.X1, data.X2)  #models['phi2.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a2)) & (a2 == 0)] = 1 - pred[(~np.isnan(a2)) & (a2 == 0)]\n",
    "        return pred\n",
    "    \n",
    "    def cp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        return K1(data.X1, a1)  #models['K1.hat'].predict(fitdf).values\n",
    "    \n",
    "    def cp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, len(data.X1))\n",
    "        return K2(data.X1, data.X2, a1, a2, eta2)  #models['K2.hat'].predict(fitdf).values\n",
    "\n",
    "    def sp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        return p1(data.X1, a1)\n",
    "    \n",
    "    def sp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, len(data.X1))\n",
    "        return p2(data.X1, data.X2, a1, a2)\n",
    "\n",
    "    def pcs1pcs2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2) * sp2(a1, a2)\n",
    "    \n",
    "    def pcs1pc2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    def pc1(a1):\n",
    "        return ps1(a1) * cp1(a1)\n",
    "    \n",
    "    def m2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, len(data.X1))\n",
    "        return mu2(data.X1, data.X2, a1, a2)  #models['mu2.hat'].predict(fitdf).values\n",
    "\n",
    "    # Return dictionary with models and prediction functions\n",
    "    return {**models,\n",
    "            'ps1': ps1, 'ps2': ps2,  # Prediction functions\n",
    "            'cp1': cp1, 'cp2': cp2,\n",
    "            'sp1': sp1, 'sp2': sp2, 'm2': m2,\n",
    "            'pcs1pcs2': pcs1pcs2, 'pcs1pc2': pcs1pc2, 'pcs1': pcs1, 'pc1': pc1,\n",
    "            'Ep2.false': False, 'Emupi.false': False}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72556fbf-3f6f-4a3c-a544-c6c43c8f371b",
   "metadata": {},
   "outputs": [],
   "source": [
    "tm = true_models(df_large[\"obs\"])  # true + flexible model\n",
    "write_pkl(tm, \"tm_df_large_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "696b8ec6-dd1a-49d3-8ee3-2faed8213b3c",
   "metadata": {},
   "source": [
    "tm = read_pkl(\"tm_df_large_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee76f62d-ac05-4f4a-a7ab-59fa8d42875e",
   "metadata": {},
   "outputs": [],
   "source": [
    "D_true_id = D_hat(df_large[\"obs\"], tm)\n",
    "D_true_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "894f0d19-9ca1-4cbc-874f-bb2d9edf416e",
   "metadata": {},
   "outputs": [],
   "source": [
    "d1_id = reg1(*eta_true[:2], df_large[\"obs\"])\n",
    "d2_id = reg2(*eta_true[2:], df_large[\"obs\"])\n",
    "N_true_id = N_hat(d1_id, d2_id, df_large[\"obs\"], tm)\n",
    "N_true_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b412e34d-2b70-4a09-b1e4-9695af8da1e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "V_true_id = N_true_id / D_true_id\n",
    "V_true_id"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db2abb54",
   "metadata": {},
   "source": [
    "## SD-MR (multiply robust)\n",
    "$$\\hat V_\\text{SD-MR}(\\pi) = \\frac{ \\mathbb P_n \\hat{\\mathcal V}_N(O) }{ \\mathbb P_n \\hat{\\mathcal V}_D(O) }$$\n",
    "where $\\mathbb{P}_n$ is the empirical mean,"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7109d746",
   "metadata": {},
   "source": [
    "### D-MR\n",
    "* EIF-based estimator $\\mathbb P_n \\hat{\\mathcal V}_D(O)$ where\n",
    "$$\n",
    "\\begin{align}\n",
    "{\\mathcal V}_D(O)\n",
    " &= \n",
    "   \\frac{(1-A_1)(1-C_1)}{(\\varphi_1K_1)^0(X_1)} \\left\\{S_1 - p_1^0(X_1)\\right\\} m_{p_2}^{00}(X_1) \\\\\n",
    " &\\quad+\n",
    "   p_1^0(X_1) \\bigg[ \\frac{(1-A_1)(1-A_2)(1-C_1)(1-C_2)S_1}{(\\varphi_1^0K_1^0p_1^0)(X_1)(\\varphi_2^0K_2^0)(H_1)} \\big(S_2 - p_2^0(H_1)\\big) \\\\\n",
    " &\\quad\\qquad\\qquad+ \\frac{(1-A_1)(1-C_1)S_1}{(\\varphi_1^0K_1^0p_1^0)(X_1)} \\left\\{ p_2^0(H_1) - m_{p_2^0}(X_1) \\right\\} \\bigg] \\\\\n",
    " &\\quad+ p_1^0(X_1) m_{p_2}^{00}(X_1)\n",
    "\\end{align}\n",
    "$$\n",
    "and $$m_{p_2}^{a_1,a_2}(x_1) = \\mathbb E[ p_2^{a_2}(H_1) | X_1=x_1, A_1=a_1, C_1=0, S_1=1 ].$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95fa6ecd-e637-43a7-a70a-83aebcafe102",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi_D_terms(data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of D.\n",
    "\n",
    "    Args:\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, Ep2.false, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of D.\n",
    "    \"\"\"\n",
    "\n",
    "    # Extract prediction functions from models (assuming they exist)\n",
    "    ps1 = models['ps1']; ps2 = models['ps2']\n",
    "    cp1 = models['cp1']; cp2 = models['cp2']\n",
    "    sp1 = models['sp1']; sp2 = models['sp2']\n",
    "    Ep2_false = models['Ep2.false']\n",
    "    \n",
    "    pcs1pc2 = models['pcs1pc2']\n",
    "    pcs1 = models['pcs1']\n",
    "    pc1 = models['pc1']\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    # Conditional mean of p_2^00\n",
    "    p10x = sp1(0)\n",
    "    p200x = sp2(0, 0)\n",
    "    E_p200 = m_p200(data, models)\n",
    "\n",
    "    # attach data_filled\n",
    "    A1 = data_filled.A1; A2 = data_filled.A2; C1 = data_filled.C1; \n",
    "    C2 = data_filled.C2; S1 = data_filled.S1; S2 = data_filled.S2\n",
    "    \n",
    "    # EIF-based estimators of D\n",
    "    D11 = (1-A1)*(1-C1) / pc1(0) * (S1)\n",
    "    D12 = (1-A1)*(1-C1) / pc1(0) * (p10x)\n",
    "    D11[np.isnan(D11)] = 0\n",
    "    D12[np.isnan(D12)] = 0\n",
    "    D1 = (D11 - D12) * E_p200\n",
    "    \n",
    "    D211 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0) * (S2)\n",
    "    D212 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0) *p200x\n",
    "    D211[np.isnan(D211)] = 0\n",
    "    D212[np.isnan(D212)] = 0\n",
    "    D2 = p10x * (D211 - D212)\n",
    "\n",
    "    D221 = (1-A1)*(1-C1)*S1 / pcs1(0) * (p200x)\n",
    "    D222 = (1-A1)*(1-C1)*S1 / pcs1(0) * E_p200\n",
    "    D221[np.isnan(D221)] = 0\n",
    "    D222[np.isnan(D222)] = 0\n",
    "    D3 = p10x * (D221 - D222)\n",
    "\n",
    "    D4 = p10x * E_p200\n",
    "\n",
    "    return D1, D2, D3, D4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68e7df80-3323-4404-b76d-694799e6e985",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi_D(data, models):\n",
    "    D1, D2, D3, D4 = phi_D_terms(data, models)\n",
    "    return D1 + D2 + D3 + D4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8697b540-b0e2-44fe-8fec-71307447617d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def D_MR(data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of D.\n",
    "\n",
    "    Args:\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, Ep2.false, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of D.\n",
    "    \"\"\"\n",
    "    # Extract prediction functions from models (assuming they exist)\n",
    "    ps1 = models['ps1']; ps2 = models['ps2']\n",
    "    cp1 = models['cp1']; cp2 = models['cp2']\n",
    "    sp1 = models['sp1']; sp2 = models['sp2']\n",
    "    Ep2_false = models['Ep2.false']\n",
    "    \n",
    "    pcs1pc2 = models['pcs1pc2']\n",
    "    pcs1 = models['pcs1']\n",
    "    pc1 = models['pc1']\n",
    "\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    # attach data_filled\n",
    "    A1 = data_filled.A1; A2 = data_filled.A2; C1 = data_filled.C1; \n",
    "    C2 = data_filled.C2; S1 = data_filled.S1; S2 = data_filled.S2\n",
    "\n",
    "    D1, D2, D3, D4 = phi_D_terms(data, models)\n",
    "\n",
    "    # stabilized estimator\n",
    "    w1 = (1-A1)*(1-C1) / pc1(0)\n",
    "    w1[np.isnan(w1)] = 0\n",
    "    w1 = sum(w1)\n",
    "    D1mean = D1.sum() / w1\n",
    "\n",
    "    w2 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0)\n",
    "    w2[np.isnan(w2)] = 0\n",
    "    w2 = sum(w2)\n",
    "    D2mean = D2.sum() / w2\n",
    "\n",
    "    w3 = (1-A1)*(1-C1)*S1 / pcs1(0)\n",
    "    w3[np.isnan(w3)] = 0\n",
    "    w3 = sum(w3)\n",
    "    D3mean = D3.sum() / w3\n",
    "\n",
    "    # print(w1, w2, w3)\n",
    "\n",
    "    return D1mean + D2mean + D3mean + D4.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7db5470-2c22-4add-93cf-d19d243a8ba3",
   "metadata": {},
   "outputs": [],
   "source": [
    "tm15 = true_models(XX_obs15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7927e68-291c-4fc6-a0e5-2b261252da26",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    D_MR(XX_obs15, models),  # EIF of D + D\n",
    "    D_MR(XX_obs15, m_12345),\n",
    "    D_MR(XX_obs15, m_12367),\n",
    "    \">>>\",\n",
    "    D_MR(XX_obs15, m_1234),\n",
    "    D_MR(XX_obs15, m_1236),\n",
    "    D_MR(XX_obs15, m_1235),\n",
    "    D_MR(XX_obs15, m_1237),\n",
    "    \">>>\",\n",
    "    D_MR(XX_obs15, tm15),\n",
    "    D_hat(XX_obs15, models),\n",
    "    D_true_id\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b63994e2-ff36-4eaa-9826-55b5e6e04a2a",
   "metadata": {},
   "source": [
    "### N-MR\n",
    "\n",
    "* EIF-based estimator $\\mathbb P_n \\hat{\\mathcal V}_N(O)$ where\n",
    "$$\n",
    "\\begin{align}\n",
    "{\\mathcal V}_N(O)\n",
    " &= \\sum_{a_1,a_2} \\bigg[ \\frac{\\mathbb1\\{\\bar A=\\bar a\\}(1-C_1)(1-C_2)S_1S_2}{(\\varphi_1K_1p_1)^{a_1}(X_1)(\\varphi_2K_2p_2)^{a_2}(H_1)} \\big\\{Y - \\mu_2^{a_2}(H_1)\\big\\} \\cdot \\mathbb1\\{\\pi(H_1)=\\bar a\\} \\\\\n",
    " &\\qquad\\qquad+ \\frac{\\mathbb1\\{A_1=a_1\\}(1-C_1)S_1}{(\\varphi_1K_1p_1)^{a_1}(X_1)} \\left\\{ \\mu_2^{a_2}(H_1) \\mathbb1\\{\\pi_2(H_1)=a_2\\} - m_{\\mu_2, \\pi}^{a_1a_2}(X_1) \\right\\} \\cdot \\mathbb1\\{\\pi_1(X_1)=a_1\\} \\bigg] \\\\\n",
    " &\\qquad\\quad \\times p_1^0(X_1) \\cdot m_{p_2}^{00}(X_1) \\\\\n",
    " &+ m_{\\mu_2, \\pi}(x_1) \\cdot \\left[ \\frac{(1-A_1)(1-C_1)}{(\\varphi_1^0K_1^0)(X_1)} \\left(S_1 - p_1^0(X_1)\\right) \\right] \\cdot m_{p_2}^{00}(X_1) \\\\\n",
    " &+ m_{\\mu_2, \\pi}(x_1) \\cdot p_1^0(X_1) \\times \\bigg[ \\frac{(1-A_1)(1-A_2)(1-C_1)(1-C_2)S_1}{(\\varphi_1K_1p_1)^0(X_1)(\\varphi_2^0K_2^0)(H_1)} \\big(S_2 - p_2^0(H_1)\\big) \\\\\n",
    " &\\quad\\qquad\\qquad\\qquad\\qquad\\qquad+ \\frac{(1-A_1)(1-C_1)S_1}{(\\varphi_1K_1p_1)^0(X_1)} \\left\\{ p_2^0(H_1) - m_{p_2}^{00}(X_1) \\right\\} \\bigg] \\\\\n",
    " &+ m_{\\mu_2, \\pi}(x_1) \\cdot p_1^0(X_1) \\cdot m_{p_2}^{00}(X_1)\n",
    "\\end{align}\n",
    "$$\n",
    "and $$m_{\\mu_2, \\pi}^{a_1a_2}(x_1) = \\mathbb E[\\mu_2^{a_2}(H_1) \\mathbb1\\{\\pi_2(H_1)=a_2\\} | X_1=x_1, A_1=a_1, C_1=0, S_1=1 ],$$\n",
    "$$\n",
    "m_{\\mu_2, \\pi}(x_1) = \\sum_{a_1, a_2} m_{\\mu_2, \\pi}(x_1, \\bar a) \\mathbb1\\{\\pi_1(x_1)=a_1\\}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7befb40-8144-439b-901e-6edf5945b75b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi_N_terms(d1, d2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of N.\n",
    "\n",
    "    Args:\n",
    "        d1: First decision rule estimate.\n",
    "        d2: Second decision rule estimate.\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, m2, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of N.\n",
    "    \"\"\"\n",
    "\n",
    "    # Extract prediction functions for models (assuming they exist in models)\n",
    "    ps1 = models['ps1']; ps2 = models['ps2']\n",
    "    cp1 = models['cp1']; cp2 = models['cp2']\n",
    "    sp1 = models['sp1']; sp2 = models['sp2']; m2 = models['m2']\n",
    "    Ep2_false = models['Ep2.false']\n",
    "    Emupi_false = models['Emupi.false']\n",
    "\n",
    "    pcs1pcs2 = models['pcs1pcs2']\n",
    "    pcs1pc2 = models['pcs1pc2']\n",
    "    pcs1 = models['pcs1']\n",
    "    pc1 = models['pc1']\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    # Conditional means\n",
    "    p10x = sp1(0)\n",
    "    p200x = sp2(0, 0)\n",
    "    E_p200 = m_p200(data, models)\n",
    "    E_mu2pi = g_(d1, d2, data, models)\n",
    "\n",
    "    # attach data_filled\n",
    "    A1 = data.A1; A2 = data.A2; \n",
    "    C1 = data.C1; C2 = data.C2; \n",
    "    S1 = data.S1; S2 = data.S2; Y = data.Y\n",
    "    \n",
    "    # EIF-based estimators of N\n",
    "    def N11(a1, a2):\n",
    "        val = (A1==a1)*(A2==a2)*(1-C1)*(1-C2)*S1*S2 / pcs1pcs2(a1,a2) * \\\n",
    "              (Y - m2(a1,a2)) * (d2==a2) * (d1==a1)\n",
    "        val[np.isnan(val)] = 0\n",
    "        return val\n",
    "\n",
    "    def N12(a1, a2):\n",
    "        val = (A1==a1)*(1-C1)*S1 / pcs1(a1) * (m2(a1,a2)*(d2==a2) \n",
    "                                               - m_mu_pi(d1, d2, a1, a2, \n",
    "                                                         data, models)) * (d1==a1)\n",
    "        val[np.isnan(val)] = 0\n",
    "        return val\n",
    "\n",
    "    N111_ = (A1==d1)*(A2==d2)*(1-C1)*(1-C2)*S1*S2 / pcs1pcs2(d1,d2) * Y #N11(0, 0) + N11(1, 0) + N11(0, 1) + N11(1, 1)\n",
    "    N112_ = (A1==d1)*(A2==d2)*(1-C1)*(1-C2)*S1*S2 / pcs1pcs2(d1,d2) * m2(d1,d2)\n",
    "    N111_[np.isnan(N111_)] = 0\n",
    "    N112_[np.isnan(N112_)] = 0\n",
    "    N11_ = N111_ - N112_\n",
    "    N121_ = N12(0, 0) + N12(1, 0) + N12(0, 1) + N12(1, 1)\n",
    "    N121_[np.isnan(N121_)] = 0\n",
    "    N12_ = N121_ #- N122_\n",
    "    N11 = (N11_) * p10x * E_p200\n",
    "    N12 = (N12_) * p10x * E_p200\n",
    "\n",
    "    N211 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0) * S2\n",
    "    N212 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0) * p200x  # same as D1\n",
    "    N211[np.isnan(N211)] = 0\n",
    "    N212[np.isnan(N212)] = 0\n",
    "    N21_ = N211 - N212\n",
    "    N221 = (1-A1)*(1-C1)*S1 / pcs1(0) * p200x  # same as D2\n",
    "    N222 = (1-A1)*(1-C1)*S1 / pcs1(0) * E_p200\n",
    "    N221[np.isnan(N221)] = 0\n",
    "    N222[np.isnan(N222)] = 0\n",
    "    N22_ = N221 - N222\n",
    "    N21 = E_mu2pi * p10x * (N21_)\n",
    "    N22 = E_mu2pi * p10x * (N22_)\n",
    "\n",
    "    N311 = (1-A1)*(1-C1) / pc1(0) * (S1)\n",
    "    N312 = (1-A1)*(1-C1) / pc1(0) * (p10x)\n",
    "    N311[np.isnan(N311)] = 0\n",
    "    N312[np.isnan(N312)] = 0\n",
    "    N31 = N311 - N312\n",
    "    N3 = E_mu2pi * N31 * E_p200\n",
    "\n",
    "    N4 = E_mu2pi * p10x * E_p200\n",
    "\n",
    "    return N11, N12, N21, N22, N3, N4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddffe601-b73f-4bac-af66-756a7a4da3c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi_N(d1, d2, data, models):\n",
    "    N11, N12, N21, N22, N3, N4 = phi_N_terms(d1, d2, data, models)\n",
    "    return N11 + N12 + N21 + N22 + N3 + N4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9431a645-b1d4-4ab2-8492-0d6d9e32c092",
   "metadata": {},
   "outputs": [],
   "source": [
    "def N_MR(d1, d2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of N.\n",
    "\n",
    "    Args:\n",
    "        d1: First decision rule estimate.\n",
    "        d2: Second decision rule estimate.\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, m2, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of N.\n",
    "    \"\"\"\n",
    "    # Extract prediction functions for models (assuming they exist in models)\n",
    "    ps1 = models['ps1']; ps2 = models['ps2']\n",
    "    cp1 = models['cp1']; cp2 = models['cp2']\n",
    "    sp1 = models['sp1']; sp2 = models['sp2']; m2 = models['m2']\n",
    "\n",
    "    pcs1pcs2 = models['pcs1pcs2']\n",
    "    pcs1pc2 = models['pcs1pc2']\n",
    "    pcs1 = models['pcs1']\n",
    "    pc1 = models['pc1']\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "    \n",
    "    # attach data_filled\n",
    "    A1 = data_filled.A1; A2 = data_filled.A2; \n",
    "    C1 = data_filled.C1; C2 = data_filled.C2; \n",
    "    S1 = data_filled.S1; S2 = data_filled.S2; Y = data_filled.Y\n",
    "\n",
    "    # stabilize mean of terms\n",
    "    N11, N12, N21, N22, N3, N4 = phi_N_terms(d1, d2, data, models)\n",
    "    w11 = (A1==d1)*(A2==d2)*(1-C1)*(1-C2)*S1*S2 / pcs1pcs2(d1,d2)\n",
    "    w12 = (A1==d1)*(1-C1)*S1 / pcs1(d1)\n",
    "    w21 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0)\n",
    "    w22 = (1-A1)*(1-C1)*S1 / pcs1(0)\n",
    "    w3 = (1-A1)*(1-C1) / pc1(0)\n",
    "\n",
    "    w11[np.isnan(w11)] = 0\n",
    "    w12[np.isnan(w12)] = 0\n",
    "    w21[np.isnan(w21)] = 0\n",
    "    w22[np.isnan(w22)] = 0\n",
    "    w3[np.isnan(w3)] = 0\n",
    "\n",
    "    w11 = sum(w11); w21 = sum(w21); w3 = sum(w3)\n",
    "    w12 = sum(w12); w22 = sum(w22)\n",
    "\n",
    "    N11mean = N11.sum() / w11\n",
    "    N12mean = N12.sum() / w12\n",
    "    N21mean = N21.sum() / w21\n",
    "    N22mean = N22.sum() / w22\n",
    "    N3mean = N3.sum() / w3\n",
    "    \n",
    "    return N11mean + N12mean + N21mean + N22mean + N3mean + N4.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b321063b",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "(\n",
    "    N_MR(d1_true, d2_true, XX_obs15, models),\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_12345),\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_123457),\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_12367),\n",
    "    \">>>\", \n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_1234),\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_1236),\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_1235),\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_1237),\n",
    "    \">>>\",\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_4567),\n",
    "    N_hat(d1_true, d2_true, XX_obs15, m_4567),\n",
    "    \">>>\",\n",
    "    N_MR(d1_true, d2_true, XX_obs15, tm15),\n",
    "    N_hat(d1_true, d2_true, XX_obs15, models),\n",
    "    N_true_id,\n",
    "    # N_true2\n",
    ")\n",
    "# np.set_printoptions()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "deb3797c",
   "metadata": {},
   "source": [
    "### $V^{MR}(\\pi) = N/D$\n",
    "$$\n",
    "\\hat V^\\text{MR}(\\pi) = \\frac{ \\mathbb P_n \\hat{\\mathcal V}_N }{ \\mathbb P_n \\hat{\\mathcal V}_D }\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2337c575-bec9-4d0a-a11c-ca72be462846",
   "metadata": {},
   "outputs": [],
   "source": [
    "# EIF-based estimator of V(\\pi) = N/D\n",
    "def V_MR(d1, d2, data, models):\n",
    "    return N_MR(d1, d2, data, models) / D_MR(data, models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbedd337",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "(\n",
    "    V_MR(d1_true, d2_true, XX_obs15, models),\n",
    "    V_MR(d1_true, d2_true, XX_obs15, m_12345),\n",
    "    V_MR(d1_true, d2_true, XX_obs15, m_123457),\n",
    "    \">>>\",\n",
    "    V_MR(d1_true, d2_true, XX_obs15, m_4567),\n",
    "    V_plugin(d1_true, d2_true, XX_obs15, m_4567),\n",
    "    \">>>\",\n",
    "    V_plugin(d1_true, d2_true, XX_obs15, models),\n",
    "    V_true_id,\n",
    "    # V_true2,\n",
    "    # V_true_def\n",
    ")\n",
    "# V.MR2(d1.true, d2.true, XX.obs15, models)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "394379e4",
   "metadata": {},
   "source": [
    "---\n",
    "# OPE with fixed decision rules"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f168e241",
   "metadata": {},
   "source": [
    "## Function `simulate_fixpi()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ee13ab9-7b37-443a-be2c-5b897021d553",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ```python\n",
    "Ns = [1000, 2000, 5000]\n",
    "\n",
    "for m in range(500):\n",
    "    for i in range(len(Ns)):\n",
    "        # Simulate data\n",
    "        N = Ns[i]  # Replace with actual N value for simulation\n",
    "        \n",
    "        fpath = f\"./simdata/[2]data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        if os.path.exists(fpath):\n",
    "            continue\n",
    "        else:\n",
    "            dat = genData(N, eta2)\n",
    "            write_pkl(dat, fpath)\n",
    "# ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee4b7b18-7550-4ff0-b6de-2419f1fa173e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def simulate_fixpi(m, eta2, eta_true=eta_true):\n",
    "    \"\"\"\n",
    "    Simulates data and calculates estimators under different settings.\n",
    "    \n",
    "    Args:\n",
    "      m: Number of simulations.\n",
    "      \n",
    "    Returns:\n",
    "      A namedtuple containing simulation results:\n",
    "          value_hat: Matrix of estimated values (N / D) for different settings.\n",
    "          Nhats: Matrix of N estimates for different settings.\n",
    "          Dhats: Matrix of D estimates for different settings.\n",
    "    \"\"\"\n",
    "\n",
    "    if len(np.array(eta_true).shape) == 4:\n",
    "        beta_mr_1k, beta_pg_1k = eta_true[m][0]\n",
    "        beta_mr_2k, beta_pg_2k = eta_true[m][1]\n",
    "        \n",
    "    # Define namedtuple for results\n",
    "    Result = namedtuple('Result', ['value_hat', 'Nhats', 'Dhats'])\n",
    "    \n",
    "    # Initialize results matrices\n",
    "    value_hat = np.zeros((len(Ns), 2 * 6 + 1))\n",
    "    Nhats = np.zeros((len(Ns), 2 * 6 + 1))\n",
    "    Dhats = np.zeros((len(Ns), 2 * 6 + 1))\n",
    "    \n",
    "    for i in range(len(Ns)):\n",
    "        # Simulate data\n",
    "        N = Ns[i]  # Replace with actual N value for simulation\n",
    "\n",
    "        fpath = f\"./simdata/[2]data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        if os.path.exists(fpath):\n",
    "            dat = read_pkl(fpath)\n",
    "        else:\n",
    "            dat = genData(N, eta2)\n",
    "            write_pkl(dat, fpath)\n",
    "            \n",
    "        if len(np.array(eta_true).shape) == 4:\n",
    "            if i == 0:\n",
    "                beta_mr = beta_mr_1k\n",
    "                beta_pg = beta_pg_1k\n",
    "            else:\n",
    "                beta_mr = beta_mr_2k\n",
    "                beta_pg = beta_pg_2k\n",
    "        else:\n",
    "            beta_mr = beta_pg = eta_true\n",
    "            \n",
    "        \n",
    "        # Separate observed data and all possible outcomes (if needed)\n",
    "        XX_obs = dat['obs']  # Replace with relevant columns\n",
    "        XX_all = dat['all']  # Replace with relevant columns\n",
    "        \n",
    "        # Fit nuisance models\n",
    "        models = fit_models(XX_obs)  # Replace ... with arguments\n",
    "        tm = true_models(XX_obs)\n",
    "        # settings that should lead to consistent estimator\n",
    "        # : m.1256, m.2345, m.2356, m.3456, m.12346\n",
    "        models_12367 = fit_models(XX_obs, \n",
    "                                  p2_false=True, \n",
    "                                  Ep2_false=True)  # 45 wrong\n",
    "        # models_2345 = fit_models(XX_obs, \n",
    "        #                          phi2_false=True, K2_false=True, \n",
    "        #                          Emupi_false=True)  # 2345 correct  # 16 wrong\n",
    "        models_12467 = fit_models(XX_obs, \n",
    "                                  phi2_false=True, K2_false=True,\n",
    "                                  Ep2_false=True)  # 35 wrong\n",
    "        models_4567 = fit_models(XX_obs, \n",
    "                                 phi2_false=True, K2_false=True,\n",
    "                                 phi1_false=True, K1_false=True, \n",
    "                                 p1_false=True)  # 123 wrong\n",
    "        models_123457 = fit_models(XX_obs, \n",
    "                                   mu2_false=True)  # 6 wrong\n",
    "        \n",
    "        ## settings that is not guaranteed to be consistent\n",
    "        models_12345 = fit_models(XX_obs,\n",
    "                                  mu2_false=True,\n",
    "                                  Emupi_false=True)  # 67 wrong\n",
    "        \n",
    "        # fixed decisions\n",
    "        # print(beta_mr)\n",
    "        # print(beta_pg)\n",
    "        d1mr = reg1(*beta_mr[:2], XX_obs)\n",
    "        d2mr = reg2(*beta_mr[2:], XX_obs)\n",
    "\n",
    "        d1pg = reg1(*beta_pg[:2], XX_obs)\n",
    "        d2pg = reg2(*beta_pg[2:], XX_obs)\n",
    "        \n",
    "        # Estimators with different models\n",
    "        all_models = [models, \n",
    "                      models_12367, models_12467, models_4567, models_123457, \n",
    "                      models_12345]#, models_2345]\n",
    "        n_model = len(all_models)\n",
    "        for mm in range(n_model):\n",
    "            Nhats[i, mm] = N_MR(d1mr, d2mr, XX_obs, all_models[mm])\n",
    "            Dhats[i, mm] = D_MR(XX_obs, all_models[mm])\n",
    "            value_hat[i, mm] = Nhats[i, mm] / Dhats[i, mm]\n",
    "\n",
    "        # Plugin estimators (replace function names with yours)\n",
    "        for mm in range(n_model):\n",
    "            Nhats[i, mm + n_model] = N_hat(d1pg, d2pg, XX_obs, all_models[mm])\n",
    "            Dhats[i, mm + n_model] = D_hat(XX_obs, all_models[mm])\n",
    "            value_hat[i, mm + n_model] = Nhats[i, mm + n_model] / Dhats[i, mm + n_model]\n",
    "\n",
    "        # True value\n",
    "        Nhats[i, -1] = N_hat(d1pg, d2pg, XX_obs, tm)\n",
    "        Dhats[i, -1] = D_hat(XX_obs, tm)\n",
    "        value_hat[i, -1] = Nhats[i, -1] / Dhats[i, -1]\n",
    "\n",
    "    return Result(value_hat, Nhats, Dhats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d73cc2a-161d-4407-88a2-84d0fbd6c806",
   "metadata": {},
   "outputs": [],
   "source": [
    "def initializer():\n",
    "    \"\"\"\n",
    "    Ignore CTRL+C in the worker process.\n",
    "    \"\"\"\n",
    "    signal.signal(signal.SIGINT, signal.SIG_IGN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e83620a8-aa39-447a-95aa-1be9f0c4ea97",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def simulate_fixpi_mp(m): \n",
    "    return simulate_fixpi(m, eta2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "481db7f3-c290-47a7-b301-f24a15244759",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "%%time\n",
    "\n",
    "Ns = [1000, 2000, 5000]\n",
    "M = 500\n",
    "\n",
    "# Censoring rate ~= 15%\n",
    "# eta2 = 4\n",
    "pd.options.mode.chained_assignment = None\n",
    "\n",
    "sim_15 = []\n",
    "for m in tqdm(range(M)):\n",
    "    sim_15.append( simulate_fixpi(m, eta2) )\n",
    "\n",
    "pd.options.mode.chained_assignment = 'warn'\n",
    "\n",
    "write_pkl(sim_15, \"sim_15_contiX2_sim2.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a919a6e-c4eb-471b-89bc-71ce39876b0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "Ns = [1000, 2000, 5000]\n",
    "M = 500"
   ]
  },
  {
   "cell_type": "raw",
   "id": "643b9970-1ea0-4a9f-9cd7-3ebc5c611422",
   "metadata": {},
   "source": [
    "sim_15 = read_pkl(\"sim_15_contiX2_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3c64ac7-3888-4cc7-8fcc-41909fff4b2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "D_15 = np.array([sim.Dhats for sim in sim_15])\n",
    "N_15 = np.array([sim.Nhats for sim in sim_15])\n",
    "V_15 = np.array([sim.value_hat for sim in sim_15])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a8bd70e-f26f-4a28-81f8-684a490355bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stylize_axes(ax):\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "\n",
    "    # ax.xaxis.set_tick_params(top='off', direction='out', width=1)\n",
    "    # ax.yaxis.set_tick_params(right='off', direction='out', width=1)\n",
    "\n",
    "    ax.yaxis.set_ticks_position('left') \n",
    "    ax.xaxis.set_ticks_position('bottom')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72369c66-c621-4342-86ed-eef279d3e438",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mdls = [\"\\n12367\", \"\\n12467\", \"\\n4567\", \"\\n123457\", \"\\n12345\"]#, \"\\n2345\"]\n",
    "mdls = [f'M{i+1}' for i in range(6)]\n",
    "# mdls = [\"\\n45\", \"\\n35\", \"\\n123\", \"\\n6\", \"\\n67\"]#, \"\\n2345\"]\n",
    "# mdl_lbl = [r\"MR\"] + mdls + [\"plugin\"] + mdls + [\"True\"] #+ [\"plugin2\"]\n",
    "# mrmdls = [f'M{i+1}' if i != 2 else f'M{i+1}\\n      MR' for i in range(6)]\n",
    "mrmdls = [f'M{i+1}' for i in range(6)]\n",
    "pgmdls = [f'M{i+1}' if i != 2 else f'M{i+1}\\n      Plugin' for i in range(6)]\n",
    "mdl_lbl = mrmdls #+ pgmdls + [\"True\"] #+ [\"plugin2\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67af51fc-fc12-4d2e-9611-c3a8bbda9017",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(4.3,2.2), dpi=1200)\n",
    "\n",
    "stylize_axes(plt.subplot(121))\n",
    "plt.boxplot(\n",
    "    V_15[:,1,[0,1,2,4,5,3\n",
    "              #6,7,8,10,11,9\n",
    "             ]], tick_labels=mdl_lbl[:],\n",
    "    sym='x', flierprops={'markersize': 3, 'markeredgewidth': .25}\n",
    ")\n",
    "plt.xticks(rotation=0)\n",
    "plt.ylim(V_true_id*.1, V_true_id*2)\n",
    "plt.title(f\"N={Ns[1]}\")\n",
    "plt.ylabel(\"Always-survivor value\")\n",
    "plt.axhline(V_true_id, color='r', ls=\"-\", lw=.75, label='True V');\n",
    "# plt.axvline(6.5, color='k', ls=\"-\", lw=.5, alpha=.75);\n",
    "# plt.axvline(6*2+.5, color='k', ls=\"-\", lw=1);\n",
    "\n",
    "stylize_axes(plt.subplot(122))\n",
    "plt.boxplot(\n",
    "    V_15[:,2,[0,1,2,4,5,3\n",
    "              #6,7,8,10,11,9\n",
    "             ]], tick_labels=mdl_lbl[:],\n",
    "    sym='x', flierprops={'markersize': 3, 'markeredgewidth': .25}\n",
    ")\n",
    "plt.xticks(rotation=0)\n",
    "plt.yticks([], [])\n",
    "plt.ylim(V_true_id*.1, V_true_id*2)\n",
    "plt.title(f\"N={Ns[2]}\")\n",
    "plt.axhline(V_true_id, color='r', ls=\"-\", lw=.75, label='True V');\n",
    "# plt.axvline(6.5, color='k', ls=\"-\", lw=.5, alpha=.75);\n",
    "# plt.axvline(6*2+.5, color='k', ls=\"-\", lw=1);\n",
    "\n",
    "# plt.suptitle(\"V (=N/D)\");\n",
    "plt.tight_layout();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "328b28b0-381b-4824-86eb-84af7b4c63e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_15_1000 = pd.DataFrame(\n",
    "    np.c_[V_15[:,0,:-1].mean(0) - V_true_id, #D_15[:,0,:].mean(0), N_15[:,0,:].mean(0),\n",
    "          V_15[:,0,:-1].std(0), #D_15[:,0,:].std(0), N_15[:,0,:].std(0)\n",
    "]).T\n",
    "\n",
    "df_15_2000 = pd.DataFrame(\n",
    "    np.c_[V_15[:,1,:-1].mean(0) - V_true_id, #D_15[:,0,:].mean(0), N_15[:,0,:].mean(0),\n",
    "          V_15[:,1,:-1].std(0), #D_15[:,0,:].std(0), N_15[:,0,:].std(0)\n",
    "]).T\n",
    "\n",
    "df_15_5000 = pd.DataFrame(\n",
    "    np.c_[V_15[:,2,:-1].mean(0) - V_true_id, #D_15[:,0,:].mean(0), N_15[:,0,:].mean(0),\n",
    "          V_15[:,2,:-1].std(0), #D_15[:,0,:].std(0), N_15[:,0,:].std(0)\n",
    "]).T\n",
    "\n",
    "df_15 = pd.concat([df_15_1000, df_15_2000, df_15_5000]).T\n",
    "df_15.columns = pd.MultiIndex.from_tuples([(f\"N={Ns[0]}\", \"Bias\"), (f\"N={Ns[0]}\", \"SE\"),\n",
    "                                           (f\"N={Ns[1]}\", \"Bias\"), (f\"N={Ns[1]}\", \"SE\"),\n",
    "                                           (f\"N={Ns[2]}\", \"Bias\"), (f\"N={Ns[2]}\", \"SE\")])\n",
    "df_15.index = pd.MultiIndex.from_tuples([(\"EIF\", _.strip()) for _ in mdl_lbl[:6]] + \\\n",
    "                                        [(\"Plugin\", _.strip()) for _ in mdl_lbl[:6]])\n",
    "df_15.round(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bae27802-729c-4606-af71-694251b79bb3",
   "metadata": {},
   "source": [
    "## Analytical confidence interval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ff48732-0388-4646-8208-610345ca18a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eif_MR(d1, d2, data, models):\n",
    "    D̂ = D_hat(data, models)\n",
    "    N̂ = N_hat(d1, d2, data, models)\n",
    "    V̂ = N̂/D̂\n",
    "    # V̂ = V_MR(d1, d2, data, models)\n",
    "    ϕD = phi_D(data, models) - D̂\n",
    "    # D̂ = ϕD.mean()\n",
    "    ϕN = phi_N(d1, d2, data, models) - N̂\n",
    "\n",
    "    return (ϕN - V̂ * ϕD) / D̂  # trimmed var to remove outliers\n",
    "\n",
    "def var_MR(d1, d2, data, models, trim=0.0):\n",
    "    lim = trim\n",
    "    # trimmed var to remove outliers\n",
    "    return sp.stats.mstats.trimmed_mean(eif_MR(d1, d2, data, models)**2, (lim, lim))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5beb25d7-3d75-4fb5-8d34-791c020bfacb",
   "metadata": {},
   "outputs": [],
   "source": [
    "XX_obs2k = genData(2000, eta2)['obs']\n",
    "d1_2k = reg1(*eta_true[:2], XX_obs2k)\n",
    "d2_2k = reg2(*eta_true[2:], XX_obs2k)\n",
    "models_2k = fit_models(XX_obs2k)\n",
    "\n",
    "V_MR(d1_2k, d2_2k, XX_obs2k, models_2k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7407a057-6bd1-4608-a198-36a07b3329ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sqrt( var_MR(d1_2k, d2_2k, XX_obs2k, models_2k) / 2000 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50e0f109-1de5-4557-8d75-967f0c43dc24",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vs_2k = np.empty(500)\n",
    "for i in tqdm(range(500)):\n",
    "    XX_obs2ki = genData(2000, eta2)['obs']\n",
    "    d1_2ki = reg1(*eta_true[:2], XX_obs2ki)\n",
    "    d2_2ki = reg2(*eta_true[2:], XX_obs2ki)\n",
    "    models_2ki = fit_models(XX_obs2ki)\n",
    "    \n",
    "    Vs_2k[i] = V_MR(d1_2ki, d2_2ki, XX_obs2ki, models_2ki)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1b5e322-a011-4f28-a29f-dea02abe5ede",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vs_boot_2k = np.empty(500)\n",
    "for i in tqdm(range(500)):\n",
    "    XX_obs2ki = XX_obs2k.iloc[np.random.choice(len(XX_obs2k), size=len(XX_obs2k), \n",
    "                                               replace=True)].reset_index(drop=True)\n",
    "    \n",
    "    d1_2ki = reg1(*eta_true[:2], XX_obs2ki)\n",
    "    d2_2ki = reg2(*eta_true[2:], XX_obs2ki)\n",
    "    models_2ki = fit_models(XX_obs2ki)\n",
    "    \n",
    "    Vs_boot_2k[i] = V_MR(d1_2ki, d2_2ki, XX_obs2ki, models_2ki)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fececb70-7101-4843-97f6-8f3fbed838cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(3,5))\n",
    "\n",
    "sns.stripplot(Vs_2k, s=2.5, alpha=.5, c=\"gray\")\n",
    "plt.errorbar(-.05, V_MR(d1_2k, d2_2k, XX_obs2k, models_2k),\n",
    "             1.96 * np.sqrt(var_MR(d1_2k, d2_2k, XX_obs2k, models_2k) / 2000),\n",
    "             capsize=5, elinewidth=2, c=\"C3\", label=\"EIF\")\n",
    "plt.errorbar(0.05, V_MR(d1_2k, d2_2k, XX_obs2k, models_2k),\n",
    "             1.96 * Vs_boot_2k.std(),\n",
    "             capsize=5, elinewidth=2, c=\"C1\", label=\"bootstrap M=500\")\n",
    "\n",
    "plt.legend()\n",
    "plt.xticks([0], [\"repeat=1000\"])\n",
    "plt.title(\"95% Confidence Interval, N=2000\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7212bec-6c2a-4174-8c75-b08773785895",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "bec9a442-bd59-499d-bd99-b5d3a5b4d2a9",
   "metadata": {},
   "source": [
    "### sim"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7a8295c-a4fb-4ec6-a65c-a51a592e99bf",
   "metadata": {},
   "source": [
    "Save all correct models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b7c58cb-1958-488a-9a01-aa1c5be93a9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "Ns = [1000, 2000, 5000]\n",
    "M = 500\n",
    "\n",
    "all_models_correct = []\n",
    "\n",
    "for N in Ns:\n",
    "    model_N = []\n",
    "    for m in tqdm(range(M)):\n",
    "        fpath = f\"./simdata/[2]data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        dat_obs = read_pkl(fpath)['obs']\n",
    "\n",
    "        model_N.append( fit_models(dat_obs) )\n",
    "    all_models_correct.append( model_N )\n",
    "\n",
    "write_pkl(all_models_correct, 'all_models_correct_sim2.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf659ee6-32d1-4d8f-81db-6b624aefe88d",
   "metadata": {},
   "source": [
    "```\n",
    "100%|█████████████████████████████████████████| 500/500 [08:07<00:00,  1.03it/s]\n",
    "CPU times: user 32min 26s, sys: 2min 13s, total: 34min 40s\n",
    "Wall time: 8min 43s\n",
    "```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "2e98a6d7-802e-4b01-bcd3-f7e98eae674b",
   "metadata": {},
   "source": [
    "all_models_correct = read_pkl('all_models_correct_sim2.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6faee159-d301-4cb9-aff4-ec8cb84beceb",
   "metadata": {},
   "source": [
    "Also computing \"true\" value using average of 10 large (1M) samples here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3d56d55-98af-4730-a601-0f54dee0a3af",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vs_true = np.empty(20)\n",
    "for _ in tqdm(range(20)):\n",
    "    dd = genData(N=1_000_000, eta2=eta2)\n",
    "    tmh = true_models(dd['obs'])\n",
    "    \n",
    "    d1dh = reg1(*eta_true[:2], dd['obs'])\n",
    "    d2dh = reg2(*eta_true[2:], dd['obs'])\n",
    "    \n",
    "    Vs_true[_] = V_plugin(d1dh, d2dh, dd['obs'], tmh)\n",
    "\n",
    "write_pkl(Vs_true, \"Vs_true_largesample_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fdf5b21-4afb-4731-9997-e99031f78e37",
   "metadata": {},
   "source": [
    "```\n",
    "100%|███████████████████████████████████████████| 20/20 [26:04<00:00, 78.21s/it]\n",
    "```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "1317bc42-f053-4d3e-81ac-73a1eb9b9558",
   "metadata": {},
   "source": [
    "Vs_true = read_pkl(\"Vs_true_largesample_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f97fade9-f27c-47d7-b8ec-ebea13c125ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vs_true.mean(), Vs_true.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03d85df4-538e-4105-a9f9-0d0eff83e8c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_var = []\n",
    "for i in range(len(Ns)):\n",
    "    N = Ns[i]\n",
    "    var_N = []\n",
    "    for m in tqdm(range(M)):\n",
    "        fpath = f\"./simdata/[2]data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        dat_obs = read_pkl(fpath)['obs']\n",
    "        model_m = all_models_correct[i][m]\n",
    "\n",
    "        d1_m = reg1(*eta_true[:2], dat_obs)\n",
    "        d2_m = reg2(*eta_true[2:], dat_obs)\n",
    "\n",
    "        var_N.append( np.sqrt(var_MR(d1_m, d2_m, dat_obs, model_m) / N) )\n",
    "\n",
    "    all_var.append( var_N )\n",
    "\n",
    "all_var = np.array(all_var)\n",
    "write_pkl(all_var, \"var_eif_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0801d6f-e03b-458d-bd6f-fe26befa1487",
   "metadata": {},
   "source": [
    "```\n",
    "100%|█████████████████████████████████████████| 500/500 [38:50<00:00,  4.66s/it]\n",
    "```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "231e0315-b466-439b-9c72-5f0c38b2537a",
   "metadata": {},
   "source": [
    "all_var = read_pkl(\"var_eif_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a26b8c85-030d-4ead-ba20-438686bb151e",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_var[0].mean(), all_var[1].mean(), all_var[2].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f0da160-f6e9-465f-ad8c-47ed9a3b3b6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "V_true_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16fd5b2e-8db0-4f42-970a-217466d2cec8",
   "metadata": {},
   "outputs": [],
   "source": [
    "V_15[:, :, 0].mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddd6bf0e-bce0-47e1-b4d0-06b47f58bbf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(3, 4))\n",
    "\n",
    "ax1 = plt.subplot(131)\n",
    "L_eif_1k = V_15[:, 0, 0] - 1.96 * all_var[0]\n",
    "U_eif_1k = V_15[:, 0, 0] + 1.96 * all_var[0]\n",
    "sns.barplot([((Vs_true.mean() - L_eif_1k > 0) & (U_eif_1k - Vs_true.mean() > 0)).mean(0)])\n",
    "plt.xticks(range(1), ['$\\\\phi$'])\n",
    "plt.ylim(0.4, 1.01)\n",
    "plt.title(f\"N={Ns[0]}\")\n",
    "\n",
    "ax2 = plt.subplot(132)\n",
    "ax2.sharey(ax1)\n",
    "L_eif_2k = V_15[:, 1, 0] - 1.96 * all_var[1]\n",
    "U_eif_2k = V_15[:, 1, 0] + 1.96 * all_var[1]\n",
    "sns.barplot([((Vs_true.mean() - L_eif_2k > 0) & (U_eif_2k - Vs_true.mean() > 0)).mean(0)])\n",
    "plt.xticks(range(1), ['$\\\\phi$'])\n",
    "# plt.ylim(0.5, 1.01)\n",
    "plt.title(f\"N={Ns[1]}\")\n",
    "\n",
    "ax3 = plt.subplot(133)\n",
    "ax3.sharey(ax1)\n",
    "L_eif_5k = V_15[:, 2, 0] - 1.96 * all_var[2]\n",
    "U_eif_5k = V_15[:, 2, 0] + 1.96 * all_var[2]\n",
    "sns.barplot([((Vs_true.mean() - L_eif_5k > 0) & (U_eif_5k - Vs_true.mean() > 0)).mean(0)])\n",
    "plt.xticks(range(1), ['$\\\\phi$'])\n",
    "# plt.ylim(0.5, 1.01)\n",
    "plt.title(f\"N={Ns[2]}\")\n",
    "\n",
    "plt.setp(ax2.get_yticklabels(), visible=False)\n",
    "plt.setp(ax3.get_yticklabels(), visible=False)\n",
    "\n",
    "plt.suptitle(f\"Coverage rate (analytical, M={M})\")\n",
    "plt.tight_layout();\n",
    "\n",
    "(((Vs_true.mean() - L_eif_1k > 0) & (U_eif_1k - Vs_true.mean() > 0)).mean(0), \n",
    " ((Vs_true.mean() - L_eif_2k > 0) & (U_eif_2k - Vs_true.mean() > 0)).mean(0),\n",
    " ((Vs_true.mean() - L_eif_5k > 0) & (U_eif_5k - Vs_true.mean() > 0)).mean(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33690ea2",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "# OPL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "438fa8ff-ed65-49ee-b40b-30c675d619c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def norm1check1(etas):\n",
    "    eta1s = etas[:2]\n",
    "    return np.linalg.norm(eta1s, 2)\n",
    "\n",
    "def norm1check2(etas):\n",
    "    eta2s = etas[2:]\n",
    "    return np.linalg.norm(eta2s, 2)\n",
    "    \n",
    "norm1const1 = sp.optimize.NonlinearConstraint(norm1check1, 0., 1.)\n",
    "norm1const2 = sp.optimize.NonlinearConstraint(norm1check2, 0., 1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47807c7f-1c32-4cc6-a9a4-9cb953bf7951",
   "metadata": {},
   "outputs": [],
   "source": [
    "Decision = namedtuple('decision', ['d1', 'd2'])\n",
    "\n",
    "def decision(reg1, reg2, ga_sol, data):\n",
    "    \"\"\"\n",
    "    Makes decisions based on decision rules and GA solution.\n",
    "    \n",
    "    Args:\n",
    "      reg1: Function for the first decision rule.\n",
    "      reg2: Function for the second decision rule.\n",
    "      ga_sol: GA solution.\n",
    "      data: Data to use for decision making.\n",
    "    \n",
    "    Returns:\n",
    "      A dictionary containing d1 and d2 decisions.\n",
    "    \"\"\"\n",
    "    # Make decisions using the estimated parameters\n",
    "    d1 = reg1(*ga_sol[:2], data)\n",
    "    d2 = reg2(*ga_sol[2:], data)\n",
    "    \n",
    "    return Decision(d1, d2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af319190",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def decision_opt(df_as):\n",
    "    \"\"\"\n",
    "    Creates a new column in the DataFrame with the integer representation of the maximum subscript of y,\n",
    "    separated into individual digits.\n",
    "    \n",
    "    Args:\n",
    "      df: The input DataFrame with always-survivors.\n",
    "    \n",
    "    Returns:\n",
    "      The modified DataFrame with the new column.\n",
    "    \"\"\"    \n",
    "    # optimal decision\n",
    "    d1_opt = df_as[['y_00', 'y_01', 'y_10', 'y_11']].idxmax(axis=1).str.get(2).astype(int).values\n",
    "    d2_opt = df_as[['y_00', 'y_01', 'y_10', 'y_11']].idxmax(axis=1).str.get(3).astype(int).values\n",
    "\n",
    "    return Decision(d1_opt, d2_opt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "959dce0a-78b8-4f54-a9ee-ac7be903d0bc",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "source": [
    "```python\n",
    "d_opt = decision_opt(XX_as15)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3860da7b-bc80-40a2-9bf5-a5f13e93406c",
   "metadata": {},
   "source": [
    "### V_plugin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b628c7f4-4d24-40a8-aa53-f234341ff24a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def obj_plugin(etas, reg1, reg2, data, models):\n",
    "    # Estimate decision rules\n",
    "    d1 = reg1(*etas[:2], data)\n",
    "    d2 = reg2(*etas[2:], data)\n",
    "    \n",
    "    # Calculate MR estimator (replace V.MR with your implementation)\n",
    "    value = V_plugin(d1, d2, data, models)\n",
    "    \n",
    "    return -value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29236fcb-eba3-46c9-aa0d-fc19300388f1",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "source": [
    "```python\n",
    "%%time\n",
    "with ThreadPool(processes=mp.cpu_count()-1) as pool:\n",
    "    res_pg = sp.optimize.differential_evolution(obj_plugin, bounds=[(-1., 1.)]*len(eta_true),\n",
    "                                                args=[reg1, reg2, XX_obs15, models],\n",
    "                                                workers=pool.map,\n",
    "                                                constraints=norm1const)\n",
    "\n",
    "d_pg = decision(reg1, reg2, res_pg.x, XX_obs15)\n",
    "# d_fixed = decision(reg1, reg2, eta_true, XX_obs15)\n",
    "\n",
    "d_pg.d1[XX_as15.index], d_pg.d2[XX_as15.index]\n",
    "\n",
    "# Percentage of correct decision\n",
    "\n",
    "np.round(\n",
    "    [(d_pg.d1[XX_as15.index] == d_opt.d1).mean(),\n",
    "     (d_pg.d2[XX_as15.index] == d_opt.d2).mean()],\n",
    "    4\n",
    ") * 100\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6a21964",
   "metadata": {},
   "source": [
    "### V_MR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27260f9c-4cdf-4026-8b2f-dceb5915a82e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def obj_sdmr(etas, reg1, reg2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the SDMR estimator.\n",
    "    \n",
    "    Args:\n",
    "      etas: Parameters for decision rules.\n",
    "      reg1, reg2: Functions to estimate decision rules.\n",
    "      data: pandas DataFrame containing the data.\n",
    "      models: Dictionary containing fitted nuisance models.\n",
    "      apply_penalty: Boolean flag indicating whether to apply the constraints.\n",
    "    \n",
    "    Returns:\n",
    "      A float representing the SDMR estimator.\n",
    "    \"\"\"\n",
    "    # Estimate decision rules\n",
    "    d1 = reg1(*etas[:2], data)\n",
    "    d2 = reg2(*etas[2:], data)\n",
    "    \n",
    "    # Calculate MR estimator (replace V.MR with your implementation)\n",
    "    value = V_MR(d1, d2, data, models)\n",
    "    \n",
    "    # # Apply penalty (L2 norm)\n",
    "    # if apply_penalty:\n",
    "    #     # Penalty term (adjust if needed)\n",
    "    #     pen = np.sqrt(np.finfo(float).max)\n",
    "        \n",
    "    #     # Create lists for eta1s and eta2s\n",
    "    #     eta1s = etas[:2]\n",
    "    #     eta2s = etas[2:]\n",
    "        \n",
    "    #     # Calculate penalty\n",
    "    #     penalty = np.abs(np.sum(np.square(eta1s)) - 1) * np.abs(np.sum(np.square(eta2s)) - 1) * pen\n",
    "        \n",
    "    #     # Apply penalty to value\n",
    "    #     value -= penalty\n",
    "    \n",
    "    return -value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "720347bb-0e2e-4662-9243-470b60b44ae7",
   "metadata": {},
   "source": [
    "```python\n",
    "%%time\n",
    "with ThreadPool(processes=mp.cpu_count()-1) as pool:\n",
    "    res = sp.optimize.differential_evolution(obj_sdmr, bounds=[(-1., 1.)]*len(eta_true),\n",
    "                                             args=[reg1, reg2, XX_obs15, models],\n",
    "                                             workers=pool.map,\n",
    "                                             constraints=norm1const)\n",
    "\n",
    "eta_true, res.x\n",
    "\n",
    "d_sdmr = decision(reg1, reg2, res.x, XX_obs15)\n",
    "# d_fixed = decision(reg1, reg2, eta_true, XX_obs15)\n",
    "\n",
    "d_sdmr.d1[XX_as15.index], d_sdmr.d2[XX_as15.index]\n",
    "\n",
    "\n",
    "# Percentage of correct decision\n",
    "np.round(\n",
    "    [(d_sdmr.d1[XX_as15.index] == d_opt.d1).mean(),\n",
    "     (d_sdmr.d2[XX_as15.index] == d_opt.d2).mean()],\n",
    "    4\n",
    ") * 100\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4af03d7d-0aab-4db9-8406-669636a9ebd1",
   "metadata": {},
   "source": [
    "### Optimal linear rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e07ac9d-c48d-4232-8222-bb12d3abd60d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def obj_linear(etas, reg1, reg2, data, tm):\n",
    "    \"\"\"\n",
    "    Calculates oracle value from optimal linear decision rules.\n",
    "    \n",
    "    Args:\n",
    "      etas: Parameters for decision rules.\n",
    "      reg1, reg2: Functions to estimate decision rules.\n",
    "      data: pandas DataFrame containing the obs data.\n",
    "    \n",
    "    Returns:\n",
    "      A float representing the SDMR estimator.\n",
    "    \"\"\"\n",
    "    # Estimate decision rules\n",
    "    d1 = reg1(*etas[:2], data)\n",
    "    d2 = reg2(*etas[2:], data)\n",
    "    \n",
    "    # compute the true AS Value\n",
    "    value = V_plugin(d1, d2, data, tm)\n",
    "    \n",
    "    return -value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93c6fe86-006b-4bf9-b6c3-c0587302a81a",
   "metadata": {},
   "source": [
    "```ipython\n",
    "%%time\n",
    "with ThreadPool(processes=mp.cpu_count()-1) as pool:\n",
    "    res_opt = sp.optimize.differential_evolution(obj_linear, bounds=[(-1., 1.)]*len(eta_true),\n",
    "                                                 args=[reg1, reg2, data_15['as']],\n",
    "                                                 workers=pool.map, popsize=50,\n",
    "                                                 constraints=[norm1const1, norm1const2])\n",
    "```\n",
    "\n",
    "```\n",
    "CPU times: total: 37.4 s\n",
    "Wall time: 1min 42s\n",
    "\n",
    "> res_opt.x\n",
    "array([ 0.16369132,  0.81932788,  0.64243744, -0.16940193,  0.13017567 ])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b577085-9c74-49c7-8f6c-c6937d50ffb6",
   "metadata": {},
   "source": [
    "### True linear optimal decision rule (w/ large sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb1d636a-a4d1-43fc-af3a-962fa10e8c97",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def get_opt_pi(df_large, eta2):\n",
    "    \"\"\"\n",
    "    Simulates data and calculates estimators under different settings.\n",
    "    \n",
    "    Args:\n",
    "      m: Number of simulations.\n",
    "      \n",
    "    Returns:\n",
    "      A namedtuple containing simulation results:\n",
    "          value_hat: Matrix of estimated values (N / D) for different settings.\n",
    "          Nhats: Matrix of N estimates for different settings.\n",
    "          Dhats: Matrix of D estimates for different settings.\n",
    "    \"\"\"\n",
    "\n",
    "    # Separate observed data and all possible outcomes\n",
    "    XX_obs = df_large['obs']\n",
    "    \n",
    "    # Fit nuisance models\n",
    "    models = fit_models(XX_obs)\n",
    "    tm = true_models(XX_obs)\n",
    "\n",
    "    with ThreadPool(processes=15) as pool:\n",
    "        res_linear = sp.optimize.differential_evolution(obj_linear, bounds=[(-1., 1.)]*5,\n",
    "                                                        args=[reg1, reg2, XX_obs, tm],\n",
    "                                                        workers=pool.map,\n",
    "                                                        constraints=[norm1const1, norm1const2])\n",
    "\n",
    "    # return Result(value_hat, decisions)\n",
    "    return res_linear.x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7784f806-3e55-4c04-94e0-c6513945ec15",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "pi_opt = get_opt_pi(df_large)\n",
    "write_pkl(pi_opt, \"linear_opt_rule_sim2.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "c73a15d0-efc9-4e32-a188-cdbbf0f9165a",
   "metadata": {},
   "source": [
    "pi_opt = read_pkl(\"linear_opt_rule_sim2.pkl\")  # opt rule from `V_true_id()`\n",
    "pi_opt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9377fa18-0879-4ca6-be50-7abe45d8b776",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "source": [
    "## Simulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7559a55",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def simulate(m, eta2):\n",
    "    \"\"\"\n",
    "    Simulates data and calculates estimators under different settings.\n",
    "    \n",
    "    Args:\n",
    "      m: Number of simulations.\n",
    "      \n",
    "    Returns:\n",
    "      A namedtuple containing simulation results:\n",
    "          value_hat: Matrix of estimated values (N / D) for different settings.\n",
    "          Nhats: Matrix of N estimates for different settings.\n",
    "          Dhats: Matrix of D estimates for different settings.\n",
    "    \"\"\"\n",
    "    \n",
    "    # Define namedtuple for results\n",
    "    Result = namedtuple('Result', ['value_hat', 'decision'])\n",
    "    \n",
    "    # Initialize results matrices\n",
    "    value_hat = np.zeros((len(Ns), 3+1))\n",
    "    decisions = []\n",
    "\n",
    "    beta_ = []\n",
    "    for i in range(len(Ns)):\n",
    "        # Simulate data\n",
    "        N = Ns[i]\n",
    "\n",
    "        fpath = f\"./simdata/[2]data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        if os.path.exists(fpath):\n",
    "            dat = read_pkl(fpath)\n",
    "        else:\n",
    "            dat = genData(N, eta2)\n",
    "            write_pkl(dat, fpath)\n",
    "        \n",
    "        # Separate observed data and all possible outcomes\n",
    "        XX_obs = dat['obs']\n",
    "        XX_all = dat['all']\n",
    "        XX_as = dat['as']\n",
    "        \n",
    "        # Fit nuisance models\n",
    "        models = fit_models(XX_obs)  # Replace ... with arguments\n",
    "        # # settings that should lead to consistent estimator\n",
    "        # # : m.1256, m.2345, m.2356, m.3456, m.12346\n",
    "        # models_1256 = fit_models(XX_obs, \n",
    "        #                          p2_false=True, \n",
    "        #                          Ep2_false=True)  # 1256 correct  # 34 wrong\n",
    "        # models_2345 = fit_models(XX_obs, \n",
    "        #                          phi2_false=True, K2_false=True, \n",
    "        #                          Emupi_false=True)  # 2345 correct  # 16 wrong\n",
    "        # models_2356 = fit_models(XX_obs, \n",
    "        #                          phi2_false=True, K2_false=True,\n",
    "        #                          Ep2_false=True)  # 2356 correct  # 14 wrong\n",
    "        # models_3456 = fit_models(XX_obs, \n",
    "        #                          phi2_false=True, K2_false=True,\n",
    "        #                          phi1_false=True, K1_false=True, p1_false=True)  # 3456 correct  # 12 wrong\n",
    "        # models_12346 = fit_models(XX_obs, \n",
    "        #                           mu2_false=True)  # 12346 correct  # 5 wrong\n",
    "        \n",
    "        # ## settings that is not guaranteed to be consistent\n",
    "        # models_1234 = fit_models(XX_obs,\n",
    "        #                          mu2_false=True,\n",
    "        #                          Emupi_false=True)\n",
    "\n",
    "        with ThreadPool(processes=mp.cpu_count()-1) as pool:\n",
    "            res_mr = sp.optimize.differential_evolution(obj_sdmr, bounds=[(-1., 1.)]*5,\n",
    "                                            args=(reg1, reg2, XX_obs, models),\n",
    "                                            # workers=pool.map,\n",
    "                                          # x0=np.random.uniform(high=.9, low=-.3, size=5),\n",
    "                                            x0=np.random.uniform(pi_opt-.3, pi_opt+.132),\n",
    "                                          # method=\"SLSQP\",\n",
    "                                            constraints=[norm1const1, norm1const2])\n",
    "            res_pg = sp.optimize.differential_evolution(obj_plugin, bounds=[(-1., 1.)]*5,\n",
    "                                            args=(reg1, reg2, XX_obs, models),\n",
    "                                            # workers=pool.map,\n",
    "                                          # x0=np.random.uniform(high=.9, low=-.3, size=5),\n",
    "                                            x0=np.random.uniform(pi_opt-.3, pi_opt+.132),\n",
    "                                          # method=\"SLSQP\",\n",
    "                                            constraints=[norm1const1, norm1const2])\n",
    "        # res_linear = sp.optimize.differential_evolution(obj_linear, bounds=[(-1., 1.)]*5,\n",
    "        #                                                 args=[reg1, reg2, XX_obs, tm],\n",
    "        #                                                 workers=pool.map,\n",
    "        #                                                 constraints=[norm1const1, norm1const2])\n",
    "\n",
    "        # # decisions\n",
    "        # d_opt = decision_opt(XX_as)\n",
    "\n",
    "        # d_lin = decision(reg1, reg2, res_linear.x, XX_as)\n",
    "        \n",
    "        # d_mr = decision(reg1, reg2, res_mr.x, XX_obs)\n",
    "        # d_mr_as = Decision(d_mr.d1[XX_as.index], d_mr.d2[XX_as.index])\n",
    "        \n",
    "        # d_pg = decision(reg1, reg2, res_pg.x, XX_obs)\n",
    "        # d_pg_as = Decision(d_pg.d1[XX_as.index], d_pg.d2[XX_as.index])\n",
    "        \n",
    "        # decisions.append((d_opt, d_lin, d_mr_as, d_pg_as))\n",
    "        \n",
    "        # # Estimators with different models\n",
    "        # all_models = [\n",
    "        #     models, \n",
    "        #     #models_1256, models_2345, models_2356, models_3456, models_12346, \n",
    "        #     #models_1234\n",
    "        # ]\n",
    "        # n_model = len(all_models)\n",
    "        \n",
    "        # # true survivor-optimal value\n",
    "        # value_hat[i, 0] = XX_as[['y_00', 'y_01', 'y_10', 'y_11']].max(axis=1).mean()\n",
    "        # value_hat[i, 1] = -res_linear.fun\n",
    "        \n",
    "        # # MR estimators\n",
    "        # for mm in range(n_model):\n",
    "        #     value_hat[i, mm + 2] = V_MR(d_mr.d1, d_mr.d2, XX_obs, all_models[mm])\n",
    "\n",
    "        # # Plugin estimators\n",
    "        # for mm in range(n_model):\n",
    "        #     value_hat[i, mm + n_model + 2] = V_plugin(d_pg.d1, d_pg.d2, XX_obs, all_models[mm])\n",
    "        beta_.append([res_mr.x, res_pg.x])\n",
    "\n",
    "    # return Result(value_hat, decisions)\n",
    "    return beta_"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2987e9bf-b3a5-404b-a214-d172dbcfc055",
   "metadata": {},
   "source": [
    "Initialized with `pi_opt + uniform(-.25, .25)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59065078-69cd-4fdf-9f7a-3c7fa4d5d790",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ```python\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "beta_opt = []\n",
    "for m in tqdm(range(M)):\n",
    "    beta_opt.append( simulate(m, eta2) )\n",
    "\n",
    "warnings.resetwarnings()\n",
    "\n",
    "beta_opt = np.array(beta_opt)\n",
    "write_pkl(beta_opt, \"mr,pg_opt_sim2.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66d8c17c-18be-4bf3-a6c5-45b6bece9b36",
   "metadata": {},
   "source": [
    "```\n",
    "beta_opt = read_pkl(\"mr,pg_opt_sim2.pkl\")\n",
    "```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "d25898be-9762-46f8-b602-e8207158ebb2",
   "metadata": {},
   "source": [
    "beta_opt = []\n",
    "for m in range(M):\n",
    "    try:\n",
    "        fpath = f\"./simresult/mr,pg_opt_sim2_{m:03d}.pkl\"\n",
    "        beta_opt_m = read_pkl(fpath)\n",
    "        beta_opt_m = [x[0] for x in beta_opt_m]  # only MR. no PG result\n",
    "        beta_opt.append(beta_opt_m)\n",
    "    except:\n",
    "        print(m, end=', ')\n",
    "\n",
    "beta_opt = np.array(beta_opt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6a62c7d-76f5-4479-b361-2909135152ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_opt.shape  # (n_dataset, n_Ns, n_eta)  # only MR. no PG result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5027673f-a013-4d7b-ae83-5f363eb810e9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ```python\n",
    "Vhat_opl_single = []\n",
    "\n",
    "for N in Ns:\n",
    "    for m in tqdm(range(M)):\n",
    "        # load data and beta_opt\n",
    "        fpath = f\"./simdata/[2]data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        dat_obs = read_pkl(fpath)['obs']\n",
    "        beta_opt_m = beta_opt[m]\n",
    "        \n",
    "        # decision rule params\n",
    "        if N == Ns[0]:\n",
    "            beta_mr_dat = beta_opt_m[0]  # N==1000\n",
    "        elif N == Ns[1]:\n",
    "            beta_mr_dat = beta_opt_m[1]  # N==2000\n",
    "        else:\n",
    "            beta_mr_dat = beta_opt_m[2]  # N==5000\n",
    "        \n",
    "        # fixed decisions\n",
    "        d1mr_dat = reg1(*beta_mr_dat[:2], dat_obs)\n",
    "        d2mr_dat = reg2(*beta_mr_dat[2:], dat_obs)\n",
    "        \n",
    "        # d1pg_dat = reg1(*beta_pg_dat[:2], dat_obs)\n",
    "        # d2pg_dat = reg2(*beta_pg_dat[2:], dat_obs)\n",
    "        \n",
    "        # Fit nuisance models\n",
    "        model_dat = fit_models(dat_obs)  # Replace ... with arguments\n",
    "        \n",
    "        # MR estimator\n",
    "        # Plugin estimators (replace function names with yours)\n",
    "        vmr_dat = V_MR(d1mr_dat, d2mr_dat, dat_obs, model_dat)\n",
    "        # vpg_dat = V_plugin(d1pg_dat, d2pg_dat, dat_obs, model_dat)\n",
    "\n",
    "        Vhat_opl_single.append([vmr_dat, None])\n",
    "\n",
    "Vhat_opl_single = np.array(Vhat_opl_single).reshape(len(Ns), M, -1)  # Ns, M, mrpg\n",
    "write_pkl(Vhat_opl_single, \"Vhat_opl_single_sim2.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f7a885a-2a9b-4dbf-9654-0ec94b99d877",
   "metadata": {},
   "source": [
    "```\n",
    "100%|█████████████████████████████████████████| 500/500 [07:55<00:00,  1.05it/s]\n",
    "100%|█████████████████████████████████████████| 500/500 [09:28<00:00,  1.14s/it]\n",
    "100%|███████████████████████████████████████| 500/500 [1:17:06<00:00,  9.25s/it]\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e3b27dd-8303-4a22-96d0-51c1d7ad76bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vhat_opl_single.shape"
   ]
  },
  {
   "cell_type": "raw",
   "id": "1ba6eff7-c559-4379-bdab-1fdf53062127",
   "metadata": {},
   "source": [
    "Vhat_opl_single = read_pkl(\"Vhat_opl_single_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37821108-3f9c-47a3-b1c8-8c1ce978f114",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "V_opl_single = []\n",
    "\n",
    "for N in Ns:\n",
    "    for m in tqdm(range(M)):\n",
    "        # load data and beta_opt\n",
    "        # fpath = f\"./simdata/[2]data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        # dat_obs = read_pkl(fpath)['obs']\n",
    "        dat_obs = df_large['obs']\n",
    "        beta_opt_m = beta_opt[m]\n",
    "        \n",
    "        # decision rule params\n",
    "        if N == Ns[0]:\n",
    "            beta_mr_dat = beta_opt_m[0]  # N==1000\n",
    "        elif N == Ns[1]:\n",
    "            beta_mr_dat = beta_opt_m[1]  # N==2000\n",
    "        else:\n",
    "            beta_mr_dat = beta_opt_m[2]  # N==5000\n",
    "        \n",
    "        # fixed decisions\n",
    "        d1mr_dat = reg1(*beta_mr_dat[:2], dat_obs)\n",
    "        d2mr_dat = reg2(*beta_mr_dat[2:], dat_obs)\n",
    "        \n",
    "        # d1pg_dat = reg1(*beta_pg_dat[:2], dat_obs)\n",
    "        # d2pg_dat = reg2(*beta_pg_dat[2:], dat_obs)\n",
    "        \n",
    "        # # Fit nuisance models\n",
    "        # tm_dat = true_models(dat_obs)\n",
    "        \n",
    "        # MR estimator\n",
    "        # 'true' value under best learned decision rules\n",
    "        vmr_dat = V_plugin(d1mr_dat, d2mr_dat, dat_obs, tm)\n",
    "        # vpg_dat = V_plugin(d1pg_dat, d2pg_dat, dat_obs, tm_dat)\n",
    "\n",
    "        V_opl_single.append([vmr_dat, None])\n",
    "\n",
    "\n",
    "V_opl_single = np.array(V_opl_single).reshape(len(Ns), M, -1)  # Ns, M, mrpg\n",
    "write_pkl(V_opl_single, \"V_opl_single_sim2.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c21dc03b-e489-4e9c-b99a-6f0730c63e2f",
   "metadata": {},
   "source": [
    "```\n",
    "100%|███████████████████████████████████████| 500/500 [1:26:20<00:00, 10.36s/it]\n",
    "100%|███████████████████████████████████████| 500/500 [1:18:46<00:00,  9.45s/it]\n",
    "100%|███████████████████████████████████████| 500/500 [1:27:27<00:00, 10.50s/it]\n",
    "```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "66bd3b7d-4b22-4a56-9148-276357a1a645",
   "metadata": {},
   "source": [
    "V_opl_single = read_pkl(\"V_opl_single_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3552f937-890f-4271-a0bc-ba9428bee918",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vhat_opl_single.shape, V_opl_single.shape  # Ns, M, mrpg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2f63012-58bf-4fb0-b995-3c4fc3a0543c",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vall_opl_single = np.c_[Vhat_opl_single[:,:,[0]], V_opl_single[:,:,[0]]]\n",
    "Vall_opl_single.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0cf7610-486f-4cda-95ea-478b327bfb6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "d1_opt = reg1(*pi_opt[:2], df_large[\"obs\"])\n",
    "d2_opt = reg2(*pi_opt[2:], df_large[\"obs\"])\n",
    "V_true_id_opt = V_plugin(d1_opt, d2_opt, df_large[\"obs\"], tm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20dedfd6-020c-4918-a69f-2970dfcdeff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,3))\n",
    "\n",
    "ax1 = plt.subplot(131)\n",
    "plt.boxplot(\n",
    "    Vall_opl_single[0,:,:], tick_labels=[r'$\\widehat{V}^{MR}(\\widehat{\\beta})$', \n",
    "                                         r'${V}(\\widehat{\\beta})$'],\n",
    "    flierprops={'marker': 'x', 'markersize': 3, 'markeredgewidth': .5, 'alpha': .7}\n",
    ")\n",
    "# plt.xticks(rotation=-90)\n",
    "# plt.ylim(V_true*0.05, V_true*5)\n",
    "plt.title(f\"N={Ns[0]}\")\n",
    "plt.axhline(V_true_id_opt, color='r', ls=\"-\", lw=.75)\n",
    "# plt.axhline(V_true, color='r', ls=\"-\", lw=.75, label='True V');\n",
    "# plt.axvline(7.5, color='k', ls=\"-\", lw=1);\n",
    "# plt.axvline(7*2+.5, color='k', ls=\"-\", lw=1);\n",
    "plt.ylim(0, 1.75)\n",
    "\n",
    "ax2 = plt.subplot(132)\n",
    "ax2.sharey(ax1)\n",
    "plt.boxplot(\n",
    "    Vall_opl_single[1,:,:], tick_labels=[r'$\\widehat{V}^{MR}(\\widehat{\\beta})$', \n",
    "                                         r'${V}(\\widehat{\\beta})$'],\n",
    "    flierprops={'marker': 'x', 'markersize': 3, 'markeredgewidth': .5, 'alpha': .7}\n",
    ")\n",
    "# plt.xticks(rotation=-90)\n",
    "# plt.ylim(V_true*0.05, V_true*5)\n",
    "plt.title(f\"N={Ns[1]}\")\n",
    "plt.axhline(V_true_id_opt, color='r', ls=\"-\", lw=.75)\n",
    "# plt.axhline(V_true, color='r', ls=\"-\", lw=.75, label='True V');\n",
    "# plt.axvline(7.5, color='k', ls=\"-\", lw=1);\n",
    "\n",
    "ax3 = plt.subplot(133)\n",
    "ax3.sharey(ax1)\n",
    "plt.boxplot(\n",
    "    Vall_opl_single[2,:,:], tick_labels=[r'$\\widehat{V}^{MR}(\\widehat{\\beta})$', \n",
    "                                         r'${V}(\\widehat{\\beta})$'],\n",
    "    flierprops={'marker': 'x', 'markersize': 3, 'markeredgewidth': .5, 'alpha': .7}\n",
    ")\n",
    "# plt.xticks(rotation=-90)\n",
    "# plt.ylim(V_true*0.05, V_true*5)\n",
    "plt.title(f\"N={Ns[2]}\")\n",
    "plt.axhline(V_true_id_opt, color='r', ls=\"-\", lw=.75)\n",
    "\n",
    "plt.suptitle(\"V (=N/D)\")\n",
    "plt.tight_layout();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c969c69-fb8b-4d75-889e-b5f24d3da3be",
   "metadata": {},
   "source": [
    "## Analytic confidence interval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6751f7f2-81a3-46ab-b047-7a3600e7052f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ```python\n",
    "all_var_eif = []\n",
    "for i in range(len(Ns)):\n",
    "    cnt = 0\n",
    "    \n",
    "    N = Ns[i]\n",
    "    var_N = []\n",
    "    for m in tqdm(range(M)):\n",
    "        fpath = f\"./simdata/[2]data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        dat_obs = read_pkl(fpath)['obs']\n",
    "        model_m = all_models_correct[i][m]\n",
    "\n",
    "        beta_opt_m = beta_opt[m]\n",
    "            \n",
    "        # decision rule params\n",
    "        if N == Ns[0]:\n",
    "            beta_mr_dat = beta_opt_m[0]  # N==1000\n",
    "        elif N == Ns[1]:\n",
    "            beta_mr_dat = beta_opt_m[1]  # N==2000\n",
    "        else:\n",
    "            beta_mr_dat = beta_opt_m[2]  # N==5000\n",
    "        \n",
    "        # fixed decisions\n",
    "        d1mr_dat = reg1(*beta_mr_dat[:2], dat_obs)\n",
    "        d2mr_dat = reg2(*beta_mr_dat[2:], dat_obs)\n",
    "        \n",
    "        # d1pg_dat = reg1(*beta_pg_dat[:2], dat_obs)\n",
    "        # d2pg_dat = reg2(*beta_pg_dat[2:], dat_obs)\n",
    "\n",
    "        var_N.append( np.sqrt(var_MR(d1mr_dat, d2mr_dat, dat_obs, model_m) / N) )\n",
    "\n",
    "    all_var_eif.append( var_N )\n",
    "\n",
    "all_var_eif = np.array(all_var_eif)\n",
    "write_pkl(all_var_eif, \"var_opl_eif_sim2.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77fb15fd-d975-4de7-a937-dcce9add5edc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "Vs_true_opt = np.empty(20)\n",
    "for _ in tqdm(range(20)):\n",
    "    dd = genData(N=1_000_000, eta2=eta2)\n",
    "    tmh = true_models(dd['obs'])\n",
    "    \n",
    "    d1dh = reg1(*pi_opt[:2], dd[\"obs\"])\n",
    "    d2dh = reg2(*pi_opt[2:], dd[\"obs\"])\n",
    "    \n",
    "    Vs_true_opt[_] = V_plugin(d1dh, d2dh, dd['obs'], tmh)\n",
    "\n",
    "write_pkl(Vs_true_opt, \"Vs_true_opt_largesample_sim2.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "91e9b8d6-89b1-4526-bc52-be7bf8b089cb",
   "metadata": {},
   "source": [
    "Vs_true_opt = write_pkl(\"Vs_true_opt_largesample_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ae3c250-8575-4d37-85f3-9eacdb8c2bfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(3, 4))\n",
    "\n",
    "plt.subplot(131)\n",
    "L_eif_1k = Vhat_opl_single[0,:,0] - 1.96 * all_var_eif[0]\n",
    "U_eif_1k = Vhat_opl_single[0,:,0] + 1.96 * all_var_eif[0]\n",
    "sns.barplot([((Vs_true_opt.mean() - L_eif_1k > 0) & (U_eif_1k - Vs_true_opt.mean() > 0)).mean(0)])\n",
    "plt.xticks(range(1), ['$\\\\phi$'])\n",
    "plt.ylim(0.5, 1.01)\n",
    "plt.title(f\"N={Ns[0]}\")\n",
    "\n",
    "plt.subplot(132)\n",
    "L_eif_2k = Vhat_opl_single[1,:,0] - 1.96 * all_var_eif[1]\n",
    "U_eif_2k = Vhat_opl_single[1,:,0] + 1.96 * all_var_eif[1]\n",
    "sns.barplot([((Vs_true_opt.mean() - L_eif_2k > 0) & (U_eif_2k - Vs_true_opt.mean() > 0)).mean(0)])\n",
    "plt.xticks(range(1), ['$\\\\phi$'])\n",
    "plt.ylim(0.5, 1.01)\n",
    "plt.title(f\"N={Ns[1]}\")\n",
    "\n",
    "plt.subplot(133)\n",
    "L_eif_5k = Vhat_opl_single[2,:,0] - 1.96 * all_var_eif[2]\n",
    "U_eif_5k = Vhat_opl_single[2,:,0] + 1.96 * all_var_eif[2]\n",
    "sns.barplot([((Vs_true_opt.mean() - L_eif_5k > 0) & (U_eif_5k - Vs_true_opt.mean() > 0)).mean(0)])\n",
    "plt.xticks(range(1), ['$\\\\phi$'])\n",
    "plt.ylim(0.5, 1.01)\n",
    "plt.title(f\"N={Ns[2]}\")\n",
    "\n",
    "plt.suptitle(f\"Coverage rate (analytical, M={M})\")\n",
    "plt.tight_layout();\n",
    "\n",
    "(((Vs_true_opt.mean() - L_eif_1k > 0) & (U_eif_1k - Vs_true_opt.mean() > 0)).mean(0), \n",
    " ((Vs_true_opt.mean() - L_eif_2k > 0) & (U_eif_2k - Vs_true_opt.mean() > 0)).mean(0),\n",
    " ((Vs_true_opt.mean() - L_eif_5k > 0) & (U_eif_5k - Vs_true_opt.mean() > 0)).mean(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ab60d69-aaf3-4d59-9ac0-3511b90f4432",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4ef5d539-a3fa-4c1a-a209-2fc993d67009",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "source": [
    "## Percentage of correct decisions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23c276b0-fc29-4249-afad-9780ec0eb057",
   "metadata": {},
   "source": [
    "### ID formula for AS-PCD"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63ad1e1a-405c-4bc3-8d75-a26b00afe391",
   "metadata": {},
   "source": [
    "$$\n",
    "\\mathbb E[h(X_1) | U_2=1111] = \\mathbb E \\left[ h(X_1) \\frac{ p_1^0(X_1) \\mathbb E\\left( p_2^0(H_1) | X_1, A_1=0, C_1=0, S_1=1 \\right) }{ \\mathbb E\\left[ p_1^0(X_1) \\mathbb E\\left( p_2^0(H_1) | X_1, A_1=0, C_1=0, S_1=1 \\right) \\right] } \\right]\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "603b3f25-1a58-4197-8212-0eeda74c2279",
   "metadata": {},
   "source": [
    "$$\n",
    "h(X_1; \\widehat\\pi, \\pi^*) = \\mathbb E\\left[ \\mathbb1\\left\\{\\widehat{\\pi}(H_1) = \\pi^*(H_1)\\right\\} | X_1, A_1=\\pi_1(X_1), C_1=0, S_1=1\\right]\n",
    "$$\n",
    "\n",
    "$$\n",
    "\\pi(H_1^{\\pi_1}) = \\left( \\pi_1(X_1), \\pi_2\\left(X_1, \\pi_1(X_1), X_2^{\\pi_1(X_1)}\\right) \\right)\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4d0cd8b-d3eb-4550-b3a4-5d1c2c062bad",
   "metadata": {},
   "outputs": [],
   "source": [
    "def h_(d1, d2, d1_opt, d2_opt, data, models):\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'],\n",
    "            'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        pred = models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        return models['K1.hat'].predict(fitdf).values\n",
    "\n",
    "    def sp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'A1': a1, \n",
    "        })\n",
    "        return models['p1.hat'].predict(fitdf).values\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    \n",
    "    data_filled[\"target\"] = (data_filled.A1 == d1) * \\\n",
    "                            (data_filled.C1 == 0) * \\\n",
    "                            (data_filled.S1 == 1) / pcs1(d1) * \\\n",
    "                            (d1==d1_opt) * (d2==d2_opt)\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    # data_filled[\"target\"] = m2(d1, d2)\n",
    "    # data_filtered = data_filled[(data_filled.A1 == d1) & \n",
    "    #                             (data_filled.C1 == 0) & \n",
    "    #                             (data_filled.S1 == 1)]\n",
    "\n",
    "    # GAM\n",
    "    bs = BSplines(data_filled[\"X1\"], df=[3+3+1], degree=[3])\n",
    "    h_model = smf.glmgam(\"target ~ 1 + X1 + I(X1**2)\", data=data_filled, smoother=bs).fit()\n",
    "    # m_m2 = np.zeros(len(data_filled))\n",
    "    h_val = h_model.predict(data_filled, exog_smooth=bs.x)\n",
    "    \n",
    "    return h_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6ef8170-7b75-41c7-b1a7-6109b84dfe1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def N_pcd(d1, d2, d1_opt, d2_opt, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the N-hat estimator based on data, fitted models, Ep2.false flag,\n",
    "    and additional arguments d1 and d2.\n",
    "\n",
    "    Args:\n",
    "        d1: A scalar value.\n",
    "        d2: A scalar value.\n",
    "        data: A NumPy array containing the data.\n",
    "        models: A dictionary containing fitted models (ps1, ps2, cp1, cp2, sp1, sp2, m2).\n",
    "        Ep2_false: Boolean flag indicating whether to use a zero model for m_p200.\n",
    "\n",
    "    Returns:\n",
    "        A float representing the N-hat estimator.\n",
    "    \"\"\"\n",
    "    # Extract prediction functions for models\n",
    "    sp1 = models['sp1']; sp2 = models['sp2']\n",
    "    m2 = models['m2']\n",
    "\n",
    "    # Conditional mean of p_2^0(H_1)\n",
    "    p10x = sp1(0)\n",
    "    p200x = sp2(0, 0)\n",
    "    # data_filled[\"target\"] = p200x  # Add target column\n",
    "    E_p200 = m_p200(data, models)\n",
    "\n",
    "    # h(X1)\n",
    "    h_val = h_(d1, d2, d1_opt, d2_opt, data, models)\n",
    "\n",
    "    # N-hat calculation (empirical version)\n",
    "    Nhat = h_val * p10x * E_p200\n",
    "\n",
    "    # Handle missing values and calculate mean\n",
    "    Nhat[np.isnan(Nhat)] = 0\n",
    "    Nhat = np.mean(Nhat)\n",
    "\n",
    "    return Nhat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2720495b-8356-403e-97fd-b2a3d64b0e48",
   "metadata": {},
   "outputs": [],
   "source": [
    "def PCD_AS(d1, d2, d1_opt, d2_opt, data, models):\n",
    "    val = N_pcd(d1, d2, d1_opt, d2_opt, data, models) / D_hat(data, models)\n",
    "    return val"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f72cb4a-998d-4b25-971b-7959f40556a1",
   "metadata": {},
   "source": [
    "### Sim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4799318b-b247-4275-a34d-5423ca4552f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "dd = genData(N=100_000, eta2=eta2)\n",
    "dd_as_obs = dd['obs'].loc[dd[\"as\"].index, :]\n",
    "\n",
    "write_pkl(dd, \"large_obs_PCD_sim2.pkl\")\n",
    "write_pkl(dd_as_obs, \"large_as_obs_PCD_sim2.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "dd9c110a-3b59-42cf-b0bf-4eb499de00f2",
   "metadata": {},
   "source": [
    "dd = read_pkl(\"large_obs_PCD_sim2.pkl\")\n",
    "dd_as_obs = read_pkl(\"large_as_obs_PCD_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f09b11ac-e02a-45d2-ad4f-5fd428fd8780",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ```python\n",
    "pcd_mr_all = []\n",
    "\n",
    "tm_dd = true_models(data=dd['obs'])\n",
    "\n",
    "for m in tqdm(range(M)):\n",
    "    beta_opt_m = beta_opt[m]\n",
    "    \n",
    "    d1dh = reg1(*pi_opt[:2], dd['obs'])\n",
    "    d2dh = reg2(*pi_opt[2:], dd['obs'])\n",
    "    \n",
    "    d1mr1k = reg1(*beta_opt_m[0,:2], dd['obs'])\n",
    "    d2mr1k = reg2(*beta_opt_m[0,2:], dd['obs'])\n",
    "    \n",
    "    d1mr2k = reg1(*beta_opt_m[1,:2], dd['obs'])\n",
    "    d2mr2k = reg2(*beta_opt_m[1,2:], dd['obs'])\n",
    "    \n",
    "    d1mr5k = reg1(*beta_opt_m[2,:2], dd['obs'])\n",
    "    d2mr5k = reg2(*beta_opt_m[2,2:], dd['obs'])\n",
    "    \n",
    "    pcd_mr = (\n",
    "        PCD_AS(d1mr1k, d2mr1k, d1dh, d2dh, dd['obs'], tm_dd),\n",
    "        PCD_AS(d1mr2k, d2mr2k, d1dh, d2dh, dd['obs'], tm_dd),\n",
    "        PCD_AS(d1mr5k, d2mr5k, d1dh, d2dh, dd['obs'], tm_dd)\n",
    "    )\n",
    "    \n",
    "    pcd_mr_all.append(pcd_mr)\n",
    "    \n",
    "pcd_mr_all = np.array(pcd_mr_all)\n",
    "\n",
    "write_pkl(pcd_mr_all, 'pcd_mr_all_sim2.pkl')\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "a06d9d39-23af-40c9-bdc1-b36c0f041eca",
   "metadata": {},
   "source": [
    "pcd_mr_all = read_pkl('pcd_mr_all_sim2.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9b89173-ed8c-4069-93e5-47335fa32dcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "pcd_1k_prod = np.c_[\n",
    "    pcd_mr_all[:, 0]\n",
    "]\n",
    "pcd_2k_prod = np.c_[\n",
    "    pcd_mr_all[:, 1]\n",
    "]\n",
    "pcd_5k_prod = np.c_[\n",
    "    pcd_mr_all[:, 2]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2caa35a8-beb4-463d-bf0d-86b8faaa2859",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pcd_1k = pd.DataFrame(pcd_1k_prod, columns=[r\"$\\phi$\"])#, \"AIPW\", \"Q-learning\"])\n",
    "df_pcd_2k = pd.DataFrame(pcd_2k_prod, columns=[r\"$\\phi$\"])#, \"AIPW\", \"Q-learning\"])\n",
    "df_pcd_5k = pd.DataFrame(pcd_5k_prod, columns=[r\"$\\phi$\"])#, \"AIPW\", \"Q-learning\"])\n",
    "# df_pcd_1k = pd.melt(df_pcd_1k)\n",
    "# df_pcd_2k = pd.melt(df_pcd_2k)\n",
    "# df_pcd_1k.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "292f2c1e-8d17-40df-8b78-3ddebddb15f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pcd = np.hstack([df_pcd_1k, df_pcd_2k, df_pcd_5k])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d86a6e04-9321-4f10-85bf-b5db999a4fca",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(3,4))\n",
    "# sns.stripplot(df_pcd, size=3, alpha=.05, color='k')\n",
    "plt.boxplot(df_pcd, positions=range(3),\n",
    "            flierprops={'marker': 'x', 'markersize': 3, \n",
    "                        'markeredgewidth': .5, 'alpha': .7})\n",
    "plt.xticks(range(3), Ns)\n",
    "plt.title(\"Percentage of Correct Decisions \\non Always-Survivors\")\n",
    "# plt.ylim(0.795, 1.005)\n",
    "plt.xlabel(\"Training N\")\n",
    "plt.ylabel(\"%\")\n",
    "plt.yticks(np.arange(0.7, 1.01, 0.02), range(70, 101, 2));"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43398a31-0cdd-4494-a91f-88a58d34ad3a",
   "metadata": {},
   "source": [
    "Tested on data of size 100,000."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2009350c-32d9-4f5c-9163-02a56100d5ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pcd.mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0846cf94-8f9b-4b21-abac-8944e18e0a84",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pcd.std(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a04b69f1-f437-4b85-8cf9-acaf00b0449e",
   "metadata": {},
   "source": [
    "### AIPW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5142c75d-00a3-4b83-b6d4-afc7c5a999ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_models_aipcw(data, phi1_false=False, phi2_false=False,\n",
    "               K1_false=False, K2_false=False,\n",
    "               p1_false=False, p2_false=False,\n",
    "               mu2_false=False, Ep2_false=False, \n",
    "               Emupi_false=False):\n",
    "    \"\"\"\n",
    "    Fits various nuisance models based on boolean flags.\n",
    "\n",
    "    Args:\n",
    "        data: A pandas DataFrame containing the observed data.\n",
    "        phi1_false, phi2_false, etc.: Boolean flags indicating whether to use\n",
    "            a point-five model instead of fitting the corresponding model.\n",
    "        Ep2_false, Emu2p2_false: (Currently unused)\n",
    "\n",
    "    Returns:\n",
    "        A dictionary containing fitted models and prediction functions.\n",
    "    \"\"\"\n",
    "\n",
    "    models = {}\n",
    "    if not phi1_false:\n",
    "        models['phi1.hat'] = smf.glm(\"A1 ~ 1 + I(X1**2)\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['phi1.hat'] = smf.glm(\"A1 ~ 0 + X1 + I(X1**2)\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "\n",
    "    if not phi2_false:\n",
    "        models['phi2.hat'] = smf.glm(\"A2 ~ 1 + I(X1**2) + X2 + I(X2**2)\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['phi2.hat'] = smf.glm(\"A2 ~ 1 + X2\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "\n",
    "    if not K1_false:\n",
    "        models['K1.hat'] = smf.glm(\"I(1-C1) ~ 1 + I(X1**2)\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['K1.hat'] = smf.glm(\"I(1-C1) ~ 0 + X1\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "\n",
    "    if not K2_false:\n",
    "        models['K2.hat'] = smf.glm(\"I(1-C2) ~ 1 + X1 + X2 + A2 + A2:X2\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['K2.hat'] = smf.glm(\"I(1-C2) ~ 1 + X2\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "\n",
    "    if not p1_false:\n",
    "        models['p1.hat'] = smf.glm(\"S1 ~ 1 + I(X1**2) + A1 + A1:X1\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['p1.hat'] = smf.glm(\"S1 ~ 0 + A1\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "        \n",
    "    if not p2_false:\n",
    "        models['p2.hat'] = smf.glm(\"S2 ~ 1 + X1 + X1:X2 + A1 + A2\", \n",
    "                                   #\"S2 ~ 1 + X1 + I(X2**2) + A1 + A2 + A2:I(X2**2)\", \n",
    "                                   # \"S2 ~ 1 + X1 + X2 + A1 + A2 + A2:X2\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['p2.hat'] = smf.glm(\"S2 ~ 0 + A1 + A2\",\n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "\n",
    "    models['mu1.hat'] = smf.ols(\"Y ~ 1 + I(X1**2) + A1 + A1:X1\", data=data).fit()\n",
    "    \n",
    "    if not mu2_false:\n",
    "        models['mu2.hat'] = smf.ols(\"Y ~ 1 + X1 + A1 + A1:X1 + I(np.exp(X2)) + A2 + A1:A2 + X2:A2\", \n",
    "                                    data=data).fit()\n",
    "    else:\n",
    "        models['mu2.hat'] = smf.ols(\"Y ~ 1 + X1 + X2\", \n",
    "                                    data=data).fit()\n",
    "\n",
    "    # Return dictionary with models and prediction functions\n",
    "    return models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3b82ae1-8fd4-4c14-8c86-d1a4ed8816bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def V_aipcw(d1, d2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of N.\n",
    "\n",
    "    Args:\n",
    "        d1: First decision rule estimate.\n",
    "        d2: Second decision rule estimate.\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, m2, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of N.\n",
    "    \"\"\"\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "            \n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'],\n",
    "            'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        pred = models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def ps2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, len(data.X1))\n",
    "        \n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        pred = models['phi2.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a2)) & (a2 == 0)] = 1 - pred[(~np.isnan(a2)) & (a2 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        return models['K1.hat'].predict(fitdf).values\n",
    "\n",
    "    def cp2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['K2.hat'].predict(fitdf).values\n",
    "    \n",
    "    def sp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'A1': a1, \n",
    "        })\n",
    "        return models['p1.hat'].predict(fitdf).values\n",
    "    \n",
    "    def sp2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'X2': data.X2,\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['p2.hat'].predict(fitdf).values\n",
    "\n",
    "    def pcs1pcs2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2) * sp2(a1, a2)\n",
    "    \n",
    "    def pcs1pc2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    def pc1(a1):\n",
    "        return ps1(a1) * cp1(a1)\n",
    "\n",
    "    def q1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], \n",
    "            'A1': a1,\n",
    "        })\n",
    "        return models['mu1.hat'].predict(fitdf).values\n",
    "\n",
    "    def q2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['mu2.hat'].predict(fitdf).values\n",
    "    \n",
    "    # Define functions for cs1pcs2, cs1pc2, and cs1\n",
    "    def pcs1pcs2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1,a2) * cp2(a1,a2) * sp2(a1,a2)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    # attach data_filled\n",
    "    A1 = data_filled.A1; A2 = data_filled.A2; \n",
    "    C1 = data_filled.C1; C2 = data_filled.C2; \n",
    "    S1 = data_filled.S1; S2 = data_filled.S2; Y = data_filled.Y\n",
    "\n",
    "    e1val = pcs1(d1)\n",
    "    e2val = pcs1pcs2(d1, d2)\n",
    "    q1val = q1(d1)\n",
    "    q2val = q2(d1, d2)\n",
    "\n",
    "    w1 = (A1==d1)*(A2==d2)*(1-C1)*(1-C2)*S1*S2 / e2val\n",
    "    w2 = (A1==d1)*(1-C1)*S1 / e1val\n",
    "    aipw1 = w1 * (Y - q2val)\n",
    "    aipw2 = w2 * (q2val - q1val)\n",
    "    aipw3 = q1val\n",
    "\n",
    "    aipw1 = aipw1.sum() / w1.sum()\n",
    "    aipw2 = aipw2.sum() / w2.sum()\n",
    "    aipw3 = aipw3.mean()\n",
    "    \n",
    "    return aipw1 + aipw2 + aipw3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a737df66-a764-4c2d-a678-80e71bf13c8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "mdl_aipw = fit_models_aipcw(XX_obs15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6788ae1-18eb-41f4-a446-bb4addd3855e",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = fit_models(XX_obs15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ccbc480-fa47-46ea-a5e7-19a1582ef302",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "d1_true = reg1(*eta_true[:2], XX_obs15)\n",
    "d2_true = reg2(*eta_true[2:], XX_obs15)\n",
    "\n",
    "d1_true.mean(), d2_true.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c06cba2b-e7aa-457c-9aab-e823c1e191d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "V_aipcw(d1_true, d2_true, XX_obs15, mdl_aipw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4003453e-cfc6-43ba-be3b-df70463cc840",
   "metadata": {},
   "outputs": [],
   "source": [
    "V_MR(d1_true, d2_true, XX_obs15, models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f804a3c-a6cf-4e9a-9d0d-99d1dbcaab66",
   "metadata": {},
   "outputs": [],
   "source": [
    "def obj_aipw(etas, reg1, reg2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the SDMR estimator.\n",
    "    \n",
    "    Args:\n",
    "      etas: Parameters for decision rules.\n",
    "      reg1, reg2: Functions to estimate decision rules.\n",
    "      data: pandas DataFrame containing the data.\n",
    "      models: Dictionary containing fitted nuisance models.\n",
    "      apply_penalty: Boolean flag indicating whether to apply the constraints.\n",
    "    \n",
    "    Returns:\n",
    "      A float representing the SDMR estimator.\n",
    "    \"\"\"\n",
    "    # Estimate decision rules\n",
    "    d1 = reg1(*etas[:2], data)\n",
    "    d2 = reg2(*etas[2:], data)\n",
    "    \n",
    "    # Calculate MR estimator (replace V.MR with your implementation)\n",
    "    value = V_aipcw(d1, d2, data, models)\n",
    "    \n",
    "    return -value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e01e0dad-d8fd-4ec3-b05b-ef9160f9e25a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def simulate_aipw(m, eta2):\n",
    "    \"\"\"\n",
    "    Simulates data and calculates estimators under different settings.\n",
    "    \n",
    "    Args:\n",
    "      m: Number of simulations.\n",
    "      \n",
    "    Returns:\n",
    "      A namedtuple containing simulation results:\n",
    "          value_hat: Matrix of estimated values (N / D) for different settings.\n",
    "          Nhats: Matrix of N estimates for different settings.\n",
    "          Dhats: Matrix of D estimates for different settings.\n",
    "    \"\"\"\n",
    "    \n",
    "    # Define namedtuple for results\n",
    "    Result = namedtuple('Result', ['value_hat', 'decision'])\n",
    "    \n",
    "    # Initialize results matrices\n",
    "    value_hat = np.zeros((len(Ns), 3+1))\n",
    "    decisions = []\n",
    "\n",
    "    beta_ = []\n",
    "    for i in range(len(Ns)):\n",
    "        # Simulate data\n",
    "        N = Ns[i]\n",
    "\n",
    "        fpath = f\"./simdata/[2]data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        if os.path.exists(fpath):\n",
    "            dat = read_pkl(fpath)\n",
    "        else:\n",
    "            dat = genData(N, eta2)\n",
    "            write_pkl(dat, fpath)\n",
    "        \n",
    "        # Separate observed data and all possible outcomes\n",
    "        XX_obs = dat['obs']\n",
    "        XX_all = dat['all']\n",
    "        XX_as = dat['as']\n",
    "        \n",
    "        # Fit nuisance models\n",
    "        models = fit_models_aipcw(XX_obs)  # Replace ... with arguments\n",
    "\n",
    "        # with ThreadPool(processes=mp.cpu_count()-1) as pool:\n",
    "        res_aipw = sp.optimize.differential_evolution(obj_aipw, bounds=[(-1., 1.)]*5,\n",
    "                                        args=(reg1, reg2, XX_obs, models),\n",
    "                                        # workers=pool.map,\n",
    "                                      # x0=np.random.uniform(high=.9, low=-.3, size=5),\n",
    "                                        x0=np.zeros(5),\n",
    "                                      # method=\"SLSQP\",\n",
    "                                        constraints=[norm1const1, norm1const2])\n",
    "        \n",
    "        beta_.append([res_aipw.x])\n",
    "\n",
    "    # return Result(value_hat, decisions)\n",
    "    return beta_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4daaf672-e79e-4d23-96e4-7e0715f56e53",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_opt_aipw = []\n",
    "for m in tqdm(range(M)):\n",
    "    try:\n",
    "        fpath = f\"./simresult/aipw_opt_sim2_{m:03d}.pkl\"\n",
    "        beta_opt_m = read_pkl(fpath)\n",
    "        # beta_opt_m = [x[0] for x in beta_opt_m]  # only MR. no PG result\n",
    "        beta_opt_aipw.append(beta_opt_m)\n",
    "    except:\n",
    "        print(m, end=', ')\n",
    "\n",
    "beta_opt_aipw = np.array(beta_opt_aipw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef179d01-fb74-43c7-9a13-7d2b5649f3c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vhataipw_opl_single = []\n",
    "\n",
    "for N in Ns:\n",
    "    for m in tqdm(range(M)):\n",
    "        # load data and beta_opt\n",
    "        fpath = f\"./simdata/[2]data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        dat_obs = read_pkl(fpath)['obs']\n",
    "        beta_opt_m = beta_opt_aipw[m]\n",
    "        \n",
    "        # decision rule params\n",
    "        if N == Ns[0]:\n",
    "            beta_aipw_dat = beta_opt_m[0][0]  # N==1000\n",
    "        elif N == Ns[1]:\n",
    "            beta_aipw_dat = beta_opt_m[1][0]  # N==2000\n",
    "        else:\n",
    "            beta_aipw_dat = beta_opt_m[2][0]  # N==5000\n",
    "        \n",
    "        # fixed decisions\n",
    "        d1aipw_dat = reg1(*beta_aipw_dat[:2], dat_obs)\n",
    "        d2aipw_dat = reg2(*beta_aipw_dat[2:], dat_obs)\n",
    "        \n",
    "        # Fit nuisance models\n",
    "        model_dat = fit_models_aipcw(dat_obs)  # Replace ... with arguments\n",
    "        \n",
    "        vaipw_dat = V_aipcw(d1aipw_dat, d2aipw_dat, dat_obs, model_dat)\n",
    "        \n",
    "        Vhataipw_opl_single.append([vaipw_dat, None])\n",
    "\n",
    "Vhataipw_opl_single = np.array(Vhataipw_opl_single).reshape(len(Ns), M, -1)  # Ns, M, mrpg\n",
    "write_pkl(Vhataipw_opl_single, \"Vhat_aipw_opl_single_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db2069d8-5507-4344-bcf0-93ad83980a5f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Vaipw_opl_single = []\n",
    "\n",
    "dat_obs = df_large['obs'].iloc[:100000, :]\n",
    "tm10k = true_models(dat_obs)\n",
    "\n",
    "for N in Ns:\n",
    "    for m in tqdm(range(M)):\n",
    "        # load data and beta_opt\n",
    "        # fpath = f\"./simdata/[2]data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        # dat_obs = read_pkl(fpath)['obs']\n",
    "        beta_opt_m = beta_opt_aipw[m]\n",
    "        \n",
    "        # decision rule params\n",
    "        if N == Ns[0]:\n",
    "            beta_aipw_dat = beta_opt_m[0][0]  # N==1000\n",
    "        elif N == Ns[1]:\n",
    "            beta_aipw_dat = beta_opt_m[1][0]  # N==2000\n",
    "        else:\n",
    "            beta_aipw_dat = beta_opt_m[2][0]  # N==5000\n",
    "        \n",
    "        # fixed decisions\n",
    "        d1aipw_dat = reg1(*beta_aipw_dat[:2], dat_obs)\n",
    "        d2aipw_dat = reg2(*beta_aipw_dat[2:], dat_obs)\n",
    "        \n",
    "        vaipw_dat = V_plugin(d1aipw_dat, d2aipw_dat, dat_obs, tm10k)\n",
    "        \n",
    "        Vaipw_opl_single.append([vaipw_dat, None])\n",
    "\n",
    "\n",
    "Vaipw_opl_single = np.array(Vaipw_opl_single).reshape(len(Ns), M, -1)  # Ns, M, mrpg\n",
    "write_pkl(Vaipw_opl_single, \"V_aipw_opl_single_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0b84512-a4bb-4933-bf7d-ef02a253e42c",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vhat_opl_single = read_pkl(\"Vhat_opl_single_sim2.pkl\")\n",
    "V_opl_single = read_pkl(\"V_opl_single_sim2.pkl\")\n",
    "# Vs_true_opt = write_pkl(\"Vs_true_opt_largesample_sim2.pkl\")\n",
    "Vhataipw_opl_single = read_pkl(\"Vhat_aipw_opl_single_sim2.pkl\")\n",
    "Vaipw_opl_single = read_pkl(\"V_aipw_opl_single_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6132ca61-b711-4f58-a9b2-1171d5d6bb57",
   "metadata": {},
   "outputs": [],
   "source": [
    "pi_opt = read_pkl(\"linear_opt_rule_sim2.pkl\")\n",
    "d1_opt = reg1(*pi_opt[:2], df_large[\"obs\"])\n",
    "d2_opt = reg2(*pi_opt[2:], df_large[\"obs\"])\n",
    "V_true_id_opt = V_plugin(d1_opt, d2_opt, df_large[\"obs\"], tm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9784fc18-2f43-49eb-b58f-ee02f3de3813",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vall_opl_single = np.c_[Vhat_opl_single[:,:,[0]], Vhataipw_opl_single[:,:,[0]],\n",
    "                        V_opl_single[:,:,[0]], Vaipw_opl_single[:,:,[0]]]\n",
    "Vall_opl_single.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f8131ca-147b-43d7-ade8-c18ffe9f2ad1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(4.3,2.2), dpi=1200)\n",
    "\n",
    "ax2 = plt.subplot(121)\n",
    "stylize_axes(ax2)\n",
    "plt.boxplot(\n",
    "    Vall_opl_single[1,:,:], tick_labels=['MR\\n\\t  $\\\\widehat V$', 'AIPW', \n",
    "                                         'MR\\n\\t  $V$', 'AIPW'],\n",
    "    sym='x', flierprops={'markersize': 3, 'markeredgewidth': .25}\n",
    ")\n",
    "# plt.xticks(rotation=-90)\n",
    "plt.ylim(0.25, 1)\n",
    "plt.title(f\"N={Ns[1]}\")\n",
    "plt.axhline(V_true_id_opt, color='r', ls=\"-\", lw=.75)\n",
    "plt.axvline(2.5, c='k', lw=.5, alpha=.75)\n",
    "# plt.axhline(V_true, color='r', ls=\"-\", lw=.75, label='True V');\n",
    "# plt.axvline(7.5, color='k', ls=\"-\", lw=1);\n",
    "plt.ylabel(\"Always-survivor value\")\n",
    "\n",
    "ax3 = plt.subplot(122)\n",
    "stylize_axes(ax3)\n",
    "plt.boxplot(\n",
    "    Vall_opl_single[2,:,:], tick_labels=['MR\\n\\t  $\\\\widehat V$', 'AIPW', \n",
    "                                         'MR\\n\\t  $V$', 'AIPW'],\n",
    "    sym='x', flierprops={'markersize': 3, 'markeredgewidth': .25}\n",
    ")\n",
    "# plt.xticks(rotation=-90)\n",
    "# plt.ylim(V_true*0.05, V_true*5)\n",
    "plt.ylim(0.25, 1)\n",
    "plt.yticks([], [])\n",
    "plt.title(f\"N={Ns[2]}\")\n",
    "plt.axhline(V_true_id_opt, color='r', ls=\"-\", lw=.75)\n",
    "plt.axvline(2.5, c='k', lw=.5, alpha=.75)\n",
    "\n",
    "# plt.suptitle(\"$\\\\hat V$ & $V$\")\n",
    "plt.tight_layout();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65aae2c4-1ef9-4f3c-a330-2d74b6b718c4",
   "metadata": {},
   "source": [
    "### PCD-AS of AIPW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7ceddd1-d6fe-400d-a9e6-95d77a41d3e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "dd = read_pkl(\"large_obs_PCD_sim2.pkl\")\n",
    "dd_as_obs = read_pkl(\"large_as_obs_PCD_sim2.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bacacc29-5ca2-4c20-bad5-f9ea0ef53444",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "```python\n",
    "pcd_aipw_all = []\n",
    "# pcd_pg_all = []\n",
    "\n",
    "tm_dd = true_models(data=dd['obs'])\n",
    "\n",
    "for m in tqdm(range(M)):\n",
    "    beta_opt_m = beta_opt_aipw[m]\n",
    "    \n",
    "    d1dh = reg1(*pi_opt[:2], dd['obs'])\n",
    "    d2dh = reg2(*pi_opt[2:], dd['obs'])\n",
    "    \n",
    "    d1aipw1k = reg1(*beta_opt_m[0][0][:2], dd['obs'])\n",
    "    d2aipw1k = reg2(*beta_opt_m[0][0][2:], dd['obs'])\n",
    "    \n",
    "    d1aipw2k = reg1(*beta_opt_m[1][0][:2], dd['obs'])\n",
    "    d2aipw2k = reg2(*beta_opt_m[1][0][2:], dd['obs'])\n",
    "\n",
    "    d1aipw5k = reg1(*beta_opt_m[2][0][:2], dd['obs'])\n",
    "    d2aipw5k = reg2(*beta_opt_m[2][0][2:], dd['obs'])\n",
    "    \n",
    "    pcd_aipw = (\n",
    "        # ((d1mr1k == d1dh) * (d2mr1k == d2dh)).mean(),\n",
    "        # ((d1mr2k == d1dh) * (d2mr2k == d2dh)).mean()\n",
    "        PCD_AS(d1aipw1k, d2aipw1k, d1dh, d2dh, dd['obs'], tm_dd),\n",
    "        PCD_AS(d1aipw2k, d2aipw2k, d1dh, d2dh, dd['obs'], tm_dd),\n",
    "        PCD_AS(d1aipw5k, d2aipw5k, d1dh, d2dh, dd['obs'], tm_dd)\n",
    "    )\n",
    "    # pcd_pg = (\n",
    "    #     # ((d1pg1k == d1dh) * (d2pg1k == d2dh)).mean(),\n",
    "    #     # ((d1pg2k == d1dh) * (d2pg2k == d2dh)).mean()\n",
    "    #     PCD_AS(d1pg1k, d2pg1k, d1dh, d2dh, dd['obs'], tm_dd),\n",
    "    #     PCD_AS(d1pg2k, d2pg2k, d1dh, d2dh, dd['obs'], tm_dd)\n",
    "    # )\n",
    "    \n",
    "    pcd_aipw_all.append(pcd_aipw)\n",
    "    # pcd_pg_all.append(pcd_pg)\n",
    "    \n",
    "pcd_aipw_all = np.array(pcd_aipw_all)\n",
    "# pcd_pg_all = np.array(pcd_pg_all)\n",
    "\n",
    "write_pkl(pcd_aipw_all, 'pcd_aipw_all_sim2.pkl')\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "047e9565-4b57-466e-b87b-9595696ad63a",
   "metadata": {},
   "outputs": [],
   "source": [
    "pcd_1k_prod = np.c_[\n",
    "    pcd_aipw_all[:, 0],\n",
    "    # pcd_pg_all[:, 0],\n",
    "    # pcd_aipw_15[:,0,:].prod(1),\n",
    "    # pcd_qlrn_15[:,0,:].prod(1)])\n",
    "]\n",
    "pcd_2k_prod = np.c_[\n",
    "    pcd_aipw_all[:, 1],\n",
    "    # pcd_pg_all[:, 1],\n",
    "    # pcd_aipw_15[:,1,:].prod(1),\n",
    "    # pcd_qlrn_15[:,1,:].prod(1)])\n",
    "]\n",
    "pcd_5k_prod = np.c_[\n",
    "    pcd_aipw_all[:, 2]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c47cac1c-1c99-470e-a523-a42486cbdb1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pcd_1k = pd.DataFrame(pcd_1k_prod, columns=[r\"$\\phi$\"])#, \"AIPW\", \"Q-learning\"])\n",
    "df_pcd_2k = pd.DataFrame(pcd_2k_prod, columns=[r\"$\\phi$\"])#, \"AIPW\", \"Q-learning\"])\n",
    "df_pcd_5k = pd.DataFrame(pcd_5k_prod, columns=[r\"$\\phi$\"])#, \"AIPW\", \"Q-learning\"])\n",
    "# df_pcd_1k = pd.melt(df_pcd_1k)\n",
    "# df_pcd_2k = pd.melt(df_pcd_2k)\n",
    "# df_pcd_1k.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae0a97b9-376f-4ba2-8439-da8d5cbdf275",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pcd = np.hstack([df_pcd_1k, df_pcd_2k, df_pcd_5k])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bafaf5d-9382-482d-930b-323b3715cedf",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(3,4))\n",
    "# sns.stripplot(df_pcd, size=3, alpha=.05, color='k')\n",
    "plt.boxplot(df_pcd, positions=range(3),\n",
    "            flierprops={'marker': 'x', 'markersize': 3, \n",
    "                        'markeredgewidth': .5, 'alpha': .7})\n",
    "plt.xticks(range(3), Ns)\n",
    "plt.title(\"Percentage of Correct Decisions \\non Always-Survivors\")\n",
    "# plt.ylim(0.895, 1.005)\n",
    "plt.xlabel(\"Training N\")\n",
    "plt.ylabel(\"%\")\n",
    "plt.yticks(np.arange(0.9, 1.01, 0.02), range(90, 101, 2));"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52289783-7bac-4379-990e-abc73ba3e4fc",
   "metadata": {},
   "source": [
    "Tested on data of size 100,000."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5861a6b4-83ee-47cb-b197-63aeda350a1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pcd.mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e54a6fa-fb54-4835-b915-61323793f506",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pcd.std(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53fa1fc6-c06b-4381-bc2a-9b64ea47f2a7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91281ac3-838a-4e8b-8b2b-27ad78ac1737",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8bae516-7d7c-4987-ac1d-d09711fe61e4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18645ef0-c4d3-4065-8b09-fdfed90992ab",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "061376ff-e586-4014-ac13-9925b8c83653",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {
    "height": "328px",
    "width": "477px"
   },
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "231022) Survivor-optimal DynTrtRegime",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "214.578px"
   },
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
