{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed26117a",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pyreadr\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import pandas as pd\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "# import rpy2.robjects as robjects\n",
    "import statsmodels.api as sm\n",
    "import statsmodels.formula.api as smf\n",
    "import multiprocess as mp\n",
    "from collections import namedtuple\n",
    "from multiprocess.pool import ThreadPool\n",
    "from sklearn.dummy import DummyRegressor\n",
    "from scipy.stats import norm\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48492432-fcde-4e30-8518-efa3daa83fef",
   "metadata": {},
   "outputs": [],
   "source": [
    "import signal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "145b6a9e-8689-4e6c-8b66-e6a4c8bfdeb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from statsmodels.gam.api import BSplines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b88d68f3-8893-45e9-9c7e-b26ea524021c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9c35f59-fce8-4ef0-b9e9-c7fac978f2c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dill\n",
    "\n",
    "def write_pkl(obj, fname):\n",
    "    with open(fname, 'wb') as wb:\n",
    "        dill.dump(obj, wb)\n",
    "    return\n",
    "\n",
    "def read_pkl(fname):\n",
    "    with open(fname, 'rb') as rb:\n",
    "        obj = dill.load(rb)\n",
    "    return obj"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70d576a5",
   "metadata": {},
   "source": [
    "<!--$$\\require{cancel}$$-->\n",
    "# Simulation scheme"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9dd25fae",
   "metadata": {},
   "source": [
    "## Data generating process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22aec556",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "np.random.seed(0)\n",
    "N = 10_000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "851e4cc8",
   "metadata": {},
   "source": [
    "Predefine true parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f5cef8",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "# stage 1\n",
    "beta_phi1 = np.array([0.3, 0.2])\n",
    "beta_K1 = np.array([0, 1, 1])\n",
    "beta_p1 = np.array([0, 5, 3, 0.5])\n",
    "beta_mu1 = np.array([0.2, 0.3, 1.5, 0.75, 0])#0.2])\n",
    "# stage 2\n",
    "beta_phi2 = np.array([0.7, 0.2, -0.2, -0.1])\n",
    "beta_K2 = np.array([-3, 1, 1, 0.5, 1])\n",
    "beta_p2 = np.array([ 0.8 , -1.42,  0.8 , -0.65])\n",
    "beta_mu2 = np.array([2.58, -1.04, 1.21, -0.92, 2.27, 1.18, 3.29, 3.95])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdde0bcc",
   "metadata": {},
   "source": [
    "### First stage\n",
    "* Baseline covariate\n",
    "    * $X_1 \\sim \\text{Uniform}(-.3, .7)$ \n",
    "    <!-- * $X_1 \\sim \\text{Bernoulli}(0.75)$ -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56a50760",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "X1 = np.random.uniform(low=-.3, high=.7, size=N)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f254c309",
   "metadata": {},
   "source": [
    "* Propensity\n",
    "    * $A_1 \\sim \\text{Bernoulli}(\\varphi(X_1))$ where $\\varphi_1(X_1) := \\mathbb P(A_1=1 | X_1)$ is the propensity score for $A_1$\n",
    "    * $\\text{logit}\\left( \\varphi_1(X_1) \\right) = 0.3 + 0.2X_1$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07f27cc0-90ab-44a6-9d3d-d7d3772d01ad",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def expit(x):\n",
    "  \"\"\"Inverse logit function (sigmoid function).\n",
    "\n",
    "  Args:\n",
    "    x: A numeric value or NumPy array.\n",
    "\n",
    "  Returns:\n",
    "    The inverse logit of x.\n",
    "  \"\"\"\n",
    "  return np.exp(x) / (1 + np.exp(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25cb898a-b809-4d5e-9d1d-9d2268c19168",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi1(x1):\n",
    "    \"\"\"Propensity score for A1.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    beta_phi1: A NumPy array of coefficients.\n",
    "    \n",
    "    Returns:\n",
    "    The propensity score for A1.\n",
    "    \"\"\"\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    return expit(np.c_[np.ones(len(x1)), x1] @ beta_phi1)\n",
    "\n",
    "def genA1(x1):\n",
    "    \"\"\"Generate A1.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    beta_phi1: A NumPy array of coefficients.\n",
    "    \n",
    "    Returns:\n",
    "    A binary array of A1 values.\n",
    "    \"\"\"\n",
    "    return np.random.binomial(1, phi1(x1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5643160a-59d8-40e9-96af-df86ebbabcb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "A1 = genA1(X1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad9d8d33",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "print(A1.mean())  # about 61%\n",
    "plt.hist(A1, density=True);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b707fb0",
   "metadata": {},
   "source": [
    "* Censoring\n",
    "    * $C_1 \\sim \\text{Bernoulli}(K_1^{A_1}( X_1))$ where $K_1^{A_1}( X_1) := \\mathbb P(C_1=0 |  X_1, A_1)$ is the non-censoriong probability for $C_1$\n",
    "    * $\\text{logit}\\left( K_1^{A_1}( X_1) \\right) = X_1 + A_1 + \\eta_1$ for fixed $\\eta_1 = 3$\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "611a197d-b5cd-41f7-9eb0-2275c1cce845",
   "metadata": {},
   "outputs": [],
   "source": [
    "eta1 = 3\n",
    "\n",
    "def K1(x1, a1):\n",
    "    \"\"\"Censoring probability for C1.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    a1: A numeric value or NumPy array.\n",
    "    \n",
    "    Returns:\n",
    "    The censoring probability for C1.\n",
    "    \"\"\"\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(a1, int):\n",
    "        a1 = np.repeat(a1, len(x1))\n",
    "    return expit(np.c_[np.ones(len(x1)), x1] @ beta_K1[:-1] + eta1)\n",
    "    # return expit(np.c_[np.ones(len(x1)), x1, a1] @ beta_K1 + eta1)\n",
    "\n",
    "def genC1(x1, a1):\n",
    "    \"\"\"Generate C1.\n",
    "  \n",
    "    Args:\n",
    "      x1: A numeric value or NumPy array.\n",
    "      a1: A numeric value or NumPy array.\n",
    "  \n",
    "    Returns:\n",
    "      A binary array of C1 values.\n",
    "    \"\"\"\n",
    "    return np.random.binomial(1, 1 - K1(x1, a1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5aa279fd",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def genC1a(x1):\n",
    "    \"\"\"Generate C1^a1.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    genC1_func: A function that generates C1 values.\n",
    "    \n",
    "    Returns:\n",
    "    A NumPy array with two columns, C1_0 and C1_1.\n",
    "    \"\"\"\n",
    "    c1_0 = genC1(x1, 0)\n",
    "    c1_1 = genC1(x1, 1)\n",
    "    c1a = np.c_[c1_0, c1_1]\n",
    "    return c1a\n",
    "\n",
    "C1a = genC1a(X1)\n",
    "C1_0 = C1a[:, 0]\n",
    "C1_1 = C1a[:, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "653e0a7a",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "print( np.mean(C1_0) )  # about 4%\n",
    "print( np.mean(C1_1) )  # about 4%\n",
    "plt.hist(C1_0, density=True, alpha=.5) \n",
    "plt.hist(C1_1, density=True, alpha=.5);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ef74b87",
   "metadata": {},
   "source": [
    "* Survival\n",
    "    * $S_1 \\sim \\text{Bernoulli}(p_1^{A_1}( X_1))$ where $p_1^{A_1}( X_1) := \\mathbb P(S_1=0 |  X_1, A_1, C_1=0)$ is the survival probability for $S_1$\n",
    "    * $\\text{logit}\\left( p_1^{A_1}( X_1) \\right) = 5X_1 + 3 A_1 + 0.5 A_1X_1$\n",
    "    * Monotonicity satisfied: $S_{1i}^1 \\ge S_{1i}^0,~ \\forall i=1,\\cdots,N.$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdf521b9",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def p1(x1, a1):\n",
    "    \"\"\"Survival probability for S1.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    a1: A numeric value or NumPy array.\n",
    "    \n",
    "    Returns:\n",
    "    The survival probability for S1.\n",
    "    \"\"\"\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(a1, int):\n",
    "        a1 = np.repeat(a1, len(x1))\n",
    "    return expit(np.c_[np.ones(len(x1)), x1, a1, a1 * x1] @ beta_p1)\n",
    "\n",
    "def genS1(x1, a1):\n",
    "    \"\"\"Generate S1.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    a1: A numeric value or NumPy array.\n",
    "    \n",
    "    Returns:\n",
    "    A binary array of S1 values.\n",
    "    \"\"\"\n",
    "    return np.random.binomial(1, p1(x1, a1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34964aed",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def genS1a(x1):\n",
    "    \"\"\"Generate S1^{a1} for all a1 in {0, 1}.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    \n",
    "    Returns:\n",
    "    A NumPy array with two columns, S1_0 and S1_1.\n",
    "    \"\"\"\n",
    "    s1_0 = genS1(x1, 0)\n",
    "    s1_1 = genS1(x1, 1)\n",
    "    \n",
    "    # # While loop to ensure S1_0 <= S1_1\n",
    "    # while not np.all(s1_0 <= s1_1):\n",
    "    #    indices = np.where(s1_0 > s1_1)[0]\n",
    "    #    s1_0[indices] = genS1(x1[indices], 0)\n",
    "    #    s1_1[indices] = genS1(x1[indices], 1)\n",
    "    \n",
    "    s1a = np.c_[s1_0, s1_1]\n",
    "    return s1a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "351c2958",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "S1a = genS1a(X1)\n",
    "S1_0 = S1a[:, 0]\n",
    "S1_1 = S1a[:, 1]\n",
    "\n",
    "print( np.mean(S1_0) )  # about 75%\n",
    "print( np.mean(S1_1) )  # about 97%\n",
    "\n",
    "print( \"monotonicity?\", np.mean(S1_0 <= S1_1) )  # monotonicity probabilistically?\n",
    "\n",
    "plt.figure(figsize=(4,3))\n",
    "plt.stem(S1a.mean(0))\n",
    "plt.xticks(range(2), [\"S1_0\", \"S1_1\"])\n",
    "plt.ylabel(\"probability\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10007cbb",
   "metadata": {},
   "source": [
    "* Outcome: \n",
    "    * $X_{2} | \\{X_1, A_1\\} \\sim \\mathcal N \\left(0.2 + 0.3X_{1} + 1.5A_1 + 0.75A_1X_{1},\\; 1.5^2\\right)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f37ede9-aeed-46d9-8a73-b0f0ac8bfd29",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_mu1[:-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5ac562e",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "#beta_mu1 = c(0.2, 0.3, 1.5, 0.75, 0)\n",
    "\n",
    "def mu1(x1, a1, s1):\n",
    "    \"\"\"Intermediate outcome.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    a1: A numeric value or NumPy array.\n",
    "    s1: A numeric value or NumPy array.\n",
    "    \n",
    "    Returns:\n",
    "    The intermediate outcome.\n",
    "    \"\"\"\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(a1, int):\n",
    "        a1 = np.repeat(a1, len(x1))\n",
    "    if isinstance(s1, int):\n",
    "        s1 = np.repeat(s1, len(x1))\n",
    "    # return expit(np.c_[np.ones(len(x1)), x1, a1, a1 * x1] @ beta_mu1[:-1])\n",
    "    return np.c_[np.ones(len(x1)), x1, a1, a1 * x1] @ beta_mu1[:-1]\n",
    "\n",
    "def genX2(x1, a1, s1):\n",
    "    \"\"\"Generate X2.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    a1: A numeric value or NumPy array.\n",
    "    s1: A numeric value or NumPy array.\n",
    "    \n",
    "    Returns:\n",
    "    A NumPy array of X2 values.\n",
    "    \"\"\"\n",
    "    # return np.random.binomial(1, mu1(x1, a1, s1))\n",
    "    return np.random.normal(loc=mu1(x1, a1, s1), scale=1.5, size=len(x1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3c77166",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def genX2a(x1, s1a):\n",
    "    \"\"\"Generate X2^a1 for all a1 in {0, 1}.\n",
    "    \n",
    "    Args:\n",
    "    x1: A numeric value or NumPy array.\n",
    "    s1a: A NumPy array with two columns, S1_0 and S1_1.\n",
    "    \n",
    "    Returns:\n",
    "    A NumPy array with two columns, X2_0 and X2_1.\n",
    "    \"\"\"\n",
    "    x2_0 = genX2(x1, 0, s1a[:, 0])\n",
    "    x2_1 = genX2(x1, 1, s1a[:, 1])\n",
    "    x2a = np.c_[x2_0, x2_1]\n",
    "    return x2a\n",
    "\n",
    "X2a = genX2a(X1, S1a)\n",
    "X2_0 = X2a[:, 0]\n",
    "X2_1 = X2a[:, 1]\n",
    "# X2[C1 == 1 | S1 == 0] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4e1f995-e291-4b24-b273-e67dcac02207",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(4,3))\n",
    "sns.kdeplot(X2a);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "524b5adc",
   "metadata": {},
   "source": [
    "### Second stage\n",
    "\n",
    "* Propensity\n",
    "    * $A_2 \\sim \\text{Bernoulli}(\\varphi_2(\\overline{ X}_2), A_1)$ where $\\varphi_2(\\overline{ X}_2, A_1) := \\mathbb P(A_2=1 | \\overline{ X}_2, A_1, C_1=0, S_1=1)$ is the propensity score for $A_2$\n",
    "    * $\\text{logit}\\left( \\varphi_2(\\overline{ X}_2, A_1) \\right) = 0.7 + 0.2X_{1} - 0.2X_{2} -0.1 X_2^2$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "701d327e-6fcd-4299-b5df-94f5ffb61f0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi2(x1, x2):\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(x2, int):\n",
    "        x2 = np.array([x2])\n",
    "    return expit(np.c_[np.ones(len(x1)), x1, x2, x2**2] @ beta_phi2)\n",
    "\n",
    "def genA2(x1, x2):\n",
    "    return np.random.binomial(1, phi2(x1, x2))\n",
    "\n",
    "def genA2a(x1, x2a):\n",
    "    x2_0 = x2a[:, 0]\n",
    "    x2_1 = x2a[:, 1]\n",
    "    a2_0 = genA2(x1, x2_0)\n",
    "    a2_1 = genA2(x1, x2_1)\n",
    "    a2a = np.c_[a2_0, a2_1]\n",
    "    return a2a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "059c5cbd-4032-44db-924b-eb0bb542d369",
   "metadata": {},
   "outputs": [],
   "source": [
    "A2a = genA2a(X1, X2a)\n",
    "A2_0 = A2a[:, 0]\n",
    "A2_1 = A2a[:, 1]\n",
    "# A2[C1 == 1 | S1 == 0] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e205d97a-8613-4699-88cb-d4810be42b45",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(4,3))\n",
    "plt.stem(A2a.mean(0))\n",
    "plt.xticks(range(2), [\"A2_0\", \"A2_1\"])\n",
    "plt.ylabel(\"probability\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9435ea41-119c-42df-bd43-f5c02d5f4cfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    A1.mean(),\n",
    "    A2_0[A1==0].mean(),\n",
    "    A2_1[A1==1].mean()\n",
    ")  # balanced."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6200984e",
   "metadata": {},
   "source": [
    "* Censor\n",
    "    * $C_2 \\sim \\text{Bernoulli}(K_2^{A_1A_2}(\\overline{ X}_2))$ where $K_2^{A_1A_2}(\\overline{ X}_2) := \\mathbb P(C_2=0 | \\overline{ X}_2, \\overline A_2, C_1=0, S_1=1)$ is the non-censoriong probability for $C_2$\n",
    "    * $\\text{logit}\\left( K_2^{A_1A_2}( \\overline X_2) \\right) = -3 + X_1 + X_2 + 0.5 A_2 + A_2X_2 + \\eta_2$ for fixed $\\eta_2=5$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "193b7c3a-1c0d-45c2-8620-87eee0730d4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "eta2 = 5  # Adjust eta2 for different censoring rates\n",
    "# eta2 = 5 : 15% censoring\n",
    "#      = 3 : 20%\n",
    "#      = 0 : 50%\n",
    "\n",
    "# beta_K2 = np.array([-3, 1, 1, 0.5, 1])\n",
    "\n",
    "def K2(x1, x2, a1, a2, eta2=5):\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(x2, int):\n",
    "        x2 = np.array([x2])\n",
    "    if isinstance(a1, int):\n",
    "        a1 = np.repeat(a1, len(x1))\n",
    "    if isinstance(a2, int):\n",
    "        a2 = np.repeat(a2, len(x1))\n",
    "    # return expit(np.c_[np.ones(len(x1)), x1, x2, a1] @ beta_K2[:-1] + eta2)\n",
    "    return expit(np.c_[np.ones(len(x1)), x1, x2, a2, a2 * x2] @ beta_K2 + eta2)\n",
    "\n",
    "def genC2(x1, x2, a1, a2, eta2=5):\n",
    "    return np.random.binomial(1, 1 - K2(x1, x2, a1, a2, eta2))\n",
    "\n",
    "def genC2a(x1, x2a, c1a, eta2=5):\n",
    "    x2_0 = x2a[:, 0]\n",
    "    x2_1 = x2a[:, 1]\n",
    "\n",
    "    c1_0 = c1a[:, 0]\n",
    "    c1_1 = c1a[:, 1]\n",
    "    c2_00 = genC2(x1, x2_0, 0, 0, eta2)\n",
    "    c2_01 = genC2(x1, x2_0, 0, 1, eta2)\n",
    "    c2_10 = genC2(x1, x2_1, 1, 0, eta2)\n",
    "    c2_11 = genC2(x1, x2_1, 1, 1, eta2)\n",
    "\n",
    "    # monotonicity btw C1 and C2\n",
    "    c2_00[c2_00 < c1_0] = 1\n",
    "    c2_01[c2_01 < c1_0] = 1\n",
    "    c2_10[c2_10 < c1_1] = 1\n",
    "    c2_11[c2_11 < c1_1] = 1\n",
    "    \n",
    "    c2a = np.c_[c2_00, c2_01, c2_10, c2_11]\n",
    "    return c2a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baf7d63c-376f-4efc-8437-54b2a84b0b70",
   "metadata": {},
   "outputs": [],
   "source": [
    "C2a = genC2a(X1, X2a, C1a)\n",
    "C2_00 = C2a[:, 0]\n",
    "C2_01 = C2a[:, 1]\n",
    "C2_10 = C2a[:, 2]\n",
    "C2_11 = C2a[:, 3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30281b01-029d-4c74-9f77-b2a016e435cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Mean of C2_00:\", np.mean(C2_00, where=~np.isnan(C2_00)))\n",
    "print(\"Mean of C2_01:\", np.mean(C2_01, where=~np.isnan(C2_01)))\n",
    "print(\"Mean of C2_10:\", np.mean(C2_10, where=~np.isnan(C2_10)))\n",
    "print(\"Mean of C2_11:\", np.mean(C2_11, where=~np.isnan(C2_11)))\n",
    "\n",
    "plt.figure(figsize=(4,3))\n",
    "plt.stem(C2a.mean(0))\n",
    "plt.xticks(range(4), [\"C2_00\", \"C2_01\", \"C2_10\", \"C2_11\"])\n",
    "plt.ylabel(\"probability\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36101959",
   "metadata": {},
   "source": [
    "* Survival\n",
    "    * $S_2 \\sim \\text{Bernoulli}(p_2^{A_1A_2}(\\overline{X}_2))$ where $p_2^{A_1A_2}(\\overline{X}_2) := \\mathbb P(S_2=1 | \\overline{X}_2, \\overline A_2, C_2=0, S_1=1)$ is the survival probability for $S_2$\n",
    "    * $\\text{logit}\\left( p_2^{A_1A_2}(\\overline{X}_2) \\right) = %0.5 + 2 X_1 + 0.5 X_2^2 + 2.5 A_2 - 0.5 A_2X_2$\n",
    "    * Monotonicity satisfied: $\\mathbb P(S_2 | \\overline X_2, A = (1,1)) \\ge \\mathbb P(S_2 | \\overline X_2, A = (0,1)), \\mathbb P(S_2 | \\overline X_2, A = (1,0)) \\ge \\mathbb P(S_2 | \\overline X_2, A = (0,0))$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd49d7ad-36ee-4559-97b9-bcf9af1648eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# beta_p2 = np.array([ 0.8 , -1.42,  0.8 , -0.65])\n",
    "\n",
    "def p2(x1, x2, a1, a2):\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(x2, int):\n",
    "        x2 = np.array([x2])\n",
    "    if isinstance(a1, int):\n",
    "        a1 = np.repeat(a1, len(x1))\n",
    "    if isinstance(a2, int):\n",
    "        a2 = np.repeat(a2, len(x1))\n",
    "    return expit(np.c_[np.ones(len(x1)), x1, a1, a2] @ beta_p2)\n",
    "\n",
    "def genS2(x1, x2, a1, a2):\n",
    "    return np.random.binomial(1, p2(x1, x2, a1, a2))\n",
    "\n",
    "def genS2a(x1, x2a, s1a):\n",
    "    x2_0 = x2a[:, 0]\n",
    "    x2_1 = x2a[:, 1]\n",
    "    \n",
    "    s1_0 = s1a[:, 0]\n",
    "    s1_1 = s1a[:, 1]\n",
    "    s2_00 = genS2(x1, x2_0, 0, 0)\n",
    "    s2_01 = genS2(x1, x2_0, 0, 1)\n",
    "    s2_10 = genS2(x1, x2_1, 1, 0)\n",
    "    s2_11 = genS2(x1, x2_1, 1, 1)\n",
    "\n",
    "    s2a = np.c_[s2_00, s2_01, s2_10, s2_11]\n",
    "    return s2a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd659bb0-ca35-4828-aee6-b71fd8bb4e9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "S2a = genS2a(X1, X2a, S1a)\n",
    "S2_00 = S2a[:, 0]\n",
    "S2_01 = S2a[:, 1]\n",
    "S2_10 = S2a[:, 2]\n",
    "S2_11 = S2a[:, 3]\n",
    "# S2[C1 == 1 | C2 == 1 | S1 == 0] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d2447bc-3242-4a22-81fc-6df23b803096",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Mean of S2_00:\", np.mean(S2_00))\n",
    "print(\"Mean of S2_01:\", np.mean(S2_01))\n",
    "print(\"Mean of S2_10:\", np.mean(S2_10))\n",
    "print(\"Mean of S2_11:\", np.mean(S2_11))\n",
    "\n",
    "plt.figure(figsize=(4,3))\n",
    "plt.stem(S2a.mean(0))\n",
    "plt.xticks(range(4), [\"S2_00\", \"S2_01\", \"S2_10\", \"S2_11\"])\n",
    "plt.ylabel(\"probability\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f47a650",
   "metadata": {},
   "source": [
    "* Outcome: \n",
    "    * $Y | \\{\\bar X, \\bar A, \\bar C, \\bar S\\} = \\mu_2^{A_1A_2}(\\overline{\\mathbf X}_2) + \\epsilon_2$\n",
    "    * $\\mu_2^{A_1A_2}(\\overline{\\mathbf X}_2) = %-1 + X_1 + 0.1A_1 -0.1 A_1X_1 + 0.2X_2 + A_2(0.4 +0.25 A_1 -0.1 X_2) + 4S_2$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a81ea5d-8626-4540-8fc2-f6711d8de6fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mu2(x1, x2, a1, a2):\n",
    "    if isinstance(x1, int):\n",
    "        x1 = np.array([x1])\n",
    "    if isinstance(x2, int):\n",
    "        x2 = np.array([x2])\n",
    "    if isinstance(a1, int):\n",
    "        a1 = np.repeat(a1, len(x1))\n",
    "    if isinstance(a2, int):\n",
    "        a2 = np.repeat(a2, len(x1))\n",
    "    return np.dot(np.c_[np.ones(len(x1)), x1, a1, a1*x1, x2, a2, a2*a1, a2*x2], beta_mu2)\n",
    "\n",
    "def genY(x1, x2, a1, a2):\n",
    "    return mu2(x1, x2, a1, a2) + np.random.normal(loc=0, scale=1.5, size=len(x1))\n",
    "\n",
    "def genYa(x1, x2a):\n",
    "    x2_0 = x2a[:, 0]\n",
    "    x2_1 = x2a[:, 1]\n",
    "\n",
    "    y_00 = genY(x1, x2_0, 0, 0)\n",
    "    y_01 = genY(x1, x2_0, 0, 1)\n",
    "    y_10 = genY(x1, x2_1, 1, 0)\n",
    "    y_11 = genY(x1, x2_1, 1, 1)\n",
    "\n",
    "    ya = np.column_stack((y_00, y_01, y_10, y_11))\n",
    "    return ya"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc6492de-2c36-49a5-be49-4cfeec8446ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "Ya = genYa(X1, X2a)\n",
    "Y_00 = Ya[:, 0]\n",
    "Y_01 = Ya[:, 1]\n",
    "Y_10 = Ya[:, 2]\n",
    "Y_11 = Ya[:, 3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d0eaac7-c524-409e-b345-ab50936fdd54",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.kdeplot(Y_00, color='red', label='Y_00')\n",
    "sns.kdeplot(Y_01, color='green', label='Y_01')\n",
    "sns.kdeplot(Y_10, color='blue', label='Y_10')\n",
    "sns.kdeplot(Y_11, color='purple', label='Y_11')\n",
    "plt.xlabel('Y')\n",
    "plt.ylabel('Density')\n",
    "# plt.ylim(0, 0.12)\n",
    "# plt.xlim(-15, 15)\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8008ce93",
   "metadata": {},
   "source": [
    "## Functions `follow.a1()` and `follow.a1a2()`\n",
    "* Return potential outcomes of a variable if it were to follow specific treatment assignment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffc8e9d2-d567-4a15-9654-8d7af055a985",
   "metadata": {},
   "outputs": [],
   "source": [
    "def follow_a1(a1, alldata, var=\"a2\"):\n",
    "    ind = np.vectorize(lambda x: f\"{var}_{int(x)}\")(a1)\n",
    "    return np.array([alldata.loc[alldata.index[i], ind[i]] for i in range(len(alldata))])\n",
    "\n",
    "def follow_a1a2(a1, a2, alldata, var=\"y\"):\n",
    "    ind = np.vectorize(lambda x, y: f\"{var}_{int(x)}{int(y)}\")(a1, a2)\n",
    "    return np.array([alldata.loc[alldata.index[i], ind[i]] for i in range(len(alldata))])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa11026f",
   "metadata": {},
   "source": [
    "## Function `genData()`\n",
    "* Generates all the potential outcomes in trajectories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ba9a504-0d44-482d-9d08-7930e3ce7b50",
   "metadata": {},
   "outputs": [],
   "source": [
    "def genData(N, eta2):\n",
    "    # Stage 1\n",
    "    x1 = np.random.uniform(low=-.3, high=.7, size=N)\n",
    "    #np.random.binomial(1, 0.75, size=N)\n",
    "    a1 = genA1(x1)\n",
    "    c1a = genC1a(x1)\n",
    "    s1a = genS1a(x1)\n",
    "    x2a = genX2a(x1, s1a)\n",
    "\n",
    "    # Stage 2\n",
    "    a2a = genA2a(x1, x2a)\n",
    "    c2a = genC2a(x1, x2a, c1a, eta2)\n",
    "    s2a = genS2a(x1, x2a, s1a)\n",
    "    ya = genYa(x1, x2a)\n",
    "\n",
    "    # All data\n",
    "    dfa = pd.DataFrame(np.c_[x1, x2a, a1, a2a, c1a, c2a, s1a, s2a, ya])\n",
    "    #print(dfa.head())\n",
    "    dfa.columns = ['x1', 'x2_0', 'x2_1', 'a1', 'a2_0', 'a2_1', \n",
    "                   'c1_0', 'c1_1', 'c2_00', 'c2_01', 'c2_10', 'c2_11',\n",
    "                   's1_0', 's1_1', 's2_00', 's2_01', 's2_10', 's2_11',\n",
    "                   'y_00', 'y_01', 'y_10', 'y_11']\n",
    "    \n",
    "    dfa[\"X1\"] = dfa[\"x1\"]\n",
    "    dfa[\"A1\"] = dfa[\"a1\"]\n",
    "    dfa[\"C1\"] = follow_a1(dfa.a1, dfa, var=\"c1\")\n",
    "    dfa[\"S1\"] = follow_a1(dfa.a1, dfa, var=\"s1\")\n",
    "    dfa[\"X2\"] = follow_a1(dfa.a1, dfa, var=\"x2\")\n",
    "    a2a1_a = follow_a1(dfa.a1, dfa, var=\"a2\")\n",
    "    dfa[\"A2\"] = a2a1_a\n",
    "    dfa[\"C2\"] = follow_a1a2(dfa.a1, a2a1_a, dfa, var=\"c2\")\n",
    "    dfa[\"S2\"] = follow_a1a2(dfa.a1, a2a1_a, dfa, var=\"s2\")\n",
    "    dfa[\"Y\"] = follow_a1a2(dfa.a1, a2a1_a, dfa, var=\"y\")\n",
    "    dfa.loc[dfa.C1==1, [\"S1\", \"X2\", \"A2\", \"S2\", \"Y\"]] = 0\n",
    "    dfa.loc[(~np.isnan(dfa.S1)) & (dfa.S1==0), [\"X2\", \"A2\", \"S2\", \"Y\"]] = 0\n",
    "    dfa.loc[(~np.isnan(dfa.C2)) & (dfa.C2==1), [\"S2\", \"Y\"]] = 0\n",
    "    dfa.loc[(~np.isnan(dfa.S2)) & (dfa.S2==0), \"Y\"] = 0\n",
    "\n",
    "    # Observed data\n",
    "    dfs = dfa.loc[:, [\"X1\", \"A1\", \"C1\", \"S1\",\n",
    "                      \"X2\", \"A2\", \"C2\", \"S2\", \"Y\"]].copy()\n",
    "    dfs.loc[dfs.C1==1, [\"S1\", \"X2\", \"A2\", \"C2\", \"S2\", \"Y\"]] = np.nan  # Set NA for C1==1 rows\n",
    "    dfs.loc[(~np.isnan(dfs.S1)) & (dfs.S1==0), [\"X2\", \"A2\", \"C2\", \"S2\", \"Y\"]] = np.nan  # Set NA for S1==0 rows\n",
    "    dfs.loc[(~np.isnan(dfs.C2)) & (dfs.C2==1), [\"S2\", \"Y\"]] = np.nan  # Set NA for C2==1 rows\n",
    "    dfs.loc[(~np.isnan(dfs.S2)) & (dfs.S2==0), \"Y\"] = np.nan  # Set NA for S2==0 rows\n",
    "\n",
    "    # Always survivors\n",
    "    dfas = dfa[s2a[:, 0] == 1]  # Filter for always survivors\n",
    "    \n",
    "    # Return data\n",
    "    return {\"all\": dfa, \"obs\": dfs, \"as\": dfas}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be237808-fc5d-49b3-8ed7-013c7c2571e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(0)\n",
    "\n",
    "pd.options.mode.chained_assignment = None\n",
    "\n",
    "data_15 = genData(N=10_000, eta2=5)\n",
    "XX_all15 = data_15[\"all\"]\n",
    "XX_obs15 = data_15[\"obs\"]\n",
    "XX_as15 = data_15[\"as\"]\n",
    "\n",
    "# print(XX_all15.head())\n",
    "print(XX_all15.shape)\n",
    "\n",
    "pd.options.mode.chained_assignment = 'warn'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69ec2596-8cd0-4e08-b743-e71bca3a1483",
   "metadata": {},
   "outputs": [],
   "source": [
    "XX_obs15.C1.mean(), XX_obs15.C2.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdec7f6c-a1d6-4f20-80b8-e03150fa2f96",
   "metadata": {},
   "outputs": [],
   "source": [
    "1-XX_obs15.S1.mean(), 1-XX_obs15.S2.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9454d409-3cf7-47ef-86af-bc704d84c227",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sum((XX_obs15.S2 == 1) & (~np.isnan(XX_obs15.S2))), \n",
    "      \"observed survivors / total\", XX_obs15.shape[0])\n",
    "print(XX_as15.shape[0], \n",
    "      \"always survivors / total\", XX_obs15.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee56883d-dfb2-4067-bd8f-4e1c22e469a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# linear decision rules\n",
    "def reg1(eta1, eta2, data):\n",
    "    criteria = eta1 + eta2 * data[\"X1\"] >= 0\n",
    "    d1 = np.array(criteria, dtype=int)\n",
    "    return d1\n",
    "\n",
    "def reg2(eta1, eta2, eta3, data):\n",
    "    criteria = eta1 + eta2 * data[\"X1\"] + eta3 * data[\"X2\"] >= 0\n",
    "    d2 = np.array(criteria, dtype=int)\n",
    "    return d2\n",
    "\n",
    "eta_true = np.array([-0.14,  0.99, 0, -0.707, 0.707])\n",
    "# eta_true = np.array([0.511, 0.303, 0.378, 0.046, 0.642])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4399aa1-0fbf-4d2e-9f03-6d98bc2fa378",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "df_large = genData(500_000, eta2=5)\n",
    "df_all = df_large[\"all\"]\n",
    "write_pkl(df_large, \"df_large_nonmono.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "7613183c-04d8-4ecc-889c-110771bb9c5f",
   "metadata": {},
   "source": [
    "df_large = read_pkl(\"df_large_nonmono.pkl\")\n",
    "df_all = df_large[\"all\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab42c75d-ccb5-4524-91c9-7149ee9922a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(sum((df_large[\"obs\"].S2 == 1) & (~np.isnan(df_large[\"obs\"].S2))), \n",
    "      \"observed survivors / total\", df_large[\"obs\"].shape[0])\n",
    "print(df_large[\"as\"].shape[0], \n",
    "      \"always survivors / total\", df_large[\"obs\"].shape[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7cad1b3-79a9-44c0-a241-94aa5d3d3b5e",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "# Estimators"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32b8005e",
   "metadata": {},
   "source": [
    "## Fit all the nuisance models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ffb4eac-7653-4497-96d8-0765dbe9ab24",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.options.mode.chained_assignment = None\n",
    "np.set_printoptions(3, suppress=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b5296a5-2604-4510-9b99-e6d72be63854",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_models(data,\n",
    "               phi1_false=False, phi2_false=False,\n",
    "               K1_false=False, K2_false=False,\n",
    "               p1_false=False, p2_false=False,\n",
    "               mu2_false=False, Ep2_false=False, \n",
    "               Emupi_false=False):\n",
    "    \"\"\"\n",
    "    Fits various nuisance models based on boolean flags.\n",
    "\n",
    "    Args:\n",
    "        data: A pandas DataFrame containing the observed data.\n",
    "        phi1_false, phi2_false, etc.: Boolean flags indicating whether to use\n",
    "            a point-five model instead of fitting the corresponding model.\n",
    "        Ep2_false, Emu2p2_false: (Currently unused)\n",
    "\n",
    "    Returns:\n",
    "        A dictionary containing fitted models and prediction functions.\n",
    "    \"\"\"\n",
    "\n",
    "    # Fit models based on flags\n",
    "    models = {}\n",
    "    if not phi1_false:\n",
    "        models['phi1.hat'] = smf.glm(\"A1 ~ 1 + X1\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['phi1.hat'] = smf.glm(\"A1 ~ 0 + X1 + I(X1**2)\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "        #None  # Placeholder for point-five model\n",
    "\n",
    "    if not phi2_false:\n",
    "        models['phi2.hat'] = smf.glm(\"A2 ~ 1 + X1 + X2 + I(X2**2)\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['phi2.hat'] = smf.glm(\"A2 ~ 1 + X2\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "        #None\n",
    "\n",
    "    if not K1_false:\n",
    "        models['K1.hat'] = smf.glm(\"I(1-C1) ~ 1 + X1\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['K1.hat'] = smf.glm(\"I(1-C1) ~ 0 + X1 + I(X1**2)\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "        #None\n",
    "\n",
    "    if not K2_false:\n",
    "        models['K2.hat'] = smf.glm(\"I(1-C2) ~ 1 + X1 + X2 + A2 + A2:X2\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['K2.hat'] = smf.glm(\"I(1-C2) ~ 1 + X1 + X2\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "        None\n",
    "\n",
    "    if not p1_false:\n",
    "        models['p1.hat'] = smf.glm(\"S1 ~ 1 + X1 + A1 + A1:X1\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['p1.hat'] = smf.glm(\"S1 ~ 0 + A1\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "        \n",
    "    if not p2_false:\n",
    "        models['p2.hat'] = smf.glm(\"S2 ~ 1 + X1 + A1 + A2\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['p2.hat'] = smf.glm(\"S2 ~ 0 + A1 + A2\",\n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "\n",
    "    if not mu2_false:\n",
    "        models['mu2.hat'] = smf.ols(\"Y ~ 1 + X1 + A1 + A1:X1 + X2 + A2 + A1:A2 + X2:A2\", \n",
    "                                    data=data).fit()\n",
    "    else:\n",
    "        models['mu2.hat'] = smf.ols(\"Y ~ 1 + X1 + X2\", \n",
    "                                    data=data).fit()\n",
    "        \n",
    "    # Define prediction functions for fitted models\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        \n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'],\n",
    "            'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        pred = models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def ps2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, len(data.X1))\n",
    "        \n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        pred = models['phi2.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a2)) & (a2 == 0)] = 1 - pred[(~np.isnan(a2)) & (a2 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        return models['K1.hat'].predict(fitdf).values\n",
    "\n",
    "    def cp2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['K2.hat'].predict(fitdf).values\n",
    "    \n",
    "    def sp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'A1': a1, \n",
    "        })\n",
    "        return models['p1.hat'].predict(fitdf).values\n",
    "    \n",
    "    def sp2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'X2': data.X2,\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['p2.hat'].predict(fitdf).values\n",
    "\n",
    "    def pcs1pcs2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2) * sp2(a1, a2)\n",
    "    \n",
    "    def pcs1pc2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    def pc1(a1):\n",
    "        return ps1(a1) * cp1(a1)\n",
    "    \n",
    "    def m2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['mu2.hat'].predict(fitdf).values\n",
    "\n",
    "    # Return dictionary with models and prediction functions\n",
    "    return {**models,  # Unpack models dictionary\n",
    "            'ps1': ps1, 'ps2': ps2,  # Prediction functions\n",
    "            'cp1': cp1, 'cp2': cp2,\n",
    "            'sp1': sp1, 'sp2': sp2, 'm2': m2,\n",
    "            'pcs1pcs2': pcs1pcs2, 'pcs1pc2': pcs1pc2, 'pcs1': pcs1, 'pc1': pc1,\n",
    "            'Ep2.false': Ep2_false, 'Emupi.false': Emupi_false}  # Currently unused"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e6285b4",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "models = fit_models(XX_obs15)\n",
    "m_123457 = fit_models(XX_obs15, mu2_false=True)\n",
    "m_12345 = fit_models(XX_obs15, mu2_false=True, Emupi_false=True)\n",
    "m_12367 = fit_models(XX_obs15, p2_false=True, Ep2_false=True)\n",
    "\n",
    "m_1234 = fit_models(XX_obs15, mu2_false=True)\n",
    "m_1236 = fit_models(XX_obs15, p2_false=True)\n",
    "\n",
    "m_1235 = fit_models(XX_obs15, Emupi_false=True)\n",
    "m_1237 = fit_models(XX_obs15, Ep2_false=True)\n",
    "\n",
    "m_4567 = fit_models(XX_obs15, phi1_false=True, K1_false=True,\n",
    "                    p1_false=True,\n",
    "                    phi2_false=True, K2_false=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57b5be2f",
   "metadata": {},
   "source": [
    "## Plug-in estimator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d76f2351-145c-4e07-b6b5-4c023ae5bda9",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "source": [
    "### plugin for D\n",
    "\n",
    "$$\n",
    "D = \\mathbb E \\bigg[ p_1^0(X_1) \\mathbb E\\big( p_2^0(H_1) \\big| X_1, A_1=0, C_1=0, S_1=1 \\big) \\bigg],\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "989af59c-9ef9-4386-a6de-04219704efdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def m_p200(data, models):\n",
    "    Ep2_false = models['Ep2.false']\n",
    "\n",
    "    def sp2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'X2': data.X2,\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['p2.hat'].predict(fitdf).values\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'],\n",
    "            'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        pred = models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        return models['K1.hat'].predict(fitdf).values\n",
    "\n",
    "    def sp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'A1': a1, \n",
    "        })\n",
    "        return models['p1.hat'].predict(fitdf).values\n",
    "        \n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "\n",
    "    # Conditional mean of p_2^0(H_1)\n",
    "    p200x = sp2(0, 0)\n",
    "    data_filled[\"target\"] = (data_filled.A1 == 0) * \\\n",
    "                            (data_filled.C1 == 0) * \\\n",
    "                            (data_filled.S1 == 1) / pcs1(0) * p200x  # Add target column\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    if not Ep2_false:\n",
    "        # # OLS\n",
    "        # m_p200_model = smf.ols(\"target ~ 1 + X1\", data=data_filtered).fit()\n",
    "        # return m_p200_model.predict(data_filled)\n",
    "\n",
    "        # GAM\n",
    "        bs = BSplines(data_filled.X1, df=[3+3+1], degree=[3])\n",
    "        m_p200_model = smf.glmgam(\"target ~ 1 + X1\", data=data_filled, smoother=bs).fit()\n",
    "        return m_p200_model.predict(data_filled, exog_smooth=bs.x)\n",
    "    else:\n",
    "        m_p200_model = smf.ols(\"target ~ 0 + X1\", data=data_filled).fit()\n",
    "        return m_p200_model.predict(data_filled)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73e45080-4831-4428-bbc6-f8fe195ce16f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def D_hat(data, models):\n",
    "    \"\"\"\n",
    "    Calculates the D-hat estimator based on data, fitted models, and Ep2.false flag.\n",
    "    \n",
    "    Args:\n",
    "      data: A NumPy array containing the data.\n",
    "      models: A dictionary containing fitted models (ps1, ps2, cp1, cp2, sp1, sp2).\n",
    "      Ep2_false: Boolean flag indicating whether to use a zero model for m_p200.\n",
    "    \n",
    "    Returns:\n",
    "      A float representing the D-hat estimator.\n",
    "    \"\"\"\n",
    "    sp1 = models['sp1']\n",
    "    Ep2_false = models['Ep2.false']\n",
    "    \n",
    "    # Conditional mean of p_2^0(H_1)\n",
    "    p10x = sp1(0)\n",
    "    E_p200 = m_p200(data, models)\n",
    "    \n",
    "    # D-hat calculation (empirical version)\n",
    "    Dhat = p10x * E_p200\n",
    "    \n",
    "    # Handle missing values and calculate mean\n",
    "    Dhat[np.isnan(Dhat)] = 0\n",
    "    Dhat = np.mean(Dhat)\n",
    "    \n",
    "    return Dhat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3382ace-52a1-4e03-9033-e53e19d0182f",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    D_hat(XX_obs15, models),\n",
    "    D_hat(XX_obs15, m_123457),\n",
    "    D_hat(XX_obs15, m_12345),\n",
    "    D_hat(XX_obs15, m_4567)\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "12cb1ea8-8282-4a56-90e3-b26dc6e0b3dc",
   "metadata": {},
   "source": [
    "### Plugin for N\n",
    "$$\n",
    "N = \\mathbb E \\bigg[ g(X_1) p_1^0(X_1) \\mathbb E\\big( p_2^0(H_1) \\big| X_1, A_1=0, C_1=0, S_1=1 \\big) \\bigg],\n",
    "$$\n",
    "\n",
    "$$\n",
    "\\begin{align*}\n",
    "g(X_1)\n",
    " &= \\sum_{a_1,a_2 \\in \\{0,1\\}} m_{\\mu_2, \\pi}(x_1, \\bar a) \\mathbf1\\{\\pi_1(X_1) = a_1\\} \\\\\n",
    " &= \\sum_{a_1,a_2 \\in \\{0,1\\}} \\mathbb E \\bigg[ \\mu_2^{a_2}(H_1) \\mathbf1\\{\\pi_2(H_1)=a_2\\} \\bigg| X_1, A_1=a_1, C_1=0, S_1=1 \\bigg] \\mathbf1\\{\\pi_1(X_1) = a_1\\} \\\\\n",
    " &= \\sum_{a_1 \\in \\{0,1\\}} \\mathbb E \\bigg[ \\sum_{a_2} \\mu_2^{a_2}(H_1) \\mathbf1\\{\\pi_2(H_1)=a_2\\} \\bigg| X_1, A_1=a_1, C_1=0, S_1=1 \\bigg] \\mathbf1\\{\\pi_1(X_1) = a_1\\}\n",
    "\\end{align*}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9d787c7-c5c0-4d81-93cd-cae88b0991b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def g_(d1, d2, data, models):\n",
    "    Emupi_false = models['Emupi.false']\n",
    "\n",
    "    def m2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['mu2.hat'].predict(fitdf).values\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'],\n",
    "            'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        pred = models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        return models['K1.hat'].predict(fitdf).values\n",
    "\n",
    "    def sp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'A1': a1, \n",
    "        })\n",
    "        return models['p1.hat'].predict(fitdf).values\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    \n",
    "    data_filled[\"target\"] = (data_filled.A1 == d1) * \\\n",
    "                            (data_filled.C1 == 0) * \\\n",
    "                            (data_filled.S1 == 1) / pcs1(d1) * \\\n",
    "                            m2(d1, d2)\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    if not Emupi_false:\n",
    "        # # OLS\n",
    "        # m_m2_model = smf.ols(\"target ~ 1 + X1\", data=data_filtered).fit()\n",
    "        # m_m2 = m_m2_model.predict(data_filled)\n",
    "\n",
    "        # GAM\n",
    "        bs = BSplines(data_filled[\"X1\"], df=[3+3+1], degree=[3])\n",
    "        m_m2_model = smf.glmgam(\"target ~ 1 + X1\", data=data_filled, smoother=bs).fit()\n",
    "        m_m2 = m_m2_model.predict(data_filled, exog_smooth=bs.x)\n",
    "    else:\n",
    "        m_m2_model = smf.ols(\"target ~ 0 + X1\", data=data_filled).fit()\n",
    "        m_m2 = m_m2_model.predict(data_filled)\n",
    "    \n",
    "    return m_m2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c0a4d23-a18e-48e2-8e6a-e5c441fb122e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pygam import LinearGAM, s, l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3748666-ac98-4864-b88c-5154818b445b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def m_mu_pi(d1, d2, a1, a2, data, models):\n",
    "    Emupi_false = models['Emupi.false']\n",
    "    \n",
    "    def m2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['mu2.hat'].predict(fitdf).values\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'],\n",
    "            'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        pred = models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        return models['K1.hat'].predict(fitdf).values\n",
    "\n",
    "    def sp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'A1': a1, \n",
    "        })\n",
    "        return models['p1.hat'].predict(fitdf).values\n",
    "        \n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    \n",
    "    data_filled[\"target\"] = (data_filled.A1 == a1) * \\\n",
    "                            (data_filled.C1 == 0) * \\\n",
    "                            (data_filled.S1 == 1) / pcs1(a1) * \\\n",
    "                            m2(a1, a2) * (d2 == a2)\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    \n",
    "    if (data_filled.target.mean() != 0) and (data_filled.target.mean() != 1):\n",
    "        if not Emupi_false:\n",
    "            # # OLS\n",
    "            # m_m2_model = smf.ols(\"target ~ 1 + X1\", data=data_filtered).fit()\n",
    "    \n",
    "            # GAM\n",
    "            m_m2_model = LinearGAM(s(0)).fit(data_filled.X1.values.reshape(-1, 1), data_filled.target)\n",
    "            m_m2 = m_m2_model.predict(data_filled.X1.values.reshape(-1, 1))\n",
    "        else:\n",
    "            m_m2_model = smf.ols(\"target ~ 0 + X1\", data=data_filled).fit()\n",
    "            m_m2 = m_m2_model.predict(data_filled)\n",
    "    else:\n",
    "        m_m2 = np.zeros(data_filled.shape[0])\n",
    "    \n",
    "    return m_m2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b253a53b-1170-4a45-9dda-a7ebee079704",
   "metadata": {},
   "outputs": [],
   "source": [
    "def N_hat(d1, d2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the N-hat estimator based on data, fitted models, Ep2.false flag,\n",
    "    and additional arguments d1 and d2.\n",
    "\n",
    "    Args:\n",
    "        d1: A scalar value.\n",
    "        d2: A scalar value.\n",
    "        data: A NumPy array containing the data.\n",
    "        models: A dictionary containing fitted models (ps1, ps2, cp1, cp2, sp1, sp2, m2).\n",
    "        Ep2_false: Boolean flag indicating whether to use a zero model for m_p200.\n",
    "\n",
    "    Returns:\n",
    "        A float representing the N-hat estimator.\n",
    "    \"\"\"\n",
    "    Ep2_false = models['Ep2.false']\n",
    "    Emupi_false = models['Emupi.false']\n",
    "\n",
    "    # Extract prediction functions for models\n",
    "    sp1 = models['sp1']; sp2 = models['sp2']\n",
    "    m2 = models['m2']\n",
    "\n",
    "    # Conditional mean of p_2^0(H_1)\n",
    "    p10x = sp1(0)\n",
    "    p200x = sp2(0, 0)\n",
    "    # data_filled[\"target\"] = p200x  # Add target column\n",
    "    E_p200 = m_p200(data, models)\n",
    "\n",
    "    # g(X1)\n",
    "    g = g_(d1, d2, data, models)\n",
    "\n",
    "    # N-hat calculation (empirical version)\n",
    "    Nhat = g * p10x * E_p200\n",
    "\n",
    "    # Handle missing values and calculate mean\n",
    "    Nhat[np.isnan(Nhat)] = 0\n",
    "    Nhat = np.mean(Nhat)\n",
    "\n",
    "    return Nhat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fadae74b-ec0b-47f3-9418-7439f90491a4",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "d1_true = reg1(*eta_true[:2], XX_obs15)\n",
    "d2_true = reg2(*eta_true[2:], XX_obs15)\n",
    "\n",
    "d1_true.mean(), d2_true.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a561769-a42d-4e87-90aa-f4f60eefc9d0",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "(\n",
    "    N_hat(d1_true, d2_true, XX_obs15, models),\n",
    "    N_hat(d1_true, d2_true, XX_obs15, m_4567)\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "63f86b86-46aa-497a-9c48-fdee60f717cf",
   "metadata": {},
   "source": [
    "### Plugin for V=N/D\n",
    "\n",
    "$$\n",
    "V_{1111}(\\pi) = \\mathbb E \\left[ g(X_1) \\frac{ p_1^0(X_1) \\mathbb E\\left( p_2^0(H_1) | X_1, A_1=0, C_1=0, S_1=1 \\right) }{ \\mathbb E\\left[ p_1^0(X_1) \\mathbb E\\left( p_2^0(H_1) | X_1, A_1=0, C_1=0, S_1=1 \\right) \\right] } \\right],\n",
    "$$\n",
    "\n",
    "$$\n",
    "g(X_1) = \\sum_{a_1,a_2 \\in \\{0,1\\}} \\mathbb E \\left[ \\mu_2^{a_2}(H_1) \\mathbf1\\{\\pi_2(H_1)=a_2\\} | X_1, A_1=a_1, C_1=0, S_1=1 \\right] \\mathbf1\\{\\pi_1(X_1) = a_1\\}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb624d62",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def V_plugin(d1, d2, data, models):\n",
    "    # eta11~24: parameters for decision rules\n",
    "    # reg1,2: decision rules\n",
    "    # data: data.frame\n",
    "    # models: fitted models from fit.models()\n",
    "    \n",
    "    ## decisions\n",
    "    #d1 = reg1(eta11, eta12, data)\n",
    "    #d2 = reg2(eta21, eta22, eta23, data)\n",
    "    \n",
    "    # numerator\n",
    "    val = N_hat(d1, d2, data, models) / D_hat(data, models)\n",
    "    \n",
    "    return val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "715b3b99-7ff2-4d33-807c-d9485af63024",
   "metadata": {},
   "outputs": [],
   "source": [
    "V_hat = V_plugin(d1_true, d2_true, XX_obs15, models)\n",
    "V_hat#, V_true, V_true2, V_true_def"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44c02631-117b-40e1-89bb-e56f2b4ece41",
   "metadata": {},
   "source": [
    "## \"True\" value using ID & true model\n",
    "- Plugin estimator with true & flexible models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09a61fdb-a676-4994-8812-a4d497b1f97e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import inspect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb72f170-92ff-408e-88ef-0b753478f94d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def true_models(data):\n",
    "    models = {}\n",
    "    \n",
    "    class trueModel1():\n",
    "        def __init__(self, md):\n",
    "            self.model_func = md\n",
    "            \n",
    "        def predict(self, fitdf):\n",
    "            return pd.Series(self.model_func(fitdf.X1))\n",
    "\n",
    "    class trueModel11():\n",
    "        def __init__(self, md):\n",
    "            self.model_func = md\n",
    "            \n",
    "        def predict(self, fitdf):\n",
    "            return pd.Series(self.model_func(fitdf.X1, fitdf.A1))\n",
    "\n",
    "    class trueModel2():\n",
    "        def __init__(self, md):\n",
    "            self.model_func = md\n",
    "            \n",
    "        def predict(self, fitdf):\n",
    "            return pd.Series(self.model_func(fitdf.X1, fitdf.X2))\n",
    "\n",
    "    class trueModel22():\n",
    "        def __init__(self, md):\n",
    "            self.model_func = md\n",
    "            \n",
    "        def predict(self, fitdf):\n",
    "            return pd.Series(self.model_func(fitdf.X1, fitdf.X2, fitdf.A1, fitdf.A2))\n",
    "\n",
    "    models['phi1.hat'] = trueModel1(phi1)\n",
    "    models['phi2.hat'] = trueModel2(phi2)\n",
    "    models['K1.hat'] = trueModel11(K1)\n",
    "    models['K2.hat'] = trueModel22(K2)\n",
    "    models['p1.hat'] = trueModel11(p1)\n",
    "    models['p2.hat'] = trueModel22(p2)\n",
    "    models['mu2.hat'] = trueModel22(mu2)\n",
    "        \n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        \n",
    "        pred = phi1(data['X1'])  #models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "    \n",
    "    def ps2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, len(data.X1))\n",
    "        \n",
    "        pred = phi2(data.X1, data.X2)  #models['phi2.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a2)) & (a2 == 0)] = 1 - pred[(~np.isnan(a2)) & (a2 == 0)]\n",
    "        return pred\n",
    "    \n",
    "    def cp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        return K1(data.X1, a1)  #models['K1.hat'].predict(fitdf).values\n",
    "    \n",
    "    def cp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, len(data.X1))\n",
    "        return K2(data.X1, data.X2, a1, a2, eta2)  #models['K2.hat'].predict(fitdf).values\n",
    "\n",
    "    def sp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        return p1(data.X1, a1)\n",
    "    \n",
    "    def sp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, len(data.X1))\n",
    "        return p2(data.X1, data.X2, a1, a2)\n",
    "\n",
    "    def pcs1pcs2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2) * sp2(a1, a2)\n",
    "    \n",
    "    def pcs1pc2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    def pc1(a1):\n",
    "        return ps1(a1) * cp1(a1)\n",
    "    \n",
    "    def m2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, len(data.X1))\n",
    "        return mu2(data.X1, data.X2, a1, a2)  #models['mu2.hat'].predict(fitdf).values\n",
    "\n",
    "    # Return dictionary with models and prediction functions\n",
    "    return {**models,\n",
    "            'ps1': ps1, 'ps2': ps2,  # Prediction functions\n",
    "            'cp1': cp1, 'cp2': cp2,\n",
    "            'sp1': sp1, 'sp2': sp2, 'm2': m2,\n",
    "            'pcs1pcs2': pcs1pcs2, 'pcs1pc2': pcs1pc2, 'pcs1': pcs1, 'pc1': pc1,\n",
    "            'Ep2.false': False, 'Emupi.false': False}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fcb2a67-30ed-431c-a2d4-ef4e31868ce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "tm = true_models(df_large[\"obs\"])  # true + flexible model\n",
    "write_pkl(tm, \"tm_df_large_nonmono.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "efccc473-b76c-4849-a1e2-c55f371f7e37",
   "metadata": {},
   "source": [
    "tm = read_pkl(\"tm_df_large_nonmono.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee76f62d-ac05-4f4a-a7ab-59fa8d42875e",
   "metadata": {},
   "outputs": [],
   "source": [
    "D_true_id = D_hat(df_large[\"obs\"], tm)\n",
    "D_true_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "894f0d19-9ca1-4cbc-874f-bb2d9edf416e",
   "metadata": {},
   "outputs": [],
   "source": [
    "d1_id = reg1(*eta_true[:2], df_large[\"obs\"])\n",
    "d2_id = reg2(*eta_true[2:], df_large[\"obs\"])\n",
    "N_true_id = N_hat(d1_id, d2_id, df_large[\"obs\"], tm)\n",
    "N_true_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b412e34d-2b70-4a09-b1e4-9695af8da1e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "V_true_id = N_true_id / D_true_id\n",
    "V_true_id"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db2abb54",
   "metadata": {},
   "source": [
    "## SD-MR (multiply robust)\n",
    "$$\\hat V_\\text{SD-MR}(\\pi) = \\frac{ \\mathbb P_n \\hat{\\mathcal V}_N(O) }{ \\mathbb P_n \\hat{\\mathcal V}_D(O) }$$\n",
    "where $\\mathbb{P}_n$ is the empirical mean,"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7109d746",
   "metadata": {},
   "source": [
    "### D-MR\n",
    "* EIF-based estimator $\\mathbb P_n \\hat{\\mathcal V}_D(O)$ where\n",
    "$$\n",
    "\\begin{align}\n",
    "{\\mathcal V}_D(O)\n",
    " &= \n",
    "   \\frac{(1-A_1)(1-C_1)}{(\\varphi_1K_1)^0(X_1)} \\left\\{S_1 - p_1^0(X_1)\\right\\} m_{p_2}^{00}(X_1) \\\\\n",
    " &\\quad+\n",
    "   p_1^0(X_1) \\bigg[ \\frac{(1-A_1)(1-A_2)(1-C_1)(1-C_2)S_1}{(\\varphi_1^0K_1^0p_1^0)(X_1)(\\varphi_2^0K_2^0)(H_1)} \\big(S_2 - p_2^0(H_1)\\big) \\\\\n",
    " &\\quad\\qquad\\qquad+ \\frac{(1-A_1)(1-C_1)S_1}{(\\varphi_1^0K_1^0p_1^0)(X_1)} \\left\\{ p_2^0(H_1) - m_{p_2^0}(X_1) \\right\\} \\bigg] \\\\\n",
    " &\\quad+ p_1^0(X_1) m_{p_2}^{00}(X_1)\n",
    "\\end{align}\n",
    "$$\n",
    "and $$m_{p_2}^{a_1,a_2}(x_1) = \\mathbb E[ p_2^{a_2}(H_1) | X_1=x_1, A_1=a_1, C_1=0, S_1=1 ].$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5991a4c2-7894-4c3d-b72a-388cf42cf978",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi_D_terms(data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of D.\n",
    "\n",
    "    Args:\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, Ep2.false, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of D.\n",
    "    \"\"\"\n",
    "\n",
    "    # Extract prediction functions from models (assuming they exist)\n",
    "    ps1 = models['ps1']; ps2 = models['ps2']\n",
    "    cp1 = models['cp1']; cp2 = models['cp2']\n",
    "    sp1 = models['sp1']; sp2 = models['sp2']\n",
    "    Ep2_false = models['Ep2.false']\n",
    "    \n",
    "    pcs1pc2 = models['pcs1pc2']\n",
    "    pcs1 = models['pcs1']\n",
    "    pc1 = models['pc1']\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    # Conditional mean of p_2^00\n",
    "    p10x = sp1(0)\n",
    "    p200x = sp2(0, 0)\n",
    "    E_p200 = m_p200(data, models)\n",
    "\n",
    "    # attach data_filled\n",
    "    A1 = data_filled.A1; A2 = data_filled.A2; C1 = data_filled.C1; \n",
    "    C2 = data_filled.C2; S1 = data_filled.S1; S2 = data_filled.S2\n",
    "    \n",
    "    # EIF-based estimators of D\n",
    "    D11 = (1-A1)*(1-C1) / pc1(0) * (S1)\n",
    "    D12 = (1-A1)*(1-C1) / pc1(0) * (p10x)\n",
    "    D11[np.isnan(D11)] = 0\n",
    "    D12[np.isnan(D12)] = 0\n",
    "    D1 = (D11 - D12) * E_p200\n",
    "    \n",
    "    D211 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0) * (S2)\n",
    "    D212 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0) *p200x\n",
    "    D211[np.isnan(D211)] = 0\n",
    "    D212[np.isnan(D212)] = 0\n",
    "    D2 = p10x * (D211 - D212)\n",
    "\n",
    "    D221 = (1-A1)*(1-C1)*S1 / pcs1(0) * (p200x)\n",
    "    D222 = (1-A1)*(1-C1)*S1 / pcs1(0) * E_p200\n",
    "    D221[np.isnan(D221)] = 0\n",
    "    D222[np.isnan(D222)] = 0\n",
    "    D3 = p10x * (D221 - D222)\n",
    "\n",
    "    D4 = p10x * E_p200\n",
    "\n",
    "    return D1, D2, D3, D4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7176dd8e-882c-49f6-aa41-4eb881db4348",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi_D(data, models):\n",
    "    D1, D2, D3, D4 = phi_D_terms(data, models)\n",
    "    return D1 + D2 + D3 + D4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b65d5234-7fb9-4c82-acc9-e1ec16ff91d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def D_MR(data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of D.\n",
    "\n",
    "    Args:\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, Ep2.false, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of D.\n",
    "    \"\"\"\n",
    "    # Extract prediction functions from models (assuming they exist)\n",
    "    ps1 = models['ps1']; ps2 = models['ps2']\n",
    "    cp1 = models['cp1']; cp2 = models['cp2']\n",
    "    sp1 = models['sp1']; sp2 = models['sp2']\n",
    "    Ep2_false = models['Ep2.false']\n",
    "    \n",
    "    pcs1pc2 = models['pcs1pc2']\n",
    "    pcs1 = models['pcs1']\n",
    "    pc1 = models['pc1']\n",
    "\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    # attach data_filled\n",
    "    A1 = data_filled.A1; A2 = data_filled.A2; C1 = data_filled.C1; \n",
    "    C2 = data_filled.C2; S1 = data_filled.S1; S2 = data_filled.S2\n",
    "\n",
    "    D1, D2, D3, D4 = phi_D_terms(data, models)\n",
    "\n",
    "    # stabilized estimator\n",
    "    w1 = (1-A1)*(1-C1) / pc1(0)\n",
    "    w1[np.isnan(w1)] = 0\n",
    "    w1 = sum(w1)\n",
    "    D1mean = D1.sum() / (w1 + 1e-9)\n",
    "\n",
    "    w2 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0)\n",
    "    w2[np.isnan(w2)] = 0\n",
    "    w2 = sum(w2)\n",
    "    D2mean = D2.sum() / (w2 + 1e-9)\n",
    "\n",
    "    w3 = (1-A1)*(1-C1)*S1 / pcs1(0)\n",
    "    w3[np.isnan(w3)] = 0\n",
    "    w3 = sum(w3)\n",
    "    D3mean = D3.sum() / (w3 + 1e-9)\n",
    "\n",
    "    # print(w1, w2, w3)\n",
    "\n",
    "    return D1mean + D2mean + D3mean + D4.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2f4afb2-37f2-499c-9bab-ea4f32f0a2e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "tm15 = true_models(XX_obs15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e8d3bfc-dccd-4352-9900-73d20ece3e87",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    D_MR(XX_obs15, models),  # EIF of D + D\n",
    "    D_MR(XX_obs15, m_12345),\n",
    "    D_MR(XX_obs15, m_12367),\n",
    "    \">>>\",\n",
    "    D_MR(XX_obs15, m_1234),\n",
    "    D_MR(XX_obs15, m_1236),\n",
    "    D_MR(XX_obs15, m_1235),\n",
    "    D_MR(XX_obs15, m_1237),\n",
    "    \">>>\",\n",
    "    D_MR(XX_obs15, tm15),\n",
    "    D_hat(XX_obs15, models),\n",
    "    D_true_id\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dff157c2",
   "metadata": {},
   "source": [
    "### N-MR\n",
    "\n",
    "* EIF-based estimator $\\mathbb P_n \\hat{\\mathcal V}_N(O)$ where\n",
    "$$\n",
    "\\begin{align}\n",
    "{\\mathcal V}_N(O)\n",
    " &= \\sum_{a_1,a_2} \\bigg[ \\frac{\\mathbb1\\{\\bar A=\\bar a\\}(1-C_1)(1-C_2)S_1S_2}{(\\varphi_1K_1p_1)^{a_1}(X_1)(\\varphi_2K_2p_2)^{a_2}(H_1)} \\big\\{Y - \\mu_2^{a_2}(H_1)\\big\\} \\cdot \\mathbb1\\{\\pi(H_1)=\\bar a\\} \\\\\n",
    " &\\qquad\\qquad+ \\frac{\\mathbb1\\{A_1=a_1\\}(1-C_1)S_1}{(\\varphi_1K_1p_1)^{a_1}(X_1)} \\left\\{ \\mu_2^{a_2}(H_1) \\mathbb1\\{\\pi_2(H_1)=a_2\\} - m_{\\mu_2, \\pi}^{a_1a_2}(X_1) \\right\\} \\cdot \\mathbb1\\{\\pi_1(X_1)=a_1\\} \\bigg] \\\\\n",
    " &\\qquad\\quad \\times p_1^0(X_1) \\cdot m_{p_2}^{00}(X_1) \\\\\n",
    " &+ m_{\\mu_2, \\pi}(x_1) \\cdot \\left[ \\frac{(1-A_1)(1-C_1)}{(\\varphi_1^0K_1^0)(X_1)} \\left(S_1 - p_1^0(X_1)\\right) \\right] \\cdot m_{p_2}^{00}(X_1) \\\\\n",
    " &+ m_{\\mu_2, \\pi}(x_1) \\cdot p_1^0(X_1) \\times \\bigg[ \\frac{(1-A_1)(1-A_2)(1-C_1)(1-C_2)S_1}{(\\varphi_1K_1p_1)^0(X_1)(\\varphi_2^0K_2^0)(H_1)} \\big(S_2 - p_2^0(H_1)\\big) \\\\\n",
    " &\\quad\\qquad\\qquad\\qquad\\qquad\\qquad+ \\frac{(1-A_1)(1-C_1)S_1}{(\\varphi_1K_1p_1)^0(X_1)} \\left\\{ p_2^0(H_1) - m_{p_2}^{00}(X_1) \\right\\} \\bigg] \\\\\n",
    " &+ m_{\\mu_2, \\pi}(x_1) \\cdot p_1^0(X_1) \\cdot m_{p_2}^{00}(X_1)\n",
    "\\end{align}\n",
    "$$\n",
    "and $$m_{\\mu_2, \\pi}^{a_1a_2}(x_1) = \\mathbb E[\\mu_2^{a_2}(H_1) \\mathbb1\\{\\pi_2(H_1)=a_2\\} | X_1=x_1, A_1=a_1, C_1=0, S_1=1 ],$$\n",
    "$$\n",
    "m_{\\mu_2, \\pi}(x_1) = \\sum_{a_1, a_2} m_{\\mu_2, \\pi}(x_1, \\bar a) \\mathbb1\\{\\pi_1(x_1)=a_1\\}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59dc1b6c-503d-4b61-a5ac-a01bf383842b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi_N_terms(d1, d2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of N.\n",
    "\n",
    "    Args:\n",
    "        d1: First decision rule estimate.\n",
    "        d2: Second decision rule estimate.\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, m2, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of N.\n",
    "    \"\"\"\n",
    "\n",
    "    # Extract prediction functions for models (assuming they exist in models)\n",
    "    ps1 = models['ps1']; ps2 = models['ps2']\n",
    "    cp1 = models['cp1']; cp2 = models['cp2']\n",
    "    sp1 = models['sp1']; sp2 = models['sp2']; m2 = models['m2']\n",
    "    Ep2_false = models['Ep2.false']\n",
    "    Emupi_false = models['Emupi.false']\n",
    "\n",
    "    pcs1pcs2 = models['pcs1pcs2']\n",
    "    pcs1pc2 = models['pcs1pc2']\n",
    "    pcs1 = models['pcs1']\n",
    "    pc1 = models['pc1']\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    # Conditional means\n",
    "    p10x = sp1(0)\n",
    "    p200x = sp2(0, 0)\n",
    "    E_p200 = m_p200(data, models)\n",
    "    E_mu2pi = g_(d1, d2, data, models)\n",
    "\n",
    "    # attach data_filled\n",
    "    A1 = data.A1; A2 = data.A2; \n",
    "    C1 = data.C1; C2 = data.C2; \n",
    "    S1 = data.S1; S2 = data.S2; Y = data.Y\n",
    "    \n",
    "    # EIF-based estimators of N\n",
    "    def N11(a1, a2):\n",
    "        val = (A1==a1)*(A2==a2)*(1-C1)*(1-C2)*S1*S2 / pcs1pcs2(a1,a2) * \\\n",
    "              (Y - m2(a1,a2)) * (d2==a2) * (d1==a1)\n",
    "        val[np.isnan(val)] = 0\n",
    "        return val\n",
    "\n",
    "    def N12(a1, a2):\n",
    "        val = (A1==a1)*(1-C1)*S1 / pcs1(a1) * (m2(a1,a2)*(d2==a2) \n",
    "                                               - m_mu_pi(d1, d2, a1, a2, \n",
    "                                                         data, models)) * (d1==a1)\n",
    "        val[np.isnan(val)] = 0\n",
    "        return val\n",
    "\n",
    "    N111_ = (A1==d1)*(A2==d2)*(1-C1)*(1-C2)*S1*S2 / pcs1pcs2(d1,d2) * Y #N11(0, 0) + N11(1, 0) + N11(0, 1) + N11(1, 1)\n",
    "    N112_ = (A1==d1)*(A2==d2)*(1-C1)*(1-C2)*S1*S2 / pcs1pcs2(d1,d2) * m2(d1,d2)\n",
    "    N111_[np.isnan(N111_)] = 0\n",
    "    N112_[np.isnan(N112_)] = 0\n",
    "    N11_ = N111_ - N112_\n",
    "    N121_ = N12(0, 0) + N12(1, 0) + N12(0, 1) + N12(1, 1)\n",
    "    N121_[np.isnan(N121_)] = 0\n",
    "    N12_ = N121_ #- N122_\n",
    "    N11 = (N11_) * p10x * E_p200\n",
    "    N12 = (N12_) * p10x * E_p200\n",
    "\n",
    "    N211 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0) * S2\n",
    "    N212 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0) * p200x  # same as D1\n",
    "    N211[np.isnan(N211)] = 0\n",
    "    N212[np.isnan(N212)] = 0\n",
    "    N21_ = N211 - N212\n",
    "    N221 = (1-A1)*(1-C1)*S1 / pcs1(0) * p200x  # same as D2\n",
    "    N222 = (1-A1)*(1-C1)*S1 / pcs1(0) * E_p200\n",
    "    N221[np.isnan(N221)] = 0\n",
    "    N222[np.isnan(N222)] = 0\n",
    "    N22_ = N221 - N222\n",
    "    N21 = E_mu2pi * p10x * (N21_)\n",
    "    N22 = E_mu2pi * p10x * (N22_)\n",
    "\n",
    "    N311 = (1-A1)*(1-C1) / pc1(0) * (S1)\n",
    "    N312 = (1-A1)*(1-C1) / pc1(0) * (p10x)\n",
    "    N311[np.isnan(N311)] = 0\n",
    "    N312[np.isnan(N312)] = 0\n",
    "    N31 = N311 - N312\n",
    "    N3 = E_mu2pi * N31 * E_p200\n",
    "\n",
    "    N4 = E_mu2pi * p10x * E_p200\n",
    "\n",
    "    return N11, N12, N21, N22, N3, N4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "606ccac3-6421-48a1-8640-3e93b301e92b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi_N(d1, d2, data, models):\n",
    "    N11, N12, N21, N22, N3, N4 = phi_N_terms(d1, d2, data, models)\n",
    "    return N11 + N12 + N21 + N22 + N3 + N4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed48c733-64f5-41da-a5b1-1ebe09bb44f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def N_MR(d1, d2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of N.\n",
    "\n",
    "    Args:\n",
    "        d1: First decision rule estimate.\n",
    "        d2: Second decision rule estimate.\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, m2, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of N.\n",
    "    \"\"\"\n",
    "    # Extract prediction functions for models (assuming they exist in models)\n",
    "    ps1 = models['ps1']; ps2 = models['ps2']\n",
    "    cp1 = models['cp1']; cp2 = models['cp2']\n",
    "    sp1 = models['sp1']; sp2 = models['sp2']; m2 = models['m2']\n",
    "\n",
    "    pcs1pcs2 = models['pcs1pcs2']\n",
    "    pcs1pc2 = models['pcs1pc2']\n",
    "    pcs1 = models['pcs1']\n",
    "    pc1 = models['pc1']\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "    \n",
    "    # attach data_filled\n",
    "    A1 = data_filled.A1; A2 = data_filled.A2; \n",
    "    C1 = data_filled.C1; C2 = data_filled.C2; \n",
    "    S1 = data_filled.S1; S2 = data_filled.S2; Y = data_filled.Y\n",
    "\n",
    "    # stabilize mean of terms\n",
    "    N11, N12, N21, N22, N3, N4 = phi_N_terms(d1, d2, data, models)\n",
    "    w11 = (A1==d1)*(A2==d2)*(1-C1)*(1-C2)*S1*S2 / pcs1pcs2(d1,d2)\n",
    "    w12 = (A1==d1)*(1-C1)*S1 / pcs1(d1)\n",
    "    w21 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0)\n",
    "    w22 = (1-A1)*(1-C1)*S1 / pcs1(0)\n",
    "    w3 = (1-A1)*(1-C1) / pc1(0)\n",
    "\n",
    "    w11[np.isnan(w11)] = 0\n",
    "    w12[np.isnan(w12)] = 0\n",
    "    w21[np.isnan(w21)] = 0\n",
    "    w22[np.isnan(w22)] = 0\n",
    "    w3[np.isnan(w3)] = 0\n",
    "\n",
    "    w11 = sum(w11); w21 = sum(w21); w3 = sum(w3)\n",
    "    w12 = sum(w12); w22 = sum(w22)\n",
    "\n",
    "    N11mean = N11.sum() / (w11 + 1e-9)\n",
    "    N12mean = N12.sum() / (w12 + 1e-9)\n",
    "    N21mean = N21.sum() / (w21 + 1e-9)\n",
    "    N22mean = N22.sum() / (w22 + 1e-9)\n",
    "    N3mean = N3.sum() / (w3 + 1e-9)\n",
    "    \n",
    "    return N11mean + N12mean + N21mean + N22mean + N3mean + N4.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b321063b",
   "metadata": {
    "scrolled": true,
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "(\n",
    "    N_MR(d1_true, d2_true, XX_obs15, models),\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_12345),\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_123457),\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_12367),\n",
    "    \">>>\", \n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_1234),\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_1236),\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_1235),\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_1237),\n",
    "    \">>>\",\n",
    "    N_MR(d1_true, d2_true, XX_obs15, m_4567),\n",
    "    N_hat(d1_true, d2_true, XX_obs15, m_4567),\n",
    "    \">>>\",\n",
    "    N_MR(d1_true, d2_true, XX_obs15, tm15),\n",
    "    N_hat(d1_true, d2_true, XX_obs15, models),\n",
    "    N_true_id,\n",
    "    # N_true2\n",
    ")\n",
    "# np.set_printoptions()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "deb3797c",
   "metadata": {},
   "source": [
    "### $V^{MR}(\\pi) = N/D$\n",
    "$$\n",
    "\\hat V^\\text{MR}(\\pi) = \\frac{ \\mathbb P_n \\hat{\\mathcal V}_N }{ \\mathbb P_n \\hat{\\mathcal V}_D }\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2337c575-bec9-4d0a-a11c-ca72be462846",
   "metadata": {},
   "outputs": [],
   "source": [
    "# EIF-based estimator of V(\\pi) = N/D\n",
    "def V_MR(d1, d2, data, models):\n",
    "    return N_MR(d1, d2, data, models) / D_MR(data, models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbedd337",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "(\n",
    "    V_MR(d1_true, d2_true, XX_obs15, models),\n",
    "    V_MR(d1_true, d2_true, XX_obs15, m_12345),\n",
    "    V_MR(d1_true, d2_true, XX_obs15, m_123457),\n",
    "    \">>>\",\n",
    "    V_MR(d1_true, d2_true, XX_obs15, m_4567),\n",
    "    V_plugin(d1_true, d2_true, XX_obs15, m_4567),\n",
    "    \">>>\",\n",
    "    V_plugin(d1_true, d2_true, XX_obs15, models),\n",
    "    V_true_id,\n",
    "    # V_true2,\n",
    "    # V_true_def\n",
    ")\n",
    "# V.MR2(d1.true, d2.true, XX.obs15, models)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "394379e4",
   "metadata": {},
   "source": [
    "---\n",
    "# OPE with fixed decision rules"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f168e241",
   "metadata": {},
   "source": [
    "## Function `simulate_fixpi()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81d48fc5-ddeb-45ba-88cb-34abc7202f36",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "Ns = [1000, 2000, 5000]\n",
    "\n",
    "for m in range(500):\n",
    "    for i in range(len(Ns)):\n",
    "        # Simulate data\n",
    "        N = Ns[i]  # Replace with actual N value for simulation\n",
    "        \n",
    "        fpath = f\"./simdata/data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        if os.path.exists(fpath):\n",
    "            continue\n",
    "        else:\n",
    "            dat = genData(N, eta2)\n",
    "            write_pkl(dat, fpath)\n",
    "# ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee4b7b18-7550-4ff0-b6de-2419f1fa173e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def simulate_fixpi(m, eta2, eta_true=eta_true):\n",
    "    \"\"\"\n",
    "    Simulates data and calculates estimators under different settings.\n",
    "    \n",
    "    Args:\n",
    "      m: Number of simulations.\n",
    "      \n",
    "    Returns:\n",
    "      A namedtuple containing simulation results:\n",
    "          value_hat: Matrix of estimated values (N / D) for different settings.\n",
    "          Nhats: Matrix of N estimates for different settings.\n",
    "          Dhats: Matrix of D estimates for different settings.\n",
    "    \"\"\"\n",
    "\n",
    "    if len(np.array(eta_true).shape) == 4:\n",
    "        beta_mr_1k, beta_pg_1k = eta_true[m][0]\n",
    "        beta_mr_2k, beta_pg_2k = eta_true[m][1]\n",
    "        \n",
    "    # Define namedtuple for results\n",
    "    Result = namedtuple('Result', ['value_hat', 'Nhats', 'Dhats'])\n",
    "    \n",
    "    # Initialize results matrices\n",
    "    value_hat = np.zeros((len(Ns), 2 * 6 + 1))\n",
    "    Nhats = np.zeros((len(Ns), 2 * 6 + 1))\n",
    "    Dhats = np.zeros((len(Ns), 2 * 6 + 1))\n",
    "    \n",
    "    for i in range(len(Ns)):\n",
    "        # Simulate data\n",
    "        N = Ns[i]  # Replace with actual N value for simulation\n",
    "\n",
    "        fpath = f\"./simdata/data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        if os.path.exists(fpath):\n",
    "            dat = read_pkl(fpath)\n",
    "        else:\n",
    "            dat = genData(N, eta2)\n",
    "            write_pkl(dat, fpath)\n",
    "            \n",
    "        if len(np.array(eta_true).shape) == 4:\n",
    "            if i == 0:\n",
    "                beta_mr = beta_mr_1k\n",
    "                beta_pg = beta_pg_1k\n",
    "            else:\n",
    "                beta_mr = beta_mr_2k\n",
    "                beta_pg = beta_pg_2k\n",
    "        else:\n",
    "            beta_mr = beta_pg = eta_true\n",
    "            \n",
    "        \n",
    "        # Separate observed data and all possible outcomes (if needed)\n",
    "        XX_obs = dat['obs']  # Replace with relevant columns\n",
    "        XX_all = dat['all']  # Replace with relevant columns\n",
    "        \n",
    "        # Fit nuisance models\n",
    "        models = fit_models(XX_obs)  # Replace ... with arguments\n",
    "        tm = true_models(XX_obs)\n",
    "        # settings that should lead to consistent estimator\n",
    "        # : m.1256, m.2345, m.2356, m.3456, m.12346\n",
    "        models_12367 = fit_models(XX_obs, \n",
    "                                  p2_false=True, \n",
    "                                  Ep2_false=True)  # 45 wrong\n",
    "        # models_2345 = fit_models(XX_obs, \n",
    "        #                          phi2_false=True, K2_false=True, \n",
    "        #                          Emupi_false=True)  # 2345 correct  # 16 wrong\n",
    "        models_12467 = fit_models(XX_obs, \n",
    "                                  phi2_false=True, K2_false=True,\n",
    "                                  Ep2_false=True)  # 35 wrong\n",
    "        models_4567 = fit_models(XX_obs, \n",
    "                                 phi2_false=True, K2_false=True,\n",
    "                                 phi1_false=True, K1_false=True, \n",
    "                                 p1_false=True)  # 123 wrong\n",
    "        models_123457 = fit_models(XX_obs, \n",
    "                                   mu2_false=True)  # 6 wrong\n",
    "        \n",
    "        ## settings that is not guaranteed to be consistent\n",
    "        models_12345 = fit_models(XX_obs,\n",
    "                                  mu2_false=True,\n",
    "                                  Emupi_false=True)  # 67 wrong\n",
    "        \n",
    "        # fixed decisions\n",
    "        # print(beta_mr)\n",
    "        # print(beta_pg)\n",
    "        d1mr = reg1(*beta_mr[:2], XX_obs)\n",
    "        d2mr = reg2(*beta_mr[2:], XX_obs)\n",
    "\n",
    "        d1pg = reg1(*beta_pg[:2], XX_obs)\n",
    "        d2pg = reg2(*beta_pg[2:], XX_obs)\n",
    "        \n",
    "        # Estimators with different models\n",
    "        all_models = [models, \n",
    "                      models_12367, models_12467, models_4567, models_123457, \n",
    "                      models_12345]#, models_2345]\n",
    "        n_model = len(all_models)\n",
    "        for mm in range(n_model):\n",
    "            Nhats[i, mm] = N_MR(d1mr, d2mr, XX_obs, all_models[mm])\n",
    "            Dhats[i, mm] = D_MR(XX_obs, all_models[mm])\n",
    "            value_hat[i, mm] = Nhats[i, mm] / Dhats[i, mm]\n",
    "\n",
    "        # Plugin estimators (replace function names with yours)\n",
    "        for mm in range(n_model):\n",
    "            Nhats[i, mm + n_model] = N_hat(d1pg, d2pg, XX_obs, all_models[mm])\n",
    "            Dhats[i, mm + n_model] = D_hat(XX_obs, all_models[mm])\n",
    "            value_hat[i, mm + n_model] = Nhats[i, mm + n_model] / Dhats[i, mm + n_model]\n",
    "\n",
    "        # True value\n",
    "        Nhats[i, -1] = N_hat(d1pg, d2pg, XX_obs, tm)\n",
    "        Dhats[i, -1] = D_hat(XX_obs, tm)\n",
    "        value_hat[i, -1] = Nhats[i, -1] / Dhats[i, -1]\n",
    "\n",
    "    return Result(value_hat, Nhats, Dhats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d73cc2a-161d-4407-88a2-84d0fbd6c806",
   "metadata": {},
   "outputs": [],
   "source": [
    "def initializer():\n",
    "    \"\"\"\n",
    "    Ignore CTRL+C in the worker process.\n",
    "    \"\"\"\n",
    "    signal.signal(signal.SIGINT, signal.SIG_IGN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e83620a8-aa39-447a-95aa-1be9f0c4ea97",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def simulate_fixpi_mp(m): \n",
    "    return simulate_fixpi(m, eta2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "112b81ec-b8e6-4f7a-9ba4-132e274ae9d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "%%time\n",
    "\n",
    "Ns = [1000, 2000, 5000]\n",
    "M = 500\n",
    "\n",
    "# Censoring rate ~= 15%\n",
    "eta2 = 5\n",
    "pd.options.mode.chained_assignment = None\n",
    "\n",
    "sim_15 = []\n",
    "for m in tqdm(range(M)):\n",
    "    sim_15.append( simulate_fixpi(m, eta2) )\n",
    "\n",
    "eta2 = 5\n",
    "pd.options.mode.chained_assignment = 'warn'\n",
    "\n",
    "write_pkl(sim_15, \"sim_15_contiX2_nonmono.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48a659bb-d3de-4830-8ff9-5b21ae175fad",
   "metadata": {},
   "source": [
    "For 500 runs with N=1000, 2000 and 5000,\n",
    "```\n",
    "100%|███████████████████████████████████████| 500/500 [9:10:03<00:00, 66.01s/it]\n",
    "CPU times: user 23h 3min 39s, sys: 13h 56min 10s, total: 1d 12h 59min 50s\n",
    "Wall time: 9h 10min 3s\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a919a6e-c4eb-471b-89bc-71ce39876b0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "Ns = [1000, 2000, 5000]\n",
    "M = 500"
   ]
  },
  {
   "cell_type": "raw",
   "id": "dad4a835-12b0-4407-ae92-9bcc8913ce01",
   "metadata": {},
   "source": [
    "sim_15 = read_pkl(\"sim_15_contiX2_nonmono.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3c64ac7-3888-4cc7-8fcc-41909fff4b2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "D_15 = np.array([sim.Dhats for sim in sim_15])\n",
    "N_15 = np.array([sim.Nhats for sim in sim_15])\n",
    "V_15 = np.array([sim.value_hat for sim in sim_15])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f7670f9-84b7-4834-be75-4517f9a56705",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mdls = [\"\\n12367\", \"\\n12467\", \"\\n4567\", \"\\n123457\", \"\\n12345\"]#, \"\\n2345\"]\n",
    "mdls = [f'M{i+1}' for i in range(6)]\n",
    "# mdls = [\"\\n45\", \"\\n35\", \"\\n123\", \"\\n6\", \"\\n67\"]#, \"\\n2345\"]\n",
    "# mdl_lbl = [r\"MR\"] + mdls + [\"plugin\"] + mdls + [\"True\"] #+ [\"plugin2\"]\n",
    "# mrmdls = [f'M{i+1}' if i != 2 else f'M{i+1}\\n      MR' for i in range(6)]\n",
    "mrmdls = [f'M{i+1}' for i in range(6)]\n",
    "pgmdls = [f'M{i+1}' if i != 2 else f'M{i+1}\\n      Plugin' for i in range(6)]\n",
    "mdl_lbl = mrmdls #+ pgmdls + [\"True\"] #+ [\"plugin2\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfd6272f-bff7-4dcd-8edf-704ff99e197e",
   "metadata": {},
   "outputs": [],
   "source": [
    "mdl_lbl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08ae9cec-2452-4f04-8548-f749f1f0a68a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_large['obs'].C1.mean(), df_large['obs'].C2.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "268c3810-11c5-46f6-a1de-27defa3175bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_large['obs'].S1.mean(), df_large['obs'].S2.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b01d8bf-b309-49c6-bbb8-bc4751dbb49b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stylize_axes(ax):\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "\n",
    "    # ax.xaxis.set_tick_params(top='off', direction='out', width=1)\n",
    "    # ax.yaxis.set_tick_params(right='off', direction='out', width=1)\n",
    "\n",
    "    ax.yaxis.set_ticks_position('left') \n",
    "    ax.xaxis.set_ticks_position('bottom')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "511c785e-580f-459d-8ec5-e15b75011c5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(4.3,2.2), dpi=1200)\n",
    "\n",
    "stylize_axes(plt.subplot(121))\n",
    "plt.boxplot(\n",
    "    V_15[:,1,[0,1,2,4,5,3\n",
    "              #6,7,8,10,11,9\n",
    "             ]], tick_labels=mdl_lbl[:],\n",
    "    sym='x', flierprops={'markersize': 3, 'markeredgewidth': .25}\n",
    ")\n",
    "plt.xticks(rotation=0)\n",
    "plt.ylim(V_true_id*.4, V_true_id*1.5)\n",
    "plt.title(f\"N={Ns[1]}\")\n",
    "plt.ylabel(\"Always-survivor value\")\n",
    "plt.axhline(V_true_id, color='r', ls=\"-\", lw=.75, label='True V');\n",
    "# plt.axvline(6.5, color='k', ls=\"-\", lw=.5, alpha=.75);\n",
    "# plt.axvline(6*2+.5, color='k', ls=\"-\", lw=1);\n",
    "\n",
    "stylize_axes(plt.subplot(122))\n",
    "plt.boxplot(\n",
    "    V_15[:,2,[0,1,2,4,5,3\n",
    "              #6,7,8,10,11,9\n",
    "             ]], tick_labels=mdl_lbl[:],\n",
    "    sym='x', flierprops={'markersize': 3, 'markeredgewidth': .25}\n",
    ")\n",
    "plt.xticks(rotation=0)\n",
    "plt.yticks([], [])\n",
    "plt.ylim(V_true_id*.4, V_true_id*1.5)\n",
    "plt.title(f\"N={Ns[2]}\")\n",
    "plt.axhline(V_true_id, color='r', ls=\"-\", lw=.75, label='True V');\n",
    "# plt.axvline(6.5, color='k', ls=\"-\", lw=.5, alpha=.75);\n",
    "# plt.axvline(6*2+.5, color='k', ls=\"-\", lw=1);\n",
    "\n",
    "# plt.suptitle(\"V (=N/D)\");\n",
    "plt.tight_layout();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbea08de-ffae-4482-8ee9-a22fa6df4309",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_15_1000 = pd.DataFrame(\n",
    "    np.c_[V_15[:,0,:-1].mean(0) - V_true_id, #D_15[:,0,:].mean(0), N_15[:,0,:].mean(0),\n",
    "          V_15[:,0,:-1].std(0), #D_15[:,0,:].std(0), N_15[:,0,:].std(0)\n",
    "]).T\n",
    "\n",
    "df_15_2000 = pd.DataFrame(\n",
    "    np.c_[V_15[:,1,:-1].mean(0) - V_true_id, #D_15[:,0,:].mean(0), N_15[:,0,:].mean(0),\n",
    "          V_15[:,1,:-1].std(0), #D_15[:,0,:].std(0), N_15[:,0,:].std(0)\n",
    "]).T\n",
    "\n",
    "df_15_5000 = pd.DataFrame(\n",
    "    np.c_[V_15[:,2,:-1].mean(0) - V_true_id, #D_15[:,0,:].mean(0), N_15[:,0,:].mean(0),\n",
    "          V_15[:,2,:-1].std(0), #D_15[:,0,:].std(0), N_15[:,0,:].std(0)\n",
    "]).T\n",
    "\n",
    "df_15 = pd.concat([df_15_1000, df_15_2000, df_15_5000]).T\n",
    "df_15.columns = pd.MultiIndex.from_tuples([(f\"N={Ns[0]}\", \"Bias\"), (f\"N={Ns[0]}\", \"SE\"),\n",
    "                                           (f\"N={Ns[1]}\", \"Bias\"), (f\"N={Ns[1]}\", \"SE\"),\n",
    "                                           (f\"N={Ns[2]}\", \"Bias\"), (f\"N={Ns[2]}\", \"SE\")])\n",
    "df_15.index = pd.MultiIndex.from_tuples([(\"EIF\", _.strip()) for _ in mdl_lbl[:6]] + \\\n",
    "                                        [(\"Plugin\", _.strip()) for _ in mdl_lbl[:6]])\n",
    "df_15.round(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71191eda-db21-43d2-afdb-9462be3ad040",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6af900c6-77f6-491f-b8b7-87d5071962f2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12435433-0e35-438b-809e-9308f8627997",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "bae27802-729c-4606-af71-694251b79bb3",
   "metadata": {},
   "source": [
    "## Analytical confidence interval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ff48732-0388-4646-8208-610345ca18a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eif_MR(d1, d2, data, models):\n",
    "    D̂ = D_hat(data, models)\n",
    "    N̂ = N_hat(d1, d2, data, models)\n",
    "    V̂ = N̂/D̂\n",
    "    # V̂ = V_MR(d1, d2, data, models)\n",
    "    ϕD = phi_D(data, models) - D̂\n",
    "    # D̂ = ϕD.mean()\n",
    "    ϕN = phi_N(d1, d2, data, models) - N̂\n",
    "\n",
    "    return (ϕN - V̂ * ϕD) / D̂  # trimmed var to remove outliers\n",
    "\n",
    "def var_MR(d1, d2, data, models, trim=0.0):\n",
    "    lim = trim\n",
    "    # trimmed var to remove outliers\n",
    "    return sp.stats.mstats.trimmed_mean(eif_MR(d1, d2, data, models)**2, (lim, lim))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5beb25d7-3d75-4fb5-8d34-791c020bfacb",
   "metadata": {},
   "outputs": [],
   "source": [
    "XX_obs2k = genData(2000, 5)['obs']\n",
    "d1_2k = reg1(*eta_true[:2], XX_obs2k)\n",
    "d2_2k = reg2(*eta_true[2:], XX_obs2k)\n",
    "models_2k = fit_models(XX_obs2k)\n",
    "\n",
    "V_MR(d1_2k, d2_2k, XX_obs2k, models_2k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7407a057-6bd1-4608-a198-36a07b3329ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "var_MR(d1_2k, d2_2k, XX_obs2k, models_2k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50e0f109-1de5-4557-8d75-967f0c43dc24",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Vs_2k = np.empty(1000)\n",
    "for i in tqdm(range(1000)):\n",
    "    XX_obs2ki = genData(2000, 5)['obs']\n",
    "    d1_2ki = reg1(*eta_true[:2], XX_obs2ki)\n",
    "    d2_2ki = reg2(*eta_true[2:], XX_obs2ki)\n",
    "    models_2ki = fit_models(XX_obs2ki)\n",
    "    \n",
    "    Vs_2k[i] = V_MR(d1_2ki, d2_2ki, XX_obs2ki, models_2ki)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1b5e322-a011-4f28-a29f-dea02abe5ede",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Vs_boot_2k = np.empty(500)\n",
    "for i in tqdm(range(500)):\n",
    "    XX_obs2ki = XX_obs2k.iloc[np.random.choice(len(XX_obs2k), size=len(XX_obs2k), \n",
    "                                               replace=True)].reset_index(drop=True)\n",
    "    \n",
    "    d1_2ki = reg1(*eta_true[:2], XX_obs2ki)\n",
    "    d2_2ki = reg2(*eta_true[2:], XX_obs2ki)\n",
    "    models_2ki = fit_models(XX_obs2ki)\n",
    "    \n",
    "    Vs_boot_2k[i] = V_MR(d1_2ki, d2_2ki, XX_obs2ki, models_2ki)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fececb70-7101-4843-97f6-8f3fbed838cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(3,5))\n",
    "\n",
    "sns.stripplot(Vs_2k, s=2.5, alpha=.5, c=\"gray\")\n",
    "plt.errorbar(-.05, V_MR(d1_2k, d2_2k, XX_obs2k, models_2k),\n",
    "             1.96 * np.sqrt(var_MR(d1_2k, d2_2k, XX_obs2k, models_2k) / 2000),\n",
    "             capsize=5, elinewidth=2, c=\"C3\", label=\"EIF\")\n",
    "plt.errorbar(0.05, V_MR(d1_2k, d2_2k, XX_obs2k, models_2k),\n",
    "             1.96 * Vs_boot_2k.std(),\n",
    "             capsize=5, elinewidth=2, c=\"C1\", label=\"bootstrap M=500\")\n",
    "\n",
    "plt.legend()\n",
    "plt.xticks([0], [\"repeat=1000\"])\n",
    "plt.title(\"95% Confidence Interval, N=2000\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bec9a442-bd59-499d-bd99-b5d3a5b4d2a9",
   "metadata": {},
   "source": [
    "### sim"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7a8295c-a4fb-4ec6-a65c-a51a592e99bf",
   "metadata": {},
   "source": [
    "Save all correct models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4bad7ac-2b57-4e60-b25c-8b624bc980e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "Ns = [1000, 2000, 5000]\n",
    "M = 500\n",
    "\n",
    "all_models_correct = []\n",
    "\n",
    "for N in Ns:\n",
    "    model_N = []\n",
    "    for m in tqdm(range(M)):\n",
    "        fpath = f\"./simdata/data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        dat_obs = read_pkl(fpath)['obs']\n",
    "\n",
    "        model_N.append( fit_models(dat_obs) )\n",
    "    all_models_correct.append( model_N )\n",
    "\n",
    "write_pkl(all_models_correct, 'all_models_correct_nm.pkl')\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "388c437a-b5d0-4bd5-b613-b88ccd91e5f3",
   "metadata": {},
   "source": [
    "all_models_correct = read_pkl('all_models_correct_nm.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6faee159-d301-4cb9-aff4-ec8cb84beceb",
   "metadata": {},
   "source": [
    "Also computing \"true\" value using average of 10 large (1M) samples here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "005d9628-aecd-4895-ae1a-48b436796d05",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "Vs_true = np.empty(20)\n",
    "for _ in tqdm(range(20)):\n",
    "    dd = genData(N=1_000_000, eta2=5)\n",
    "    tmh = true_models(dd['obs'])\n",
    "    \n",
    "    d1dh = reg1(*eta_true[:2], dd['obs'])\n",
    "    d2dh = reg2(*eta_true[2:], dd['obs'])\n",
    "    \n",
    "    Vs_true[_] = V_plugin(d1dh, d2dh, dd['obs'], tmh)\n",
    "\n",
    "write_pkl(Vs_true, \"Vs_true_largesample.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "19e70479-aa89-461d-888f-fec72d6d807a",
   "metadata": {},
   "source": [
    "Vs_true = read_pkl(\"Vs_true_largesample.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a0d4453-9ac4-43d7-95b2-060d8903cf59",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vs_true.mean(), Vs_true.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be79176b-9653-4307-8645-9413ab121a7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "all_var = []\n",
    "for i in range(len(Ns)):\n",
    "    N = Ns[i]\n",
    "    var_N = []\n",
    "    for m in tqdm(range(M)):\n",
    "        fpath = f\"./simdata/data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        dat_obs = read_pkl(fpath)['obs']\n",
    "        model_m = all_models_correct[i][m]\n",
    "\n",
    "        d1_m = reg1(*eta_true[:2], dat_obs)\n",
    "        d2_m = reg2(*eta_true[2:], dat_obs)\n",
    "\n",
    "        var_N.append( np.sqrt(var_MR(d1_m, d2_m, dat_obs, model_m) / N) )\n",
    "\n",
    "    all_var.append( var_N )\n",
    "\n",
    "all_var = np.array(all_var)\n",
    "write_pkl(all_var, \"var_eif.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "9138721f-efe7-4f8c-bee3-8484f7758089",
   "metadata": {},
   "source": [
    "all_var = read_pkl(\"var_eif.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a26b8c85-030d-4ead-ba20-438686bb151e",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_var[0].mean(), all_var[1].mean(), all_var[2].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f0da160-f6e9-465f-ad8c-47ed9a3b3b6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "V_true_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16fd5b2e-8db0-4f42-970a-217466d2cec8",
   "metadata": {},
   "outputs": [],
   "source": [
    "V_15[:, :, 0].mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddd6bf0e-bce0-47e1-b4d0-06b47f58bbf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(3, 4))\n",
    "\n",
    "ax1 = plt.subplot(131)\n",
    "L_eif_1k = V_15[:, 0, 0] - 1.96 * all_var[0]\n",
    "U_eif_1k = V_15[:, 0, 0] + 1.96 * all_var[0]\n",
    "sns.barplot([((Vs_true.mean() - L_eif_1k > 0) & (U_eif_1k - Vs_true.mean() > 0)).mean(0)])\n",
    "plt.xticks(range(1), ['$\\\\phi$'])\n",
    "plt.ylim(0.4, 1.01)\n",
    "plt.title(f\"N={Ns[0]}\")\n",
    "\n",
    "ax2 = plt.subplot(132)\n",
    "ax2.sharey(ax1)\n",
    "L_eif_2k = V_15[:, 1, 0] - 1.96 * all_var[1]\n",
    "U_eif_2k = V_15[:, 1, 0] + 1.96 * all_var[1]\n",
    "sns.barplot([((Vs_true.mean() - L_eif_2k > 0) & (U_eif_2k - Vs_true.mean() > 0)).mean(0)])\n",
    "plt.xticks(range(1), ['$\\\\phi$'])\n",
    "# plt.ylim(0.5, 1.01)\n",
    "plt.title(f\"N={Ns[1]}\")\n",
    "\n",
    "ax3 = plt.subplot(133)\n",
    "ax3.sharey(ax1)\n",
    "L_eif_5k = V_15[:, 2, 0] - 1.96 * all_var[2]\n",
    "U_eif_5k = V_15[:, 2, 0] + 1.96 * all_var[2]\n",
    "sns.barplot([((Vs_true.mean() - L_eif_5k > 0) & (U_eif_5k - Vs_true.mean() > 0)).mean(0)])\n",
    "plt.xticks(range(1), ['$\\\\phi$'])\n",
    "# plt.ylim(0.5, 1.01)\n",
    "plt.title(f\"N={Ns[2]}\")\n",
    "\n",
    "plt.setp(ax2.get_yticklabels(), visible=False)\n",
    "plt.setp(ax3.get_yticklabels(), visible=False)\n",
    "\n",
    "plt.suptitle(f\"Coverage rate (analytical, M={M})\")\n",
    "plt.tight_layout();\n",
    "\n",
    "(((Vs_true.mean() - L_eif_1k > 0) & (U_eif_1k - Vs_true.mean() > 0)).mean(0), \n",
    " ((Vs_true.mean() - L_eif_2k > 0) & (U_eif_2k - Vs_true.mean() > 0)).mean(0),\n",
    " ((Vs_true.mean() - L_eif_5k > 0) & (U_eif_5k - Vs_true.mean() > 0)).mean(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b1c173a-008d-4363-8144-3932c7281165",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1413058b-ff34-487c-a835-3da0eb40e05b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "33690ea2",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "# OPL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ae9d9d9-8b8c-46c2-89d4-1264b4e54a87",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "438fa8ff-ed65-49ee-b40b-30c675d619c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def norm1check1(etas):\n",
    "    eta1s = etas[:2]\n",
    "    return np.linalg.norm(eta1s, 2)\n",
    "\n",
    "def norm1check2(etas):\n",
    "    eta2s = etas[2:]\n",
    "    return np.linalg.norm(eta2s, 2)\n",
    "    \n",
    "norm1const1 = sp.optimize.NonlinearConstraint(norm1check1, 0., 1.)\n",
    "norm1const2 = sp.optimize.NonlinearConstraint(norm1check2, 0., 1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47807c7f-1c32-4cc6-a9a4-9cb953bf7951",
   "metadata": {},
   "outputs": [],
   "source": [
    "Decision = namedtuple('decision', ['d1', 'd2'])\n",
    "\n",
    "def decision(reg1, reg2, ga_sol, data):\n",
    "    \"\"\"\n",
    "    Makes decisions based on decision rules and GA solution.\n",
    "    \n",
    "    Args:\n",
    "      reg1: Function for the first decision rule.\n",
    "      reg2: Function for the second decision rule.\n",
    "      ga_sol: GA solution.\n",
    "      data: Data to use for decision making.\n",
    "    \n",
    "    Returns:\n",
    "      A dictionary containing d1 and d2 decisions.\n",
    "    \"\"\"\n",
    "    # Make decisions using the estimated parameters\n",
    "    d1 = reg1(*ga_sol[:2], data)\n",
    "    d2 = reg2(*ga_sol[2:], data)\n",
    "    \n",
    "    return Decision(d1, d2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af319190",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def decision_opt(df_as):\n",
    "    \"\"\"\n",
    "    Creates a new column in the DataFrame with the integer representation of the maximum subscript of y,\n",
    "    separated into individual digits.\n",
    "    \n",
    "    Args:\n",
    "      df: The input DataFrame with always-survivors.\n",
    "    \n",
    "    Returns:\n",
    "      The modified DataFrame with the new column.\n",
    "    \"\"\"    \n",
    "    # optimal decision\n",
    "    d1_opt = df_as[['y_00', 'y_01', 'y_10', 'y_11']].idxmax(axis=1).str.get(2).astype(int).values\n",
    "    d2_opt = df_as[['y_00', 'y_01', 'y_10', 'y_11']].idxmax(axis=1).str.get(3).astype(int).values\n",
    "\n",
    "    return Decision(d1_opt, d2_opt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "959dce0a-78b8-4f54-a9ee-ac7be903d0bc",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "source": [
    "```python\n",
    "d_opt = decision_opt(XX_as15)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3860da7b-bc80-40a2-9bf5-a5f13e93406c",
   "metadata": {},
   "source": [
    "### V_plugin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b628c7f4-4d24-40a8-aa53-f234341ff24a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def obj_plugin(etas, reg1, reg2, data, models):\n",
    "    # Estimate decision rules\n",
    "    d1 = reg1(*etas[:2], data)\n",
    "    d2 = reg2(*etas[2:], data)\n",
    "    \n",
    "    # Calculate MR estimator (replace V.MR with your implementation)\n",
    "    value = V_plugin(d1, d2, data, models)\n",
    "    \n",
    "    return -value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29236fcb-eba3-46c9-aa0d-fc19300388f1",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "source": [
    "```python\n",
    "%%time\n",
    "with ThreadPool(processes=mp.cpu_count()-1) as pool:\n",
    "    res_pg = sp.optimize.differential_evolution(obj_plugin, bounds=[(-1., 1.)]*len(eta_true),\n",
    "                                                args=[reg1, reg2, XX_obs15, models],\n",
    "                                                workers=pool.map,\n",
    "                                                constraints=norm1const)\n",
    "\n",
    "d_pg = decision(reg1, reg2, res_pg.x, XX_obs15)\n",
    "# d_fixed = decision(reg1, reg2, eta_true, XX_obs15)\n",
    "\n",
    "d_pg.d1[XX_as15.index], d_pg.d2[XX_as15.index]\n",
    "\n",
    "# Percentage of correct decision\n",
    "\n",
    "np.round(\n",
    "    [(d_pg.d1[XX_as15.index] == d_opt.d1).mean(),\n",
    "     (d_pg.d2[XX_as15.index] == d_opt.d2).mean()],\n",
    "    4\n",
    ") * 100\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6a21964",
   "metadata": {},
   "source": [
    "### V_MR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27260f9c-4cdf-4026-8b2f-dceb5915a82e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def obj_sdmr(etas, reg1, reg2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the SDMR estimator.\n",
    "    \n",
    "    Args:\n",
    "      etas: Parameters for decision rules.\n",
    "      reg1, reg2: Functions to estimate decision rules.\n",
    "      data: pandas DataFrame containing the data.\n",
    "      models: Dictionary containing fitted nuisance models.\n",
    "      apply_penalty: Boolean flag indicating whether to apply the constraints.\n",
    "    \n",
    "    Returns:\n",
    "      A float representing the SDMR estimator.\n",
    "    \"\"\"\n",
    "    # Estimate decision rules\n",
    "    d1 = reg1(*etas[:2], data)\n",
    "    d2 = reg2(*etas[2:], data)\n",
    "    \n",
    "    # Calculate MR estimator (replace V.MR with your implementation)\n",
    "    value = V_MR(d1, d2, data, models)\n",
    "    \n",
    "    # # Apply penalty (L2 norm)\n",
    "    # if apply_penalty:\n",
    "    #     # Penalty term (adjust if needed)\n",
    "    #     pen = np.sqrt(np.finfo(float).max)\n",
    "        \n",
    "    #     # Create lists for eta1s and eta2s\n",
    "    #     eta1s = etas[:2]\n",
    "    #     eta2s = etas[2:]\n",
    "        \n",
    "    #     # Calculate penalty\n",
    "    #     penalty = np.abs(np.sum(np.square(eta1s)) - 1) * np.abs(np.sum(np.square(eta2s)) - 1) * pen\n",
    "        \n",
    "    #     # Apply penalty to value\n",
    "    #     value -= penalty\n",
    "    \n",
    "    return -value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "720347bb-0e2e-4662-9243-470b60b44ae7",
   "metadata": {},
   "source": [
    "```python\n",
    "%%time\n",
    "with ThreadPool(processes=mp.cpu_count()-1) as pool:\n",
    "    res = sp.optimize.differential_evolution(obj_sdmr, bounds=[(-1., 1.)]*len(eta_true),\n",
    "                                             args=[reg1, reg2, XX_obs15, models],\n",
    "                                             workers=pool.map,\n",
    "                                             constraints=norm1const)\n",
    "\n",
    "eta_true, res.x\n",
    "\n",
    "d_sdmr = decision(reg1, reg2, res.x, XX_obs15)\n",
    "# d_fixed = decision(reg1, reg2, eta_true, XX_obs15)\n",
    "\n",
    "d_sdmr.d1[XX_as15.index], d_sdmr.d2[XX_as15.index]\n",
    "\n",
    "\n",
    "# Percentage of correct decision\n",
    "np.round(\n",
    "    [(d_sdmr.d1[XX_as15.index] == d_opt.d1).mean(),\n",
    "     (d_sdmr.d2[XX_as15.index] == d_opt.d2).mean()],\n",
    "    4\n",
    ") * 100\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4af03d7d-0aab-4db9-8406-669636a9ebd1",
   "metadata": {},
   "source": [
    "### Optimal linear rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e07ac9d-c48d-4232-8222-bb12d3abd60d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def obj_linear(etas, reg1, reg2, data, tm):\n",
    "    \"\"\"\n",
    "    Calculates oracle value from optimal linear decision rules.\n",
    "    \n",
    "    Args:\n",
    "      etas: Parameters for decision rules.\n",
    "      reg1, reg2: Functions to estimate decision rules.\n",
    "      data: pandas DataFrame containing the obs data.\n",
    "    \n",
    "    Returns:\n",
    "      A float representing the SDMR estimator.\n",
    "    \"\"\"\n",
    "    # Estimate decision rules\n",
    "    d1 = reg1(*etas[:2], data)\n",
    "    d2 = reg2(*etas[2:], data)\n",
    "    \n",
    "    # compute the true AS Value\n",
    "    value = V_plugin(d1, d2, data, tm)\n",
    "    \n",
    "    return -value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93c6fe86-006b-4bf9-b6c3-c0587302a81a",
   "metadata": {},
   "source": [
    "```ipython\n",
    "%%time\n",
    "with ThreadPool(processes=mp.cpu_count()-1) as pool:\n",
    "    res_opt = sp.optimize.differential_evolution(obj_linear, bounds=[(-1., 1.)]*len(eta_true),\n",
    "                                                 args=[reg1, reg2, data_15['as']],\n",
    "                                                 workers=pool.map, popsize=50,\n",
    "                                                 constraints=[norm1const1, norm1const2])\n",
    "```\n",
    "\n",
    "```\n",
    "CPU times: total: 37.4 s\n",
    "Wall time: 1min 42s\n",
    "\n",
    "> res_opt.x\n",
    "array([ 0.16369132,  0.81932788,  0.64243744, -0.16940193,  0.13017567 ])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b577085-9c74-49c7-8f6c-c6937d50ffb6",
   "metadata": {},
   "source": [
    "### True linear optimal decision rule (w/ large sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb1d636a-a4d1-43fc-af3a-962fa10e8c97",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def get_opt_pi(df_large, eta2=5):\n",
    "    \"\"\"\n",
    "    Simulates data and calculates estimators under different settings.\n",
    "    \n",
    "    Args:\n",
    "      m: Number of simulations.\n",
    "      \n",
    "    Returns:\n",
    "      A namedtuple containing simulation results:\n",
    "          value_hat: Matrix of estimated values (N / D) for different settings.\n",
    "          Nhats: Matrix of N estimates for different settings.\n",
    "          Dhats: Matrix of D estimates for different settings.\n",
    "    \"\"\"\n",
    "\n",
    "    # Separate observed data and all possible outcomes\n",
    "    XX_obs = df_large['obs']\n",
    "    \n",
    "    # Fit nuisance models\n",
    "    models = fit_models(XX_obs)\n",
    "    tm = true_models(XX_obs)\n",
    "\n",
    "    with ThreadPool(processes=15) as pool:\n",
    "        res_linear = sp.optimize.differential_evolution(obj_linear, bounds=[(-1., 1.)]*5,\n",
    "                                                        args=[reg1, reg2, XX_obs, tm],\n",
    "                                                        workers=pool.map,\n",
    "                                                        constraints=[norm1const1, norm1const2])\n",
    "\n",
    "    # return Result(value_hat, decisions)\n",
    "    return res_linear.x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3144a33c-4b5f-4e53-838c-94d89e7988ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "pi_opt = get_opt_pi(df_large)\n",
    "write_pkl(pi_opt, \"linear_opt_rule.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "176a934f-0a16-4da5-b4d9-dc6c7fe3ac12",
   "metadata": {},
   "source": [
    "pi_opt = read_pkl(\"linear_opt_rule.pkl\")  # opt rule from `V_true_id()`\n",
    "pi_opt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9377fa18-0879-4ca6-be50-7abe45d8b776",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "source": [
    "### Simulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7559a55",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def simulate(m, eta2):\n",
    "    \"\"\"\n",
    "    Simulates data and calculates estimators under different settings.\n",
    "    \n",
    "    Args:\n",
    "      m: Number of simulations.\n",
    "      \n",
    "    Returns:\n",
    "      A namedtuple containing simulation results:\n",
    "          value_hat: Matrix of estimated values (N / D) for different settings.\n",
    "          Nhats: Matrix of N estimates for different settings.\n",
    "          Dhats: Matrix of D estimates for different settings.\n",
    "    \"\"\"\n",
    "    \n",
    "    # Define namedtuple for results\n",
    "    Result = namedtuple('Result', ['value_hat', 'decision'])\n",
    "    \n",
    "    # Initialize results matrices\n",
    "    value_hat = np.zeros((len(Ns), 3+1))\n",
    "    decisions = []\n",
    "\n",
    "    beta_ = []\n",
    "    for i in range(len(Ns)):\n",
    "        # Simulate data\n",
    "        N = Ns[i]\n",
    "\n",
    "        fpath = f\"./simdata/data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        if os.path.exists(fpath):\n",
    "            dat = read_pkl(fpath)\n",
    "        else:\n",
    "            dat = genData(N, eta2)\n",
    "            write_pkl(dat, fpath)\n",
    "        \n",
    "        # Separate observed data and all possible outcomes\n",
    "        XX_obs = dat['obs']\n",
    "        XX_all = dat['all']\n",
    "        XX_as = dat['as']\n",
    "        \n",
    "        # Fit nuisance models\n",
    "        models = fit_models(XX_obs)  # Replace ... with arguments\n",
    "\n",
    "        with ThreadPool(processes=mp.cpu_count()-1) as pool:\n",
    "            res_mr = sp.optimize.differential_evolution(obj_sdmr, bounds=[(-1., 1.)]*5,\n",
    "                                            args=(reg1, reg2, XX_obs, models),\n",
    "                                            # workers=pool.map,\n",
    "                                          x0=np.random.uniform(high=1, low=-1, size=5),\n",
    "                                          # method=\"SLSQP\",\n",
    "                                            constraints=[norm1const1, norm1const2])\n",
    "            res_pg = sp.optimize.differential_evolution(obj_plugin, bounds=[(-1., 1.)]*5,\n",
    "                                            args=(reg1, reg2, XX_obs, models),\n",
    "                                            # workers=pool.map,\n",
    "                                          x0=np.random.uniform(high=1, low=-1, size=5),\n",
    "                                          # method=\"SLSQP\",\n",
    "                                            constraints=[norm1const1, norm1const2])\n",
    "        \n",
    "        beta_.append([res_mr.x, res_pg.x])\n",
    "\n",
    "    # return Result(value_hat, decisions)\n",
    "    return beta_"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2987e9bf-b3a5-404b-a214-d172dbcfc055",
   "metadata": {},
   "source": [
    "Initialized with `pi_opt + uniform(-.25, .25)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d393824-c521-400c-b6d1-d73836a345c2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ```python\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "beta_opt = []\n",
    "for m in tqdm(range(M)):\n",
    "    beta_opt.append( simulate(m, eta2) )\n",
    "\n",
    "warnings.resetwarnings()\n",
    "\n",
    "beta_opt = np.array(beta_opt)\n",
    "write_pkl(beta_opt, \"mr,pg_opt.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "d399d648-6334-492e-a4e0-f60246889968",
   "metadata": {},
   "source": [
    "beta_opt = read_pkl(\"mr,pg_opt.pkl\")"
   ]
  },
  {
   "cell_type": "raw",
   "id": "2c6728c2-13d6-44bd-9bd2-1b2a678326c0",
   "metadata": {},
   "source": [
    "beta_opt = []\n",
    "for m in range(M):\n",
    "    try:\n",
    "        fpath = f\"./simresult/mr,pg_opt_sim1_{m:03d}.pkl\"\n",
    "        beta_opt_m = read_pkl(fpath)\n",
    "        beta_opt_m = [x[0] for x in beta_opt_m]  # only MR. no PG result\n",
    "        beta_opt.append(beta_opt_m)\n",
    "    except:\n",
    "        print(m, end=', ')\n",
    "\n",
    "beta_opt = np.array(beta_opt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3418a50-466e-451c-bef0-04d69b5a6dab",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_opt.shape  # (n_dataset, n_Ns, n_eta)  # only MR. no PG result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44bdfa53-b8e3-43b9-bed3-612f8b2edbfd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ```python\n",
    "Vhat_opl_single = []\n",
    "\n",
    "for N in Ns:\n",
    "    for m in tqdm(range(M)):\n",
    "        # load data and beta_opt\n",
    "        fpath = f\"./simdata/data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        dat_obs = read_pkl(fpath)['obs']\n",
    "        beta_opt_m = beta_opt[m]\n",
    "        \n",
    "        # decision rule params\n",
    "        if N == Ns[0]:\n",
    "            beta_mr_dat = beta_opt_m[0]  # N==1000\n",
    "        elif N == Ns[1]:\n",
    "            beta_mr_dat = beta_opt_m[1]  # N==2000\n",
    "        else:\n",
    "            beta_mr_dat = beta_opt_m[2]  # N==5000\n",
    "        \n",
    "        # fixed decisions\n",
    "        d1mr_dat = reg1(*beta_mr_dat[:2], dat_obs)\n",
    "        d2mr_dat = reg2(*beta_mr_dat[2:], dat_obs)\n",
    "        \n",
    "        # d1pg_dat = reg1(*beta_pg_dat[:2], dat_obs)\n",
    "        # d2pg_dat = reg2(*beta_pg_dat[2:], dat_obs)\n",
    "        \n",
    "        # Fit nuisance models\n",
    "        model_dat = fit_models(dat_obs)  # Replace ... with arguments\n",
    "        \n",
    "        # MR estimator\n",
    "        # Plugin estimators (replace function names with yours)\n",
    "        vmr_dat = V_MR(d1mr_dat, d2mr_dat, dat_obs, model_dat)\n",
    "        # vpg_dat = V_plugin(d1pg_dat, d2pg_dat, dat_obs, model_dat)\n",
    "\n",
    "        Vhat_opl_single.append([vmr_dat, None])\n",
    "\n",
    "Vhat_opl_single = np.array(Vhat_opl_single).reshape(len(Ns), M, -1)  # Ns, M, mrpg\n",
    "write_pkl(Vhat_opl_single, \"Vhat_opl_single_nm.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a87f3e79-3700-41e2-b830-b8f133f94599",
   "metadata": {},
   "source": [
    "```\n",
    "100%|█████████████████████████████████████████| 500/500 [02:46<00:00,  2.99it/s]\n",
    "100%|█████████████████████████████████████████| 500/500 [03:05<00:00,  2.70it/s]\n",
    "100%|█████████████████████████████████████████| 500/500 [04:12<00:00,  1.98it/s]\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "101a19d0-650f-4713-8189-434a6760dc1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "V_opl_single = []\n",
    "\n",
    "for N in Ns:\n",
    "    for m in tqdm(range(M)):\n",
    "        # load data and beta_opt\n",
    "        # fpath = f\"./simdata/data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        # dat_obs = read_pkl(fpath)['obs']\n",
    "        dat_obs = df_large['obs']\n",
    "        beta_opt_m = beta_opt[m]\n",
    "        \n",
    "        # decision rule params\n",
    "        if N == Ns[0]:\n",
    "            beta_mr_dat = beta_opt_m[0]  # N==1000\n",
    "        elif N == Ns[1]:\n",
    "            beta_mr_dat = beta_opt_m[1]  # N==2000\n",
    "        else:\n",
    "            beta_mr_dat = beta_opt_m[2]  # N==5000\n",
    "        \n",
    "        # fixed decisions\n",
    "        d1mr_dat = reg1(*beta_mr_dat[:2], dat_obs)\n",
    "        d2mr_dat = reg2(*beta_mr_dat[2:], dat_obs)\n",
    "        \n",
    "        # d1pg_dat = reg1(*beta_pg_dat[:2], dat_obs)\n",
    "        # d2pg_dat = reg2(*beta_pg_dat[2:], dat_obs)\n",
    "        \n",
    "        # # Fit nuisance models\n",
    "        # tm_dat = true_models(dat_obs)\n",
    "        \n",
    "        # MR estimator\n",
    "        # 'true' value under best learned decision rules\n",
    "        vmr_dat = V_plugin(d1mr_dat, d2mr_dat, dat_obs, tm)\n",
    "        # vpg_dat = V_plugin(d1pg_dat, d2pg_dat, dat_obs, tm_dat)\n",
    "\n",
    "        V_opl_single.append([vmr_dat, None])\n",
    "\n",
    "\n",
    "V_opl_single = np.array(V_opl_single).reshape(len(Ns), M, -1)  # Ns, M, mrpg\n",
    "write_pkl(V_opl_single, \"V_opl_single_nm.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14e35f10-9d58-4145-80ac-f7c0987d9883",
   "metadata": {},
   "source": [
    "```\n",
    "100%|█████████████████████████████████████████| 500/500 [45:35<00:00,  5.47s/it]\n",
    "100%|█████████████████████████████████████████| 500/500 [39:28<00:00,  4.74s/it]\n",
    "100%|█████████████████████████████████████████| 500/500 [35:17<00:00,  4.24s/it]\n",
    "```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "682deac8-526d-4434-bdd6-5f17034f58ff",
   "metadata": {},
   "source": [
    "Vhat_opl_single = read_pkl(\"Vhat_opl_single_nm.pkl\")\n",
    "V_opl_single = read_pkl(\"V_opl_single_nm.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7df54f3a-4337-444c-8b73-6a97303c265c",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vhat_opl_single.shape, V_opl_single.shape  # Ns, M, mrpg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52a75322-5b09-4414-a111-c50f1769aedd",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vall_opl_single = np.c_[Vhat_opl_single[:,:,[0]], V_opl_single[:,:,[0]]]\n",
    "Vall_opl_single.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0cf7610-486f-4cda-95ea-478b327bfb6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "d1_opt = reg1(*pi_opt[:2], df_large[\"obs\"])\n",
    "d2_opt = reg2(*pi_opt[2:], df_large[\"obs\"])\n",
    "V_true_id_opt = V_plugin(d1_opt, d2_opt, df_large[\"obs\"], tm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20dedfd6-020c-4918-a69f-2970dfcdeff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,3))\n",
    "\n",
    "ax1 = plt.subplot(131)\n",
    "plt.boxplot(\n",
    "    Vall_opl_single[0,:,:], tick_labels=[r'$\\widehat{V}^{MR}(\\widehat{\\beta})$', \n",
    "                                         r'${V}(\\widehat{\\beta})$'],\n",
    "    flierprops={'marker': 'x', 'markersize': 3, 'markeredgewidth': .5, 'alpha': .7}\n",
    ")\n",
    "# plt.xticks(rotation=-90)\n",
    "# plt.ylim(V_true*0.05, V_true*5)\n",
    "plt.title(f\"N={Ns[0]}\")\n",
    "plt.axhline(V_true_id_opt, color='r', ls=\"-\", lw=.75)\n",
    "# plt.axhline(V_true, color='r', ls=\"-\", lw=.75, label='True V');\n",
    "# plt.axvline(7.5, color='k', ls=\"-\", lw=1);\n",
    "# plt.axvline(7*2+.5, color='k', ls=\"-\", lw=1);\n",
    "# plt.ylim(0, 1.75)\n",
    "\n",
    "ax2 = plt.subplot(132)\n",
    "ax2.sharey(ax1)\n",
    "plt.boxplot(\n",
    "    Vall_opl_single[1,:,:], tick_labels=[r'$\\widehat{V}^{MR}(\\widehat{\\beta})$', \n",
    "                                         r'${V}(\\widehat{\\beta})$'],\n",
    "    flierprops={'marker': 'x', 'markersize': 3, 'markeredgewidth': .5, 'alpha': .7}\n",
    ")\n",
    "# plt.xticks(rotation=-90)\n",
    "# plt.ylim(V_true*0.05, V_true*5)\n",
    "plt.title(f\"N={Ns[1]}\")\n",
    "plt.axhline(V_true_id_opt, color='r', ls=\"-\", lw=.75)\n",
    "# plt.axhline(V_true, color='r', ls=\"-\", lw=.75, label='True V');\n",
    "# plt.axvline(7.5, color='k', ls=\"-\", lw=1);\n",
    "\n",
    "ax3 = plt.subplot(133)\n",
    "ax3.sharey(ax1)\n",
    "plt.boxplot(\n",
    "    Vall_opl_single[2,:,:], tick_labels=[r'$\\widehat{V}^{MR}(\\widehat{\\beta})$', \n",
    "                                         r'${V}(\\widehat{\\beta})$'],\n",
    "    flierprops={'marker': 'x', 'markersize': 3, 'markeredgewidth': .5, 'alpha': .7}\n",
    ")\n",
    "# plt.xticks(rotation=-90)\n",
    "# plt.ylim(V_true*0.05, V_true*5)\n",
    "plt.title(f\"N={Ns[2]}\")\n",
    "plt.axhline(V_true_id_opt, color='r', ls=\"-\", lw=.75)\n",
    "\n",
    "plt.suptitle(\"V (=N/D)\")\n",
    "plt.tight_layout();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d34219e6-a350-4ea9-bd4a-5a293186c4c6",
   "metadata": {},
   "source": [
    "## Analytic confidence interval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8b621a3-b0a8-4df6-957e-93b95518a846",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "all_var_eif = []\n",
    "for i in range(len(Ns)):\n",
    "    N = Ns[i]\n",
    "    var_N = []\n",
    "    for m in tqdm(range(M)):\n",
    "        fpath = f\"./simdata/data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        dat_obs = read_pkl(fpath)['obs']\n",
    "        model_m = all_models_correct[2][m]\n",
    "\n",
    "        beta_opt_m = beta_opt[m]\n",
    "            \n",
    "        # decision rule params\n",
    "        if N == Ns[0]:\n",
    "            beta_mr_dat = beta_opt_m[0]  # N==1000\n",
    "        elif N == Ns[1]:\n",
    "            beta_mr_dat = beta_opt_m[1]  # N==2000\n",
    "        else:\n",
    "            beta_mr_dat = beta_opt_m[2]  # N==5000\n",
    "        \n",
    "        # fixed decisions\n",
    "        d1mr_dat = reg1(*beta_mr_dat[:2], dat_obs)\n",
    "        d2mr_dat = reg2(*beta_mr_dat[2:], dat_obs)\n",
    "\n",
    "        var_N.append( np.sqrt(var_MR(d1mr_dat, d2mr_dat, dat_obs, model_m) / N) )\n",
    "\n",
    "    all_var_eif.append( var_N )\n",
    "\n",
    "all_var_eif = np.array(all_var_eif)\n",
    "write_pkl(all_var_eif, \"var_opl_eif.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "2485dcd6-f4a2-4f7e-bb82-bdb7b2c4dbf8",
   "metadata": {},
   "source": [
    "all_var_eif = read_pkl(\"var_opl_eif.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ee383ee-33ec-46d5-8395-3e361b512215",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "Vs_true_opt = np.empty(20)\n",
    "for _ in tqdm(range(20)):\n",
    "    dd = genData(N=1_000_000, eta2=eta2)\n",
    "    tmh = true_models(dd['obs'])\n",
    "    \n",
    "    d1dh = reg1(*pi_opt[:2], dd[\"obs\"])\n",
    "    d2dh = reg2(*pi_opt[2:], dd[\"obs\"])\n",
    "    \n",
    "    Vs_true_opt[_] = V_plugin(d1dh, d2dh, dd['obs'], tmh)\n",
    "\n",
    "write_pkl(Vs_true_opt, \"Vs_true_opt_largesample.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "23a0de11-1f6b-4514-9054-d5b05a6b235e",
   "metadata": {},
   "source": [
    "Vs_true_opt = read_pkl(\"Vs_true_opt_largesample.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bba7ce1-4904-409c-bdb7-b1b7db248486",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vs_true_opt.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80f507a5-a4b8-4ff1-8082-b7b038afc9ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vhat_opl_single[0,:,0].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "112392e2-4d21-4876-92c6-5cebe2e17e27",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5999da64-6223-4870-a68d-896888e1f210",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(3, 4))\n",
    "\n",
    "plt.subplot(131)\n",
    "L_eif_1k = Vhat_opl_single[0,:,0] - 1.96 * all_var_eif[0]\n",
    "U_eif_1k = Vhat_opl_single[0,:,0] + 1.96 * all_var_eif[0]\n",
    "sns.barplot([((Vs_true_opt.mean() - L_eif_1k > 0) & (U_eif_1k - Vs_true_opt.mean() > 0)).mean(0)])\n",
    "plt.xticks(range(1), ['$\\\\phi$'])\n",
    "plt.ylim(0.5, 1.01)\n",
    "plt.title(f\"N={Ns[0]}\")\n",
    "\n",
    "plt.subplot(132)\n",
    "L_eif_2k = Vhat_opl_single[1,:,0] - 1.96 * all_var_eif[1]\n",
    "U_eif_2k = Vhat_opl_single[1,:,0] + 1.96 * all_var_eif[1]\n",
    "sns.barplot([((Vs_true_opt.mean() - L_eif_2k > 0) & (U_eif_2k - Vs_true_opt.mean() > 0)).mean(0)])\n",
    "plt.xticks(range(1), ['$\\\\phi$'])\n",
    "plt.ylim(0.5, 1.01)\n",
    "plt.title(f\"N={Ns[1]}\")\n",
    "\n",
    "plt.subplot(133)\n",
    "L_eif_5k = Vhat_opl_single[2,:,0] - 1.96 * all_var_eif[2]\n",
    "U_eif_5k = Vhat_opl_single[2,:,0] + 1.96 * all_var_eif[2]\n",
    "sns.barplot([((Vs_true_opt.mean() - L_eif_5k > 0) & (U_eif_5k - Vs_true_opt.mean() > 0)).mean(0)])\n",
    "plt.xticks(range(1), ['$\\\\phi$'])\n",
    "plt.ylim(0.5, 1.01)\n",
    "plt.title(f\"N={Ns[2]}\")\n",
    "\n",
    "plt.suptitle(f\"Coverage rate (analytical, M={M})\")\n",
    "plt.tight_layout();\n",
    "\n",
    "(((Vs_true_opt.mean() - L_eif_1k > 0) & (U_eif_1k - Vs_true_opt.mean() > 0)).mean(0), \n",
    " ((Vs_true_opt.mean() - L_eif_2k > 0) & (U_eif_2k - Vs_true_opt.mean() > 0)).mean(0),\n",
    " ((Vs_true_opt.mean() - L_eif_5k > 0) & (U_eif_5k - Vs_true_opt.mean() > 0)).mean(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db2c67be-dae4-4a78-91bd-523f756eb7ae",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "source": [
    "## Percentage of correct decisions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65db921f-8dac-403d-90de-c33b044655c3",
   "metadata": {},
   "source": [
    "### ID formula for AS-PCD"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "754f5d41-ef64-4baf-a007-236ea636b64d",
   "metadata": {},
   "source": [
    "$$\n",
    "\\mathbb E[h(X_1) | U_2=1111] = \\mathbb E \\left[ h(X_1) \\frac{ p_1^0(X_1) \\mathbb E\\left( p_2^0(H_1) | X_1, A_1=0, C_1=0, S_1=1 \\right) }{ \\mathbb E\\left[ p_1^0(X_1) \\mathbb E\\left( p_2^0(H_1) | X_1, A_1=0, C_1=0, S_1=1 \\right) \\right] } \\right]\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb1a16cd-a249-4e09-81c0-06021d8b967d",
   "metadata": {},
   "source": [
    "$$\n",
    "h(X_1; \\widehat\\pi, \\pi^*) = \\mathbb E\\left[ \\mathbb1\\left\\{\\widehat{\\pi}(H_1) = \\pi^*(H_1)\\right\\} | X_1, A_1=\\pi_1(X_1), C_1=0, S_1=1\\right]\n",
    "$$\n",
    "\n",
    "$$\n",
    "\\pi(H_1^{\\pi_1}) = \\left( \\pi_1(X_1), \\pi_2\\left(X_1, \\pi_1(X_1), X_2^{\\pi_1(X_1)}\\right) \\right)\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4d0cd8b-d3eb-4550-b3a4-5d1c2c062bad",
   "metadata": {},
   "outputs": [],
   "source": [
    "def h_(d1, d2, d1_opt, d2_opt, data, models):\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'],\n",
    "            'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        pred = models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        return models['K1.hat'].predict(fitdf).values\n",
    "\n",
    "    def sp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'A1': a1, \n",
    "        })\n",
    "        return models['p1.hat'].predict(fitdf).values\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    \n",
    "    data_filled[\"target\"] = (data_filled.A1 == d1) * \\\n",
    "                            (data_filled.C1 == 0) * \\\n",
    "                            (data_filled.S1 == 1) / pcs1(d1) * \\\n",
    "                            (d1==d1_opt) * (d2==d2_opt)\n",
    "    data_filled[np.isnan(data_filled)] = 0\n",
    "    # data_filled[\"target\"] = m2(d1, d2)\n",
    "    # data_filtered = data_filled[(data_filled.A1 == d1) & \n",
    "    #                             (data_filled.C1 == 0) & \n",
    "    #                             (data_filled.S1 == 1)]\n",
    "\n",
    "    # GAM\n",
    "    bs = BSplines(data_filled[\"X1\"], df=[3+3+1], degree=[3])\n",
    "    h_model = smf.glmgam(\"target ~ 1 + X1 + I(X1**2)\", data=data_filled, smoother=bs).fit()\n",
    "    # m_m2 = np.zeros(len(data_filled))\n",
    "    h_val = h_model.predict(data_filled, exog_smooth=bs.x)\n",
    "    \n",
    "    return h_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6ef8170-7b75-41c7-b1a7-6109b84dfe1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def N_pcd(d1, d2, d1_opt, d2_opt, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the N-hat estimator based on data, fitted models, Ep2.false flag,\n",
    "    and additional arguments d1 and d2.\n",
    "\n",
    "    Args:\n",
    "        d1: A scalar value.\n",
    "        d2: A scalar value.\n",
    "        data: A NumPy array containing the data.\n",
    "        models: A dictionary containing fitted models (ps1, ps2, cp1, cp2, sp1, sp2, m2).\n",
    "        Ep2_false: Boolean flag indicating whether to use a zero model for m_p200.\n",
    "\n",
    "    Returns:\n",
    "        A float representing the N-hat estimator.\n",
    "    \"\"\"\n",
    "    # Extract prediction functions for models\n",
    "    sp1 = models['sp1']; sp2 = models['sp2']\n",
    "    m2 = models['m2']\n",
    "\n",
    "    # Conditional mean of p_2^0(H_1)\n",
    "    p10x = sp1(0)\n",
    "    p200x = sp2(0, 0)\n",
    "    # data_filled[\"target\"] = p200x  # Add target column\n",
    "    E_p200 = m_p200(data, models)\n",
    "\n",
    "    # h(X1)\n",
    "    h_val = h_(d1, d2, d1_opt, d2_opt, data, models)\n",
    "\n",
    "    # N-hat calculation (empirical version)\n",
    "    Nhat = h_val * p10x * E_p200\n",
    "\n",
    "    # Handle missing values and calculate mean\n",
    "    Nhat[np.isnan(Nhat)] = 0\n",
    "    Nhat = np.mean(Nhat)\n",
    "\n",
    "    return Nhat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2720495b-8356-403e-97fd-b2a3d64b0e48",
   "metadata": {},
   "outputs": [],
   "source": [
    "def PCD_AS(d1, d2, d1_opt, d2_opt, data, models):\n",
    "    val = N_pcd(d1, d2, d1_opt, d2_opt, data, models) / D_hat(data, models)\n",
    "    return val"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98546bcf-3719-449f-b36c-df3b68d48f63",
   "metadata": {},
   "source": [
    "### Sim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "567c95b1-d30d-41cc-9d7c-c800991b82bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "dd = genData(N=100_000, eta2=eta2)\n",
    "dd_as_obs = dd['obs'].loc[dd[\"as\"].index, :]\n",
    "\n",
    "write_pkl(dd, \"large_obs_PCD.pkl\")\n",
    "write_pkl(dd_as_obs, \"large_as_obs_PCD.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "8a480f80-f492-4f6d-9363-2d2008f6d1e1",
   "metadata": {},
   "source": [
    "dd = read_pkl(\"large_obs_PCD.pkl\")\n",
    "dd_as_obs = read_pkl(\"large_as_obs_PCD.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4296f1f9-7703-4874-a7a4-db0e59c24bec",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ```python\n",
    "pcd_mr_all = []\n",
    "# pcd_pg_all = []\n",
    "\n",
    "tm_dd = true_models(data=dd['obs'])\n",
    "\n",
    "for m in tqdm(range(M)):\n",
    "    beta_opt_m = beta_opt[m]\n",
    "    \n",
    "    d1dh = reg1(*pi_opt[:2], dd['obs'])\n",
    "    d2dh = reg2(*pi_opt[2:], dd['obs'])\n",
    "    # d1dh_mr = reg1(*pi_opt_mr[:2], dd['obs'])\n",
    "    # d2dh_mr = reg2(*pi_opt_mr[2:], dd['obs'])\n",
    "    \n",
    "    d1mr1k = reg1(*beta_opt_m[0,:2], dd['obs'])\n",
    "    d2mr1k = reg2(*beta_opt_m[0,2:], dd['obs'])\n",
    "    # d1pg1k = reg1(*beta_opt_m[0,:2], dd['obs'])\n",
    "    # d2pg1k = reg2(*beta_opt_m[0,2:], dd['obs'])\n",
    "    \n",
    "    d1mr2k = reg1(*beta_opt_m[1,:2], dd['obs'])\n",
    "    d2mr2k = reg2(*beta_opt_m[1,2:], dd['obs'])\n",
    "    # d1pg2k = reg1(*beta_opt_m[1,:2], dd['obs'])\n",
    "    # d2pg2k = reg2(*beta_opt_m[1,2:], dd['obs'])\n",
    "\n",
    "    d1mr5k = reg1(*beta_opt_m[2,:2], dd['obs'])\n",
    "    d2mr5k = reg2(*beta_opt_m[2,2:], dd['obs'])\n",
    "    \n",
    "    pcd_mr = (\n",
    "        # ((d1mr1k == d1dh) * (d2mr1k == d2dh)).mean(),\n",
    "        # ((d1mr2k == d1dh) * (d2mr2k == d2dh)).mean()\n",
    "        PCD_AS(d1mr1k, d2mr1k, d1dh, d2dh, dd['obs'], tm_dd),\n",
    "        PCD_AS(d1mr2k, d2mr2k, d1dh, d2dh, dd['obs'], tm_dd),\n",
    "        PCD_AS(d1mr5k, d2mr5k, d1dh, d2dh, dd['obs'], tm_dd)\n",
    "    )\n",
    "    # pcd_pg = (\n",
    "    #     # ((d1pg1k == d1dh) * (d2pg1k == d2dh)).mean(),\n",
    "    #     # ((d1pg2k == d1dh) * (d2pg2k == d2dh)).mean()\n",
    "    #     PCD_AS(d1pg1k, d2pg1k, d1dh, d2dh, dd['obs'], tm_dd),\n",
    "    #     PCD_AS(d1pg2k, d2pg2k, d1dh, d2dh, dd['obs'], tm_dd)\n",
    "    # )\n",
    "    \n",
    "    pcd_mr_all.append(pcd_mr)\n",
    "    # pcd_pg_all.append(pcd_pg)\n",
    "    \n",
    "pcd_mr_all = np.array(pcd_mr_all)\n",
    "# pcd_pg_all = np.array(pcd_pg_all)\n",
    "\n",
    "write_pkl(pcd_mr_all, 'pcd_mr_all_sim1.pkl')\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "4249c282-2652-42b8-af2a-a8f44ecd0ec7",
   "metadata": {},
   "source": [
    "pcd_mr_all = read_pkl('pcd_mr_all_sim1.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dc78566-84bd-4aa6-af6e-122e6ff17201",
   "metadata": {},
   "outputs": [],
   "source": [
    "pcd_1k_prod = np.c_[\n",
    "    pcd_mr_all[:, 0]\n",
    "]\n",
    "pcd_2k_prod = np.c_[\n",
    "    pcd_mr_all[:, 1]\n",
    "]\n",
    "pcd_5k_prod = np.c_[\n",
    "    pcd_mr_all[:, 2]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e76094a-67ed-44e1-b9dd-97f8cdfc477f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pcd_1k = pd.DataFrame(pcd_1k_prod, columns=[r\"$\\phi$\"])\n",
    "df_pcd_2k = pd.DataFrame(pcd_2k_prod, columns=[r\"$\\phi$\"])\n",
    "df_pcd_5k = pd.DataFrame(pcd_5k_prod, columns=[r\"$\\phi$\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e02539f-d880-4f1e-b49d-48a84d716ddc",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pcd = np.hstack([df_pcd_1k, df_pcd_2k, df_pcd_5k])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eefb7208-5850-41f7-89c8-9add92ffab0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(3,4))\n",
    "# sns.stripplot(df_pcd, size=3, alpha=.05, color='k')\n",
    "plt.boxplot(df_pcd, positions=range(3),\n",
    "            flierprops={'marker': 'x', 'markersize': 3, \n",
    "                        'markeredgewidth': .5, 'alpha': .7})\n",
    "plt.xticks(range(3), Ns)\n",
    "plt.title(\"Percentage of Correct Decisions \\non Always-Survivors\")\n",
    "# plt.ylim(0.895, 1.005)\n",
    "plt.xlabel(\"Training N\")\n",
    "plt.ylabel(\"%\")\n",
    "plt.yticks(np.arange(0.9, 1.01, 0.02), range(90, 101, 2));"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a23691c6-1fc7-4df0-b976-87637af65dbf",
   "metadata": {},
   "source": [
    "Tested on data of size 100,000."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0f1ea27-4672-4d77-9081-c07b945c1818",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pcd.mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7a3aeaa-1dd3-40f1-a857-3e0ad7cb4b99",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pcd.std(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0197c630-ed76-480e-be93-99660c13684a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "214035fc-501c-4498-a007-b8ec56ccf19c",
   "metadata": {},
   "source": [
    "### AIPW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04b5d15e-9669-40dc-84d0-6a101ae9963e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_models_aipcw(data, phi1_false=False, phi2_false=False,\n",
    "               K1_false=False, K2_false=False,\n",
    "               p1_false=False, p2_false=False,\n",
    "               mu2_false=False, Ep2_false=False, \n",
    "               Emupi_false=False):\n",
    "    \"\"\"\n",
    "    Fits various nuisance models based on boolean flags.\n",
    "\n",
    "    Args:\n",
    "        data: A pandas DataFrame containing the observed data.\n",
    "        phi1_false, phi2_false, etc.: Boolean flags indicating whether to use\n",
    "            a point-five model instead of fitting the corresponding model.\n",
    "        Ep2_false, Emu2p2_false: (Currently unused)\n",
    "\n",
    "    Returns:\n",
    "        A dictionary containing fitted models and prediction functions.\n",
    "    \"\"\"\n",
    "\n",
    "    models = {}\n",
    "    if not phi1_false:\n",
    "        models['phi1.hat'] = smf.glm(\"A1 ~ 1 + I(X1**2)\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['phi1.hat'] = smf.glm(\"A1 ~ 0 + X1 + I(X1**2)\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "\n",
    "    if not phi2_false:\n",
    "        models['phi2.hat'] = smf.glm(\"A2 ~ 1 + I(X1**2) + X2 + I(X2**2)\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['phi2.hat'] = smf.glm(\"A2 ~ 1 + X2\", \n",
    "                                     data=data, family=sm.families.Binomial()).fit()\n",
    "\n",
    "    if not K1_false:\n",
    "        models['K1.hat'] = smf.glm(\"I(1-C1) ~ 1 + I(X1**2)\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['K1.hat'] = smf.glm(\"I(1-C1) ~ 0 + X1\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "\n",
    "    if not K2_false:\n",
    "        models['K2.hat'] = smf.glm(\"I(1-C2) ~ 1 + X1 + X2 + A2 + A2:X2\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['K2.hat'] = smf.glm(\"I(1-C2) ~ 1 + X2\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "\n",
    "    if not p1_false:\n",
    "        models['p1.hat'] = smf.glm(\"S1 ~ 1 + I(X1**2) + A1 + A1:X1\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['p1.hat'] = smf.glm(\"S1 ~ 0 + A1\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "        \n",
    "    if not p2_false:\n",
    "        models['p2.hat'] = smf.glm(\"S2 ~ 1 + X1 + X1:X2 + A1 + A2\", \n",
    "                                   #\"S2 ~ 1 + X1 + I(X2**2) + A1 + A2 + A2:I(X2**2)\", \n",
    "                                   # \"S2 ~ 1 + X1 + X2 + A1 + A2 + A2:X2\", \n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "    else:\n",
    "        models['p2.hat'] = smf.glm(\"S2 ~ 0 + A1 + A2\",\n",
    "                                   data=data, family=sm.families.Binomial()).fit()\n",
    "\n",
    "    models['mu1.hat'] = smf.ols(\"Y ~ 1 + I(X1**2) + A1 + A1:X1\", data=data).fit()\n",
    "    \n",
    "    if not mu2_false:\n",
    "        models['mu2.hat'] = smf.ols(\"Y ~ 1 + X1 + A1 + A1:X1 + I(np.exp(X2)) + A2 + A1:A2 + X2:A2\", \n",
    "                                    data=data).fit()\n",
    "    else:\n",
    "        models['mu2.hat'] = smf.ols(\"Y ~ 1 + X1 + X2\", \n",
    "                                    data=data).fit()\n",
    "\n",
    "    # Return dictionary with models and prediction functions\n",
    "    return models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0568e2ad-85a7-4e91-b0c7-8d892ad13931",
   "metadata": {},
   "outputs": [],
   "source": [
    "def V_aipcw(d1, d2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of N.\n",
    "\n",
    "    Args:\n",
    "        d1: First decision rule estimate.\n",
    "        d2: Second decision rule estimate.\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, m2, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of N.\n",
    "    \"\"\"\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "            \n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'],\n",
    "            'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        pred = models['phi1.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def ps2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, len(data.X1))\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, len(data.X1))\n",
    "        \n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        pred = models['phi2.hat'].predict(fitdf).values\n",
    "        pred[(~np.isnan(a2)) & (a2 == 0)] = 1 - pred[(~np.isnan(a2)) & (a2 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1\n",
    "        })\n",
    "        return models['K1.hat'].predict(fitdf).values\n",
    "\n",
    "    def cp2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['K2.hat'].predict(fitdf).values\n",
    "    \n",
    "    def sp1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'A1': a1, \n",
    "        })\n",
    "        return models['p1.hat'].predict(fitdf).values\n",
    "    \n",
    "    def sp2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'const': 1, 'X1': data.X1, 'X2': data.X2,\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['p2.hat'].predict(fitdf).values\n",
    "\n",
    "    def pcs1pcs2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2) * sp2(a1, a2)\n",
    "    \n",
    "    def pcs1pc2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    def pc1(a1):\n",
    "        return ps1(a1) * cp1(a1)\n",
    "\n",
    "    def q1(a1):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], \n",
    "            'A1': a1,\n",
    "        })\n",
    "        return models['mu1.hat'].predict(fitdf).values\n",
    "\n",
    "    def q2(a1, a2):\n",
    "        fitdf = pd.DataFrame({\n",
    "            'X1': data['X1'], 'X2': data['X2'],\n",
    "            'A1': a1, 'A2': a2\n",
    "        })\n",
    "        return models['mu2.hat'].predict(fitdf).values\n",
    "    \n",
    "    # Define functions for cs1pcs2, cs1pc2, and cs1\n",
    "    def pcs1pcs2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1,a2) * cp2(a1,a2) * sp2(a1,a2)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    # attach data_filled\n",
    "    A1 = data_filled.A1; A2 = data_filled.A2; \n",
    "    C1 = data_filled.C1; C2 = data_filled.C2; \n",
    "    S1 = data_filled.S1; S2 = data_filled.S2; Y = data_filled.Y\n",
    "\n",
    "    e1val = pcs1(d1)\n",
    "    e2val = pcs1pcs2(d1, d2)\n",
    "    q1val = q1(d1)\n",
    "    q2val = q2(d1, d2)\n",
    "\n",
    "    w1 = (A1==d1)*(A2==d2)*(1-C1)*(1-C2)*S1*S2 / e2val\n",
    "    w2 = (A1==d1)*(1-C1)*S1 / e1val\n",
    "    aipw1 = w1 * (Y - q2val)\n",
    "    aipw2 = w2 * (q2val - q1val)\n",
    "    aipw3 = q1val\n",
    "\n",
    "    aipw1 = aipw1.sum() / w1.sum()\n",
    "    aipw2 = aipw2.sum() / w2.sum()\n",
    "    aipw3 = aipw3.mean()\n",
    "    \n",
    "    return aipw1 + aipw2 + aipw3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dfa3a9d-5aff-4358-8332-3ae8b104c06a",
   "metadata": {},
   "outputs": [],
   "source": [
    "mdl_aipw = fit_models_aipcw(XX_obs15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a920c2e4-7c05-4521-b0dd-a35e9f72860e",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "d1_true = reg1(*eta_true[:2], XX_obs15)\n",
    "d2_true = reg2(*eta_true[2:], XX_obs15)\n",
    "\n",
    "d1_true.mean(), d2_true.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96774e79-8b59-4b16-9d8a-e19958cecbf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "V_aipcw(d1_true, d2_true, XX_obs15, mdl_aipw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5465e07-9b9c-41ce-a89d-00be2bc5ed06",
   "metadata": {},
   "outputs": [],
   "source": [
    "V_MR(d1_true, d2_true, XX_obs15, models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7879b078-4eaf-4e3c-8c8c-0d853fec92d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def obj_aipw(etas, reg1, reg2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the SDMR estimator.\n",
    "    \n",
    "    Args:\n",
    "      etas: Parameters for decision rules.\n",
    "      reg1, reg2: Functions to estimate decision rules.\n",
    "      data: pandas DataFrame containing the data.\n",
    "      models: Dictionary containing fitted nuisance models.\n",
    "      apply_penalty: Boolean flag indicating whether to apply the constraints.\n",
    "    \n",
    "    Returns:\n",
    "      A float representing the SDMR estimator.\n",
    "    \"\"\"\n",
    "    # Estimate decision rules\n",
    "    d1 = reg1(*etas[:2], data)\n",
    "    d2 = reg2(*etas[2:], data)\n",
    "    \n",
    "    # Calculate MR estimator (replace V.MR with your implementation)\n",
    "    value = V_aipcw(d1, d2, data, models)\n",
    "    \n",
    "    return -value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e654d858-422a-40e0-8ea4-80198553fa50",
   "metadata": {},
   "outputs": [],
   "source": [
    "def simulate_aipw(m, eta2):\n",
    "    \"\"\"\n",
    "    Simulates data and calculates estimators under different settings.\n",
    "    \n",
    "    Args:\n",
    "      m: Number of simulations.\n",
    "      \n",
    "    Returns:\n",
    "      A namedtuple containing simulation results:\n",
    "          value_hat: Matrix of estimated values (N / D) for different settings.\n",
    "          Nhats: Matrix of N estimates for different settings.\n",
    "          Dhats: Matrix of D estimates for different settings.\n",
    "    \"\"\"\n",
    "    \n",
    "    # Define namedtuple for results\n",
    "    Result = namedtuple('Result', ['value_hat', 'decision'])\n",
    "    \n",
    "    # Initialize results matrices\n",
    "    value_hat = np.zeros((len(Ns), 3+1))\n",
    "    decisions = []\n",
    "\n",
    "    beta_ = []\n",
    "    for i in range(len(Ns)):\n",
    "        # Simulate data\n",
    "        N = Ns[i]\n",
    "\n",
    "        fpath = f\"./simdata/[2]data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        if os.path.exists(fpath):\n",
    "            dat = read_pkl(fpath)\n",
    "        else:\n",
    "            dat = genData(N, eta2)\n",
    "            write_pkl(dat, fpath)\n",
    "        \n",
    "        # Separate observed data and all possible outcomes\n",
    "        XX_obs = dat['obs']\n",
    "        XX_all = dat['all']\n",
    "        XX_as = dat['as']\n",
    "        \n",
    "        # Fit nuisance models\n",
    "        models = fit_models_aipcw(XX_obs)  # Replace ... with arguments\n",
    "\n",
    "        # with ThreadPool(processes=mp.cpu_count()-1) as pool:\n",
    "        res_aipw = sp.optimize.differential_evolution(obj_aipw, bounds=[(-1., 1.)]*5,\n",
    "                                        args=(reg1, reg2, XX_obs, models),\n",
    "                                        # workers=pool.map,\n",
    "                                      # x0=np.random.uniform(high=.9, low=-.3, size=5),\n",
    "                                        x0=np.zeros(5),\n",
    "                                      # method=\"SLSQP\",\n",
    "                                        constraints=[norm1const1, norm1const2])\n",
    "        \n",
    "        beta_.append([res_aipw.x])\n",
    "\n",
    "    # return Result(value_hat, decisions)\n",
    "    return beta_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57d7d37f-2d39-4026-a27e-f4536460f57a",
   "metadata": {},
   "outputs": [],
   "source": [
    "Ns = [1000, 2000, 5000]\n",
    "M = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e8b53f5-27b1-420c-8da1-1ef6dd186685",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_opt_aipw = []\n",
    "for m in tqdm(range(M)):\n",
    "    try:\n",
    "        fpath = f\"./simresult/aipw_opt_sim1_{m:03d}.pkl\"\n",
    "        beta_opt_m = read_pkl(fpath)\n",
    "        # beta_opt_m = [x[0] for x in beta_opt_m]  # only MR. no PG result\n",
    "        beta_opt_aipw.append(beta_opt_m)\n",
    "    except:\n",
    "        print(m, end=', ')\n",
    "\n",
    "beta_opt_aipw = np.array(beta_opt_aipw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28ca02a7-ebe0-4beb-8252-11d66ebfca07",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "Vhataipw_opl_single = []\n",
    "\n",
    "for N in Ns:\n",
    "    for m in tqdm(range(M)):\n",
    "        # load data and beta_opt\n",
    "        fpath = f\"./simdata/data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        dat_obs = read_pkl(fpath)['obs']\n",
    "        beta_opt_m = beta_opt_aipw[m]\n",
    "        \n",
    "        # decision rule params\n",
    "        if N == Ns[0]:\n",
    "            beta_aipw_dat = beta_opt_m[0][0]  # N==1000\n",
    "        elif N == Ns[1]:\n",
    "            beta_aipw_dat = beta_opt_m[1][0]  # N==2000\n",
    "        else:\n",
    "            beta_aipw_dat = beta_opt_m[2][0]  # N==5000\n",
    "        \n",
    "        # fixed decisions\n",
    "        d1aipw_dat = reg1(*beta_aipw_dat[:2], dat_obs)\n",
    "        d2aipw_dat = reg2(*beta_aipw_dat[2:], dat_obs)\n",
    "        \n",
    "        # d1pg_dat = reg1(*beta_pg_dat[:2], dat_obs)\n",
    "        # d2pg_dat = reg2(*beta_pg_dat[2:], dat_obs)\n",
    "        \n",
    "        # Fit nuisance models\n",
    "        model_dat = fit_models_aipcw(dat_obs)  # Replace ... with arguments\n",
    "        \n",
    "        # MR estimator\n",
    "        # Plugin estimators (replace function names with yours)\n",
    "        vaipw_dat = V_aipcw(d1aipw_dat, d2aipw_dat, dat_obs, model_dat)\n",
    "        # vpg_dat = V_plugin(d1pg_dat, d2pg_dat, dat_obs, model_dat)\n",
    "\n",
    "        Vhataipw_opl_single.append([vaipw_dat, None])\n",
    "\n",
    "Vhataipw_opl_single = np.array(Vhataipw_opl_single).reshape(len(Ns), M, -1)  # Ns, M, mrpg\n",
    "write_pkl(Vhataipw_opl_single, \"Vhat_aipw_opl_single_sim1.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bbc77c0-0453-4079-a3c1-043df1880d7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ```python\n",
    "Vaipw_opl_single = []\n",
    "\n",
    "dat_obs = df_large['obs'].iloc[:100000, :]\n",
    "tm10k = true_models(dat_obs)\n",
    "\n",
    "for N in Ns:\n",
    "    for m in tqdm(range(M)):\n",
    "        # load data and beta_opt\n",
    "        # fpath = f\"./simdata/data_NM_{N:04d}_{m:03d}_{eta2:01d}.pkl\"\n",
    "        # dat_obs = read_pkl(fpath)['obs']\n",
    "        beta_opt_m = beta_opt_aipw[m]\n",
    "        \n",
    "        # decision rule params\n",
    "        if N == Ns[0]:\n",
    "            beta_aipw_dat = beta_opt_m[0][0]  # N==1000\n",
    "        elif N == Ns[1]:\n",
    "            beta_aipw_dat = beta_opt_m[1][0]  # N==2000\n",
    "        else:\n",
    "            beta_aipw_dat = beta_opt_m[2][0]  # N==5000\n",
    "        \n",
    "        # fixed decisions\n",
    "        d1aipw_dat = reg1(*beta_aipw_dat[:2], dat_obs)\n",
    "        d2aipw_dat = reg2(*beta_aipw_dat[2:], dat_obs)\n",
    "        \n",
    "        # d1pg_dat = reg1(*beta_pg_dat[:2], dat_obs)\n",
    "        # d2pg_dat = reg2(*beta_pg_dat[2:], dat_obs)\n",
    "        \n",
    "        # # Fit nuisance models\n",
    "        # tm_dat = true_models(dat_obs)\n",
    "        \n",
    "        # MR estimator\n",
    "        # 'true' value under best learned decision rules\n",
    "        vaipw_dat = V_plugin(d1aipw_dat, d2aipw_dat, dat_obs, tm10k)\n",
    "        # vpg_dat = V_plugin(d1pg_dat, d2pg_dat, dat_obs, tm_dat)\n",
    "\n",
    "        Vaipw_opl_single.append([vaipw_dat, None])\n",
    "\n",
    "\n",
    "Vaipw_opl_single = np.array(Vaipw_opl_single).reshape(len(Ns), M, -1)  # Ns, M, mrpg\n",
    "write_pkl(Vaipw_opl_single, \"V_aipw_opl_single_sim1.pkl\")\n",
    "# ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ba50f63-dda9-4e98-9af8-1ff0aa19af3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vhat_opl_single = read_pkl(\"Vhat_opl_single_nm.pkl\")\n",
    "V_opl_single = read_pkl(\"V_opl_single_nm.pkl\")\n",
    "Vhataipw_opl_single = read_pkl(\"Vhat_aipw_opl_single_sim1.pkl\")\n",
    "Vaipw_opl_single = read_pkl(\"V_aipw_opl_single_sim1.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d4de0f7-81b2-40b4-93da-1ef847941415",
   "metadata": {},
   "outputs": [],
   "source": [
    "pi_opt = read_pkl(\"linear_opt_rule.pkl\")\n",
    "d1_opt = reg1(*pi_opt[:2], df_large[\"obs\"])\n",
    "d2_opt = reg2(*pi_opt[2:], df_large[\"obs\"])\n",
    "V_true_id_opt = V_plugin(d1_opt, d2_opt, df_large[\"obs\"], tm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c318f61-cf59-4c36-99fa-7d520ccc40b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vall_opl_single = np.c_[Vhat_opl_single[:,:,[0]], Vhataipw_opl_single[:,:,[0]],\n",
    "                        V_opl_single[:,:,[0]], Vaipw_opl_single[:,:,[0]]]\n",
    "Vall_opl_single.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab563388-337e-460c-8611-61d36c0bcb4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,2.4), dpi=1200)\n",
    "\n",
    "ax2 = plt.subplot(121)\n",
    "stylize_axes(ax2)\n",
    "plt.boxplot(\n",
    "    Vall_opl_single[1,:,:], tick_labels=['MR\\n\\t  $\\\\widehat V$', 'AIPW', \n",
    "                                         'MR\\n\\t  $V$', 'AIPW'],\n",
    "    sym='x', flierprops={'markersize': 3, 'markeredgewidth': .25}\n",
    ")\n",
    "# plt.xticks(rotation=-90)\n",
    "plt.ylim(18.5, 21)\n",
    "plt.title(f\"N={Ns[1]}\")\n",
    "plt.axhline(V_true_id_opt, color='r', ls=\"-\", lw=.75)\n",
    "plt.axvline(2.5, c='k', lw=.5, alpha=.75)\n",
    "# plt.axhline(V_true, color='r', ls=\"-\", lw=.75, label='True V');\n",
    "# plt.axvline(7.5, color='k', ls=\"-\", lw=1);\n",
    "plt.ylabel(\"Always-survivor value\")\n",
    "\n",
    "ax3 = plt.subplot(122)\n",
    "stylize_axes(ax3)\n",
    "plt.boxplot(\n",
    "    Vall_opl_single[2,:,:], tick_labels=['MR\\n\\t  $\\\\widehat V$', 'AIPW', \n",
    "                                         'MR\\n\\t  $V$', 'AIPW'],\n",
    "    sym='x', flierprops={'markersize': 3, 'markeredgewidth': .25}\n",
    ")\n",
    "# plt.xticks(rotation=-90)\n",
    "# plt.ylim(V_true*0.05, V_true*5)\n",
    "plt.ylim(18.5, 21)\n",
    "plt.yticks([], [])\n",
    "plt.title(f\"N={Ns[2]}\")\n",
    "plt.axhline(V_true_id_opt, color='r', ls=\"-\", lw=.75)\n",
    "plt.axvline(2.5, c='k', lw=.5, alpha=.75)\n",
    "\n",
    "# plt.suptitle(\"$\\\\hat V$ & $V$\")\n",
    "plt.tight_layout();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f81975b0-3c79-497d-bf5c-1cbb701979c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_large['obs'].C1.mean(), df_large['obs'].C2.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60d5b0ae-6429-4778-a6ae-0ae1276c6d71",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_large['obs'].S1.mean(), df_large['obs'].S2.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7909c01-13e2-4e05-9659-3e0d0fe63933",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34b34ce1-0415-4bf5-9c90-20ba3824b78d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {
    "height": "328px",
    "width": "477px"
   },
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "231022) Survivor-optimal DynTrtRegime",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "214.578px"
   },
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
