{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aad825dd-a045-4afe-aee0-9e42d84c5a98",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d98178e8-c192-4ae9-9aff-71de57c8d94f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import pandas as pd\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "# import rpy2.robjects as robjects\n",
    "import statsmodels.api as sm\n",
    "import statsmodels.formula.api as smf\n",
    "from collections import namedtuple\n",
    "from scipy.stats import norm\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "652969e2-9669-47fd-8e4d-2e0bb8d6e710",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pygam import LinearGAM, LogisticGAM, s, l\n",
    "from pygam.terms import TermList, SplineTerm, LinearTerm, FactorTerm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "988c3bdf-786d-4956-beef-5442211c046a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestRegressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b48e1e7c-a4c7-42e3-bfc0-e94bcae6574b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dill\n",
    "\n",
    "def write_pkl(obj, fname):\n",
    "    with open(fname, 'wb') as wb:\n",
    "        dill.dump(obj, wb)\n",
    "    return\n",
    "\n",
    "def read_pkl(fname):\n",
    "    with open(fname, 'rb') as rb:\n",
    "        obj = dill.load(rb)\n",
    "    return obj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d45e28b-5148-47a4-b099-183931f54b75",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.options.mode.chained_assignment = None  # default='warn'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6317f85-4b9a-4d12-b8d7-c5adb9122ac9",
   "metadata": {},
   "source": [
    "# Functions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f79ae9fa-64ab-4f73-953b-76087a74edfc",
   "metadata": {},
   "source": [
    "## SDTR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1644a7bd-00c2-4f87-bfbd-23d7abe76b61",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_models(data, X1lab, X2lab):\n",
    "    \"\"\"\n",
    "    Fits various nuisance models based on boolean flags.\n",
    "\n",
    "    Args:\n",
    "        data: A pandas DataFrame containing the observed data.\n",
    "        phi1_false, phi2_false, etc.: Boolean flags indicating whether to use\n",
    "            a point-five model instead of fitting the corresponding model.\n",
    "        Ep2_false, Emu2p2_false: (Currently unused)\n",
    "\n",
    "    Returns:\n",
    "        A dictionary containing fitted models and prediction functions.\n",
    "    \"\"\"\n",
    "    n1 = len(X1lab)\n",
    "    n2 = len(X2lab)\n",
    "    ndata = data.shape[0]\n",
    "    \n",
    "    models = {}\n",
    "    \n",
    "    term_list1 = \" + \".join(X1lab)\n",
    "    term_list1a = term_list1 + ' + A1'\n",
    "    models['phi1.hat'] = smf.glm(\"A1 ~ 1 + \" + term_list1, data=data, \n",
    "                                 family=sm.families.Binomial()).fit()\n",
    "    models['K1.hat'] = smf.glm(\"I(1-C1) ~ 1 + \" + term_list1a, data=data, \n",
    "                               family=sm.families.Binomial()).fit()\n",
    "    models['p1.hat'] = smf.glm(\"S1 ~ 1 + \" + term_list1a, data=data, \n",
    "                               family=sm.families.Binomial()).fit()\n",
    "    \n",
    "    term_list2 = \" + \".join(X1lab+X2lab+['A1'])\n",
    "    term_list2a = term_list2 + ' + A2'\n",
    "    models['phi2.hat'] = smf.glm(\"A2 ~ 1 + \" + term_list2, data=data, \n",
    "                                 family=sm.families.Binomial()).fit()\n",
    "    models['K2.hat'] = smf.glm(\"I(1-C2) ~ 1 + \" + term_list2a, data=data, \n",
    "                               family=sm.families.Binomial()).fit()\n",
    "    models['p2.hat'] = smf.glm(\"S2 ~ 1 + \" + term_list2a, data=data, \n",
    "                               family=sm.families.Binomial()).fit()\n",
    "\n",
    "    data_y = data.loc[(data.C2 == 0) & (data.S2 == 1), :]\n",
    "    mu2_model = RandomForestRegressor()\n",
    "    models['mu2.hat'] = mu2_model.fit(data_y[X1lab+X2lab+['A1', 'A2']], data_y.Y)\n",
    "\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    models['ndata'] = ndata\n",
    "    models['n1'] = n1\n",
    "    models['n2'] = n2\n",
    "    models['X1lab'] = X1lab\n",
    "    models['X2lab'] = X2lab\n",
    "\n",
    "    # models['ps1'] = ps1\n",
    "    # models['cp1'] = cp1\n",
    "    # models['sp1'] = sp1\n",
    "    # models['ps2'] = ps2\n",
    "    # models['cp2'] = cp2\n",
    "    # models['sp2'] = sp2\n",
    "    # models['m2'] = m2\n",
    "\n",
    "    # Return dictionary with models and prediction functions\n",
    "    return models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ede25ae-835f-4482-8fc9-82c821272784",
   "metadata": {},
   "outputs": [],
   "source": [
    "def m_p200(data, models):\n",
    "    n1 = models['n1']\n",
    "    n2 = models['n2']\n",
    "    ndata = data.shape[0]\n",
    "    X1lab = models['X1lab']\n",
    "    X2lab = models['X2lab']\n",
    "\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        pred = models['phi1.hat'].predict(fitdf)\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['K1.hat'].predict(fitdf)\n",
    "\n",
    "    def sp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['p1.hat'].predict(fitdf)\n",
    "\n",
    "    def sp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "\n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['p2.hat'].predict(fitdf)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    # Conditional mean of p_2^0(H_1)\n",
    "    p200x = sp2(0, 0)\n",
    "    data_filled[\"target\"] = p200x  # Add target column\n",
    "    data_filtered = data_filled[np.logical_and.reduce([data_filled.A1 == 0, \n",
    "                                                       data_filled.C1 == 0, \n",
    "                                                       data_filled.S1 == 1])]\n",
    "\n",
    "    # GAM\n",
    "    m_p200_model = LinearGAM().fit(data_filtered[X1lab], data_filtered.target)\n",
    "    return m_p200_model.predict(data_filled[X1lab])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd4e9520-dd5d-4c15-9ed9-532b0f1ee2a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def D_hat(data, models):\n",
    "    \"\"\"\n",
    "    Calculates the D-hat estimator based on data, fitted models, and Ep2.false flag.\n",
    "    \n",
    "    Args:\n",
    "      data: A NumPy array containing the data.\n",
    "      models: A dictionary containing fitted models (ps1, ps2, cp1, cp2, sp1, sp2).\n",
    "      Ep2_false: Boolean flag indicating whether to use a zero model for m_p200.\n",
    "    \n",
    "    Returns:\n",
    "      A float representing the D-hat estimator.\n",
    "    \"\"\"\n",
    "    n1 = models['n1']\n",
    "    n2 = models['n2']\n",
    "    ndata = data.shape[0]\n",
    "    X1lab = models['X1lab']\n",
    "    X2lab = models['X2lab']\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    def sp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['p1.hat'].predict(fitdf)\n",
    "    \n",
    "    # Conditional mean of p_2^0(H_1)\n",
    "    p10x = sp1(0)\n",
    "    E_p200 = m_p200(data, models)\n",
    "    \n",
    "    # D-hat calculation (empirical version)\n",
    "    Dhat = p10x * E_p200\n",
    "    \n",
    "    # Handle missing values and calculate mean\n",
    "    Dhat[np.isnan(Dhat)] = 0\n",
    "    Dhat = np.mean(Dhat)\n",
    "    \n",
    "    return Dhat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f68196c-9438-49be-a740-8405ab342711",
   "metadata": {},
   "outputs": [],
   "source": [
    "def g_(d1, d2, data, models):\n",
    "    n1 = models['n1']\n",
    "    n2 = models['n2']\n",
    "    ndata = data.shape[0]\n",
    "    X1lab = models['X1lab']\n",
    "    X2lab = models['X2lab']\n",
    "    \n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "    \n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        pred = models['phi1.hat'].predict(fitdf)\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['K1.hat'].predict(fitdf)\n",
    "\n",
    "    def sp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['p1.hat'].predict(fitdf)\n",
    "\n",
    "    def m2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['mu2.hat'].predict(fitdf)\n",
    "    \n",
    "    data_filled[\"target\"] = m2(d1, d2)\n",
    "    data_filtered = data_filled[(data_filled.A1 == d1) & \n",
    "                                (data_filled.C1 == 0) & \n",
    "                                (data_filled.S1 == 1)]\n",
    "    # GAM\n",
    "    m_m2_model = LinearGAM().fit(data_filtered[X1lab], data_filtered.target)\n",
    "        \n",
    "    return m_m2_model.predict(data_filled[X1lab])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49c0b80f-7985-4c3d-8a5a-208a33cdc2c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def m_mu_pi(d1, d2, a1, a2, data, models):\n",
    "    n1 = models['n1']\n",
    "    n2 = models['n2']\n",
    "    ndata = data.shape[0]\n",
    "    X1lab = models['X1lab']\n",
    "    X2lab = models['X2lab']\n",
    "\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        pred = models['phi1.hat'].predict(fitdf)\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['K1.hat'].predict(fitdf)\n",
    "\n",
    "    def sp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['p1.hat'].predict(fitdf)\n",
    "\n",
    "    def m2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['mu2.hat'].predict(fitdf)\n",
    "        \n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "    \n",
    "    data_filled[\"target\"] = m2(a1, a2) * (d2 == a2)\n",
    "    data_filtered = data_filled[(data_filled.A1 == a1) & \n",
    "                                (data_filled.C1 == 0) & \n",
    "                                (data_filled.S1 == 1)]\n",
    "    \n",
    "    if (data_filled.target.mean() != 0) and (data_filled.target.mean() != 1):\n",
    "        m_m2_model = LinearGAM().fit(data_filtered[X1lab], data_filtered.target)\n",
    "        m_m2 = m_m2_model.predict(data_filled[X1lab])\n",
    "    else:\n",
    "        m_m2 = np.zeros(data_filled.shape[0])\n",
    "    \n",
    "    return m_m2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b795285-3ba9-41d0-b9f0-e84628b5f2be",
   "metadata": {},
   "outputs": [],
   "source": [
    "def N_hat(d1, d2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the N-hat estimator based on data, fitted models, Ep2.false flag,\n",
    "    and additional arguments d1 and d2.\n",
    "\n",
    "    Args:\n",
    "        d1: A scalar value.\n",
    "        d2: A scalar value.\n",
    "        data: A NumPy array containing the data.\n",
    "        models: A dictionary containing fitted models (ps1, ps2, cp1, cp2, sp1, sp2, m2).\n",
    "        Ep2_false: Boolean flag indicating whether to use a zero model for m_p200.\n",
    "\n",
    "    Returns:\n",
    "        A float representing the N-hat estimator.\n",
    "    \"\"\"\n",
    "    n1 = models['n1']\n",
    "    n2 = models['n2']\n",
    "    ndata = data.shape[0]\n",
    "    X1lab = models['X1lab']\n",
    "    X2lab = models['X2lab']\n",
    "\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    def sp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['p1.hat'].predict(fitdf)\n",
    "\n",
    "    def sp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "\n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['p2.hat'].predict(fitdf)\n",
    "\n",
    "    def m2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['mu2.hat'].predict(fitdf)\n",
    "\n",
    "    # Conditional mean of p_2^0(H_1)\n",
    "    p10x = sp1(0)\n",
    "    p200x = sp2(0, 0)\n",
    "    # data_filled[\"target\"] = p200x  # Add target column\n",
    "    E_p200 = m_p200(data, models)\n",
    "\n",
    "    # g(X1)\n",
    "    g = g_(d1, d2, data, models)\n",
    "\n",
    "    # N-hat calculation (empirical version)\n",
    "    Nhat = g * p10x * E_p200\n",
    "\n",
    "    # Handle missing values and calculate mean\n",
    "    Nhat[np.isnan(Nhat)] = 0\n",
    "    Nhat = np.mean(Nhat)\n",
    "\n",
    "    return Nhat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf2708b2-7165-4f95-b880-9100f7d5fbe9",
   "metadata": {
    "vscode": {
     "languageId": "r"
    }
   },
   "outputs": [],
   "source": [
    "def V_plugin(d1, d2, data, models):\n",
    "    # eta11~24: parameters for decision rules\n",
    "    # reg1,2: decision rules\n",
    "    # data: data.frame\n",
    "    # models: fitted models from fit.models()\n",
    "    \n",
    "    ## decisions\n",
    "    #d1 = reg1(eta11, eta12, data)\n",
    "    #d2 = reg2(eta21, eta22, eta23, data)\n",
    "    \n",
    "    # numerator\n",
    "    val = N_hat(d1, d2, data, models) / D_hat(data, models)\n",
    "    \n",
    "    return val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43f38bb7-a1cc-4c64-9cf5-29ffbcf42835",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi_D_terms(data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of D.\n",
    "\n",
    "    Args:\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, Ep2.false, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of D.\n",
    "    \"\"\"\n",
    "\n",
    "    n1 = models['n1']\n",
    "    n2 = models['n2']\n",
    "    ndata = data.shape[0]\n",
    "    X1lab = models['X1lab']\n",
    "    X2lab = models['X2lab']\n",
    "\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        pred = models['phi1.hat'].predict(fitdf)\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['K1.hat'].predict(fitdf)\n",
    "\n",
    "    def sp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['p1.hat'].predict(fitdf)\n",
    "\n",
    "    def ps2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        pred = models['phi2.hat'].predict(fitdf)\n",
    "        pred[(~np.isnan(a2)) & (a2 == 0)] = 1 - pred[(~np.isnan(a2)) & (a2 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['K2.hat'].predict(fitdf)\n",
    "\n",
    "    def sp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "\n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['p2.hat'].predict(fitdf)\n",
    "    \n",
    "    # Define functions for cs1pc2 and cs1\n",
    "    def pcs1pc2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    def pc1(a1):\n",
    "        return ps1(a1) * cp1(a1)\n",
    "\n",
    "    # Conditional mean of p_2^00\n",
    "    p10x = sp1(0)\n",
    "    p200x = sp2(0, 0)\n",
    "    E_p200 = m_p200(data, models)\n",
    "\n",
    "    # attach data_filled\n",
    "    A1 = data_filled.A1; A2 = data_filled.A2; C1 = data_filled.C1; \n",
    "    C2 = data_filled.C2; S1 = data_filled.S1; S2 = data_filled.S2\n",
    "    \n",
    "    # EIF-based estimators of D\n",
    "    D11 = (1-A1)*(1-C1) / pc1(0) * (S1)\n",
    "    D12 = (1-A1)*(1-C1) / pc1(0) * (p10x)\n",
    "    D11[np.isnan(D11)] = 0\n",
    "    D12[np.isnan(D12)] = 0\n",
    "    D1 = (D11 - D12) * E_p200\n",
    "    \n",
    "    D211 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0) * (S2)\n",
    "    D212 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0) * p200x\n",
    "    D211[np.isnan(D211)] = 0\n",
    "    D212[np.isnan(D212)] = 0\n",
    "    D2 = p10x * (D211 - D212)\n",
    "\n",
    "    D221 = (1-A1)*(1-C1)*S1 / pcs1(0) * (p200x)\n",
    "    D222 = (1-A1)*(1-C1)*S1 / pcs1(0) * E_p200\n",
    "    D221[np.isnan(D221)] = 0\n",
    "    D222[np.isnan(D222)] = 0\n",
    "    D3 = p10x * (D221 - D222)\n",
    "\n",
    "    D4 = p10x * E_p200\n",
    "\n",
    "    return D1, D2, D3, D4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd2c357e-3b40-4a5e-8288-2a8f91b3c55a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi_D(data, models):\n",
    "    D1, D2, D3, D4 = phi_D_terms(data, models)\n",
    "    return D1 + D2 + D3 + D4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3274eeb1-8c14-4c3e-a84e-d9e7e2d31e0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def D_MR(data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of D.\n",
    "\n",
    "    Args:\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, Ep2.false, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of D.\n",
    "    \"\"\"\n",
    "    n1 = models['n1']\n",
    "    n2 = models['n2']\n",
    "    ndata = data.shape[0]\n",
    "    X1lab = models['X1lab']\n",
    "    X2lab = models['X2lab']\n",
    "\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        pred = models['phi1.hat'].predict(fitdf)\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['K1.hat'].predict(fitdf)\n",
    "\n",
    "    def sp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['p1.hat'].predict(fitdf)\n",
    "\n",
    "    def ps2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        pred = models['phi2.hat'].predict(fitdf)\n",
    "        pred[(~np.isnan(a2)) & (a2 == 0)] = 1 - pred[(~np.isnan(a2)) & (a2 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['K2.hat'].predict(fitdf)\n",
    "\n",
    "    def sp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "\n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['p2.hat'].predict(fitdf)\n",
    "    \n",
    "    # Define functions for cs1pc2 and cs1\n",
    "    def pcs1pc2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1, a2) * cp2(a1, a2)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    def pc1(a1):\n",
    "        return ps1(a1) * cp1(a1)\n",
    "\n",
    "    # attach data_filled\n",
    "    A1 = data_filled.A1; A2 = data_filled.A2; C1 = data_filled.C1; \n",
    "    C2 = data_filled.C2; S1 = data_filled.S1; S2 = data_filled.S2\n",
    "\n",
    "    D1, D2, D3, D4 = phi_D_terms(data, models)\n",
    "\n",
    "    # stabilized estimator\n",
    "    w1 = (1-A1)*(1-C1) / pc1(0)\n",
    "    w1[np.isnan(w1)] = 0\n",
    "    w1 = sum(w1)\n",
    "    D1mean = D1.sum() / w1\n",
    "\n",
    "    w2 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0)\n",
    "    w2[np.isnan(w2)] = 0\n",
    "    w2 = sum(w2)\n",
    "    D2mean = D2.sum() / w2\n",
    "\n",
    "    w3 = (1-A1)*(1-C1)*S1 / pcs1(0)\n",
    "    w3[np.isnan(w3)] = 0\n",
    "    w3 = sum(w3)\n",
    "    D3mean = D3.sum() / w3\n",
    "\n",
    "    # print(w1, w2, w3)\n",
    "\n",
    "    return D1mean + D2mean + D3mean + D4.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e28e88d6-167f-4c42-97a7-30815055bea3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi_N_terms(d1, d2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of N.\n",
    "\n",
    "    Args:\n",
    "        d1: First decision rule estimate.\n",
    "        d2: Second decision rule estimate.\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, m2, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of N.\n",
    "    \"\"\"\n",
    "\n",
    "    n1 = models['n1']\n",
    "    n2 = models['n2']\n",
    "    ndata = data.shape[0]\n",
    "    X1lab = models['X1lab']\n",
    "    X2lab = models['X2lab']\n",
    "\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        pred = models['phi1.hat'].predict(fitdf)\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['K1.hat'].predict(fitdf)\n",
    "\n",
    "    def sp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['p1.hat'].predict(fitdf)\n",
    "\n",
    "    def ps2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        pred = models['phi2.hat'].predict(fitdf)\n",
    "        pred[(~np.isnan(a2)) & (a2 == 0)] = 1 - pred[(~np.isnan(a2)) & (a2 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['K2.hat'].predict(fitdf)\n",
    "\n",
    "    def sp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "\n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['p2.hat'].predict(fitdf)\n",
    "\n",
    "    def m2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['mu2.hat'].predict(fitdf)\n",
    "    \n",
    "    # Define functions for cs1pcs2, cs1pc2, and cs1\n",
    "    def pcs1pcs2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1,a2) * cp2(a1,a2) * sp2(a1,a2)\n",
    "\n",
    "    def pcs1pc2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1,a2) * cp2(a1,a2)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    def pc1(a1):\n",
    "        return ps1(a1) * cp1(a1)\n",
    "\n",
    "    # Conditional means\n",
    "    p10x = sp1(0)\n",
    "    p200x = sp2(0, 0)\n",
    "    E_p200 = m_p200(data, models)\n",
    "    E_mu2pi = g_(d1, d2, data, models)\n",
    "\n",
    "    # attach data_filled\n",
    "    A1 = data.A1; A2 = data.A2; \n",
    "    C1 = data.C1; C2 = data.C2; \n",
    "    S1 = data.S1; S2 = data.S2; Y = data.Y\n",
    "    \n",
    "    # EIF-based estimators of N\n",
    "    def N11(a1, a2):\n",
    "        val = (A1==a1)*(A2==a2)*(1-C1)*(1-C2)*S1*S2 / pcs1pcs2(a1,a2) * \\\n",
    "              (Y - m2(a1,a2)) * (d2==a2) * (d1==a1)\n",
    "        val[np.isnan(val)] = 0\n",
    "        return val\n",
    "\n",
    "    def N12(a1, a2):\n",
    "        val = (A1==a1)*(1-C1)*S1 / pcs1(a1) * (m2(a1,a2)*(d2==a2) \n",
    "                                               - m_mu_pi(d1, d2, a1, a2, \n",
    "                                                         data, models)) * (d1==a1)\n",
    "        # val = (A1==a1)*(1-C1)*S1 / pcs1(a1) * \\\n",
    "        #       (m2(a1,a2)*(d2==a2)*(d1==a1) - E_mu2pi)\n",
    "        val[np.isnan(val)] = 0\n",
    "        return val\n",
    "\n",
    "    N111_ = (A1==d1)*(A2==d2)*(1-C1)*(1-C2)*S1*S2 / pcs1pcs2(d1,d2) * Y #N11(0, 0) + N11(1, 0) + N11(0, 1) + N11(1, 1)\n",
    "    N112_ = (A1==d1)*(A2==d2)*(1-C1)*(1-C2)*S1*S2 / pcs1pcs2(d1,d2) * m2(d1,d2)\n",
    "    N111_[np.isnan(N111_)] = 0\n",
    "    N112_[np.isnan(N112_)] = 0\n",
    "    N11_ = N111_ - N112_\n",
    "    N121_ = N12(0, 0) + N12(1, 0) + N12(0, 1) + N12(1, 1)\n",
    "    N121_[np.isnan(N121_)] = 0\n",
    "    N12_ = N121_ #- N122_\n",
    "    N11 = (N11_) * p10x * E_p200\n",
    "    N12 = (N12_) * p10x * E_p200\n",
    "\n",
    "    N211 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0) * S2\n",
    "    N212 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0) * p200x  # same as D1\n",
    "    N211[np.isnan(N211)] = 0\n",
    "    N212[np.isnan(N212)] = 0\n",
    "    N21_ = N211 - N212\n",
    "    N221 = (1-A1)*(1-C1)*S1 / pcs1(0) * p200x  # same as D2\n",
    "    N222 = (1-A1)*(1-C1)*S1 / pcs1(0) * E_p200\n",
    "    N221[np.isnan(N221)] = 0\n",
    "    N222[np.isnan(N222)] = 0\n",
    "    N22_ = N221 - N222\n",
    "    N21 = E_mu2pi * p10x * (N21_)\n",
    "    N22 = E_mu2pi * p10x * (N22_)\n",
    "\n",
    "    N311 = (1-A1)*(1-C1) / pc1(0) * (S1)\n",
    "    N312 = (1-A1)*(1-C1) / pc1(0) * (p10x)\n",
    "    N311[np.isnan(N311)] = 0\n",
    "    N312[np.isnan(N312)] = 0\n",
    "    N31 = N311 - N312\n",
    "    N3 = E_mu2pi * N31 * E_p200\n",
    "\n",
    "    N4 = E_mu2pi * p10x * E_p200\n",
    "\n",
    "    return N11, N12, N21, N22, N3, N4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89edc671-3c48-45e1-99b8-c4cc769d2e3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def phi_N(d1, d2, data, models):\n",
    "    N11, N12, N21, N22, N3, N4 = phi_N_terms(d1, d2, data, models)\n",
    "    return N11 + N12 + N21 + N22 + N3 + N4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a41d11e8-5c24-473d-9d60-b7de8ff90ee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def N_MR(d1, d2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of N.\n",
    "\n",
    "    Args:\n",
    "        d1: First decision rule estimate.\n",
    "        d2: Second decision rule estimate.\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, m2, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of N.\n",
    "    \"\"\"\n",
    "    n1 = models['n1']\n",
    "    n2 = models['n2']\n",
    "    ndata = data.shape[0]\n",
    "    X1lab = models['X1lab']\n",
    "    X2lab = models['X2lab']\n",
    "\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        pred = models['phi1.hat'].predict(fitdf)\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['K1.hat'].predict(fitdf)\n",
    "\n",
    "    def sp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['p1.hat'].predict(fitdf)\n",
    "\n",
    "    def ps2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        pred = models['phi2.hat'].predict(fitdf)\n",
    "        pred[(~np.isnan(a2)) & (a2 == 0)] = 1 - pred[(~np.isnan(a2)) & (a2 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['K2.hat'].predict(fitdf)\n",
    "\n",
    "    def sp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "\n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['p2.hat'].predict(fitdf)\n",
    "\n",
    "    def m2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['mu2.hat'].predict(fitdf)\n",
    "\n",
    "    # Define functions for cs1pcs2, cs1pc2, and cs1\n",
    "    def pcs1pcs2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1,a2) * cp2(a1,a2) * sp2(a1,a2)\n",
    "\n",
    "    def pcs1pc2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1,a2) * cp2(a1,a2)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    def pc1(a1):\n",
    "        return ps1(a1) * cp1(a1)\n",
    "    \n",
    "    # attach data_filled\n",
    "    A1 = data_filled.A1; A2 = data_filled.A2; \n",
    "    C1 = data_filled.C1; C2 = data_filled.C2; \n",
    "    S1 = data_filled.S1; S2 = data_filled.S2; Y = data_filled.Y\n",
    "\n",
    "    # stabilize mean of terms\n",
    "    N11, N12, N21, N22, N3, N4 = phi_N_terms(d1, d2, data, models)\n",
    "    w11 = (A1==d1)*(A2==d2)*(1-C1)*(1-C2)*S1*S2 / pcs1pcs2(d1,d2)\n",
    "    w12 = (A1==d1)*(1-C1)*S1 / pcs1(d1)\n",
    "    w21 = (1-A1)*(1-A2)*(1-C1)*(1-C2)*S1 / pcs1pc2(0, 0)\n",
    "    w22 = (1-A1)*(1-C1)*S1 / pcs1(0)\n",
    "    w3 = (1-A1)*(1-C1) / pc1(0)\n",
    "\n",
    "    w11[np.isnan(w11)] = 0\n",
    "    w12[np.isnan(w12)] = 0\n",
    "    w21[np.isnan(w21)] = 0\n",
    "    w22[np.isnan(w22)] = 0\n",
    "    w3[np.isnan(w3)] = 0\n",
    "\n",
    "    w11 = sum(w11) + 1e-9; w21 = sum(w21) + 1e-9; w3 = sum(w3) + 1e-9\n",
    "    w12 = sum(w12) + 1e-9; w22 = sum(w22) + 1e-9\n",
    "\n",
    "    N11mean = N11.sum() / w11\n",
    "    N12mean = N12.sum() / w12\n",
    "    N21mean = N21.sum() / w21\n",
    "    N22mean = N22.sum() / w22\n",
    "    N3mean = N3.sum() / w3\n",
    "    \n",
    "    return N11mean + N12mean + N21mean + N22mean + N3mean + N4.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f68eb0b6-27f3-4099-b20f-c666204f544b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# EIF-based estimator of V(\\pi) = N/D\n",
    "def V_MR(d1, d2, data, models):\n",
    "    return N_MR(d1, d2, data, models) / D_MR(data, models)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "929d157c-e329-4a07-a8b8-38ed615e9142",
   "metadata": {},
   "source": [
    "## AIPCW"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6a03659-3678-4d43-9f5b-d4c3b8090690",
   "metadata": {},
   "source": [
    "$\\widetilde C = C(1-S)$. i.e. consider death as missing."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ba034c6-3db7-458f-ba51-04c80aa3f301",
   "metadata": {},
   "source": [
    "$$\n",
    "\\widehat{\\mathbb E} \n",
    "  \\left[ \\frac{(\\bar A=\\bar d)(\\bar1-\\bar{\\widetilde C})}\n",
    "              {\\bar\\varphi\\bar K} (Y - Q_2) +\n",
    "         \\frac{(A_1=d_1)(1-\\widetilde C_1)}\n",
    "              {\\varphi_1 K_1} (Q_2 - Q_1) +\n",
    "         Q_1 \\right]\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06231a8d-74b8-4228-bd94-df42f5109089",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{align}\n",
    "Q_2(\\bar x, \\bar a) = \\mathbb E[Y | \\bar x, \\bar A=\\bar a, \\bar{\\widetilde C} = \\bar0],\\\\\n",
    "Q_1(x_1, a_1) = \\mathbb E[Y | x_1, A=a_1, \\bar{\\widetilde C}_1 = 0]\n",
    "\\end{align}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5716644f-e976-4d5d-81b9-d778a13a3f9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_models_aipcw(data, X1lab, X2lab):\n",
    "    \"\"\"\n",
    "    Fits various nuisance models based on boolean flags.\n",
    "\n",
    "    Args:\n",
    "        data: A pandas DataFrame containing the observed data.\n",
    "        phi1_false, phi2_false, etc.: Boolean flags indicating whether to use\n",
    "            a point-five model instead of fitting the corresponding model.\n",
    "        Ep2_false, Emu2p2_false: (Currently unused)\n",
    "\n",
    "    Returns:\n",
    "        A dictionary containing fitted models and prediction functions.\n",
    "    \"\"\"\n",
    "    n1 = len(X1lab)\n",
    "    n2 = len(X2lab)\n",
    "    ndata = data.shape[0]\n",
    "\n",
    "    models = {}\n",
    "    \n",
    "    term_list1 = \" + \".join(X1lab)\n",
    "    term_list1a = term_list1 + ' + A1'\n",
    "    models['phi1.hat'] = smf.glm(\"A1 ~ 1 + \" + term_list1, data=data, \n",
    "                                 family=sm.families.Binomial()).fit()\n",
    "    models['K1.hat'] = smf.glm(\"I(1-C1) ~ 1 + \" + term_list1a, data=data, \n",
    "                               family=sm.families.Binomial()).fit()\n",
    "    models['p1.hat'] = smf.glm(\"S1 ~ 1 + \" + term_list1a, data=data, \n",
    "                               family=sm.families.Binomial()).fit()\n",
    "\n",
    "    data_2 = data.loc[(data.C1 == 0) & (data.S1 == 1), :]\n",
    "    \n",
    "    term_list2 = \" + \".join(X1lab+X2lab+['A1'])\n",
    "    term_list2a = term_list2 + ' + A2'\n",
    "    models['phi2.hat'] = smf.glm(\"A2 ~ 1 + \" + term_list2, data=data, \n",
    "                                 family=sm.families.Binomial()).fit()\n",
    "    models['K2.hat'] = smf.glm(\"I(1-C2) ~ 1 + \" + term_list2a, data=data, \n",
    "                               family=sm.families.Binomial()).fit()\n",
    "    models['p2.hat'] = smf.glm(\"S2 ~ 1 + \" + term_list2a, data=data, \n",
    "                               family=sm.families.Binomial()).fit()\n",
    "\n",
    "    data_y1 = data.loc[(data.C1 == 0) & (data.S1 == 1), :]\n",
    "    data_y1 = data_y1.fillna(0)\n",
    "    mu1_model = RandomForestRegressor()\n",
    "    models['q1.hat'] = mu1_model.fit(data_y1[X1lab+['A1']], data_y1.Y)\n",
    "\n",
    "    data_y2 = data.loc[(data.C2 == 0) & (data.S2 == 1), :]\n",
    "    mu2_model = RandomForestRegressor()\n",
    "    models['q2.hat'] = mu2_model.fit(data_y2[X1lab+X2lab+['A1', 'A2']], data_y2.Y)\n",
    "\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    models['ndata'] = ndata\n",
    "    models['n1'] = n1\n",
    "    models['n2'] = n2\n",
    "    models['X1lab'] = X1lab\n",
    "    models['X2lab'] = X2lab\n",
    "\n",
    "    # Return dictionary with models and prediction functions\n",
    "    return models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c93ea7fc-2ed2-43ef-9b1b-61589472c91f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def V_aipcw(d1, d2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the Expected Information Fraction (EIF) based estimator of N.\n",
    "\n",
    "    Args:\n",
    "        d1: First decision rule estimate.\n",
    "        d2: Second decision rule estimate.\n",
    "        data: A pandas DataFrame containing the data.\n",
    "        models: A dictionary containing fitted models (various ps, cp, sp, m2, etc.).\n",
    "\n",
    "    Returns:\n",
    "        A float representing the EIF estimator of N.\n",
    "    \"\"\"\n",
    "    n1 = models['n1']\n",
    "    n2 = models['n2']\n",
    "    ndata = data.shape[0]\n",
    "    X1lab = models['X1lab']\n",
    "    X2lab = models['X2lab']\n",
    "\n",
    "    # Fill NA with 0\n",
    "    data_filled = data.copy()\n",
    "    data_filled.fillna(0, inplace=True)\n",
    "\n",
    "    def ps1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        pred = models['phi1.hat'].predict(fitdf)\n",
    "        pred[(~np.isnan(a1)) & (a1 == 0)] = 1 - pred[(~np.isnan(a1)) & (a1 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['K1.hat'].predict(fitdf)\n",
    "\n",
    "    def sp1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['p1.hat'].predict(fitdf)\n",
    "\n",
    "    def ps2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        pred = models['phi2.hat'].predict(fitdf)\n",
    "        pred[(~np.isnan(a2)) & (a2 == 0)] = 1 - pred[(~np.isnan(a2)) & (a2 == 0)]\n",
    "        return pred\n",
    "\n",
    "    def cp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['K2.hat'].predict(fitdf)\n",
    "\n",
    "    def sp2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "\n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['p2.hat'].predict(fitdf)\n",
    "\n",
    "    def q1(a1):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab]\n",
    "        fitdf['A1'] = a1\n",
    "        return models['q1.hat'].predict(fitdf)\n",
    "\n",
    "    def q2(a1, a2):\n",
    "        if isinstance(a1, int):\n",
    "            a1 = np.repeat(a1, ndata)\n",
    "\n",
    "        if isinstance(a2, int):\n",
    "            a2 = np.repeat(a2, ndata)\n",
    "            \n",
    "        fitdf = data_filled[X1lab + X2lab]\n",
    "        fitdf['A1'] = a1\n",
    "        fitdf['A2'] = a2\n",
    "        return models['q2.hat'].predict(fitdf)\n",
    "    \n",
    "    # Define functions for cs1pcs2, cs1pc2, and cs1\n",
    "    def pcs1pcs2(a1, a2):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1) * ps2(a1,a2) * cp2(a1,a2) * sp2(a1,a2)\n",
    "\n",
    "    def pcs1(a1):\n",
    "        return ps1(a1) * cp1(a1) * sp1(a1)\n",
    "\n",
    "    # attach data_filled\n",
    "    A1 = data_filled.A1; A2 = data_filled.A2; \n",
    "    C1 = data_filled.C1; C2 = data_filled.C2; \n",
    "    S1 = data_filled.S1; S2 = data_filled.S2; Y = data_filled.Y\n",
    "\n",
    "    e1val = pcs1(d1)\n",
    "    e2val = pcs1pcs2(d1, d2)\n",
    "    q1val = q1(d1)\n",
    "    q2val = q2(d1, d2)\n",
    "\n",
    "    w1 = (A1==d1)*(A2==d2)*(1-C1)*(1-C2)*S1*S2 / (e2val + 1e-9)\n",
    "    w2 = (A1==d1)*(1-C1)*S1 / (e1val + 1e-9)\n",
    "    aipw1 = w1 * (Y - q2val)\n",
    "    aipw2 = w2 * (q2val - q1val)\n",
    "    aipw3 = q1val\n",
    "\n",
    "    aipw1 = aipw1.sum() / w1.sum()\n",
    "    aipw2 = aipw2.sum() / w2.sum()\n",
    "    aipw3 = aipw3.mean()\n",
    "\n",
    "    return aipw1 + aipw2 + aipw3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "713456fb-7d17-4cfb-bfe1-ccc30fb107ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "def obj_aipw(betas, reg1, reg2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the SDMR estimator.\n",
    "    \n",
    "    Args:\n",
    "      etas: Parameters for decision rules.\n",
    "      reg1, reg2: Functions to estimate decision rules.\n",
    "      data: pandas DataFrame containing the data.\n",
    "      models: Dictionary containing fitted nuisance models.\n",
    "      apply_penalty: Boolean flag indicating whether to apply the constraints.\n",
    "    \n",
    "    Returns:\n",
    "      A float representing the SDMR estimator.\n",
    "    \"\"\"\n",
    "    X1lab = models['X1lab']\n",
    "    X2lab = models['X2lab']\n",
    "    n1 = models['n1']\n",
    "    X1 = data[X1lab]\n",
    "    X1X2 = data[X1lab+X2lab]\n",
    "\n",
    "    # Estimate decision rules\n",
    "    d1 = reg1(betas[:(n1+1)], X1)\n",
    "    d2 = reg2(betas[(n1+1):], X1X2)\n",
    "    \n",
    "    # Calculate MR estimator (replace V.MR with your implementation)\n",
    "    value = V_aipcw(d1, d2, data, models)\n",
    "    \n",
    "    return -value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "949bee09-8c48-463a-9477-934aa54f3761",
   "metadata": {},
   "source": [
    "## Functions related to policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1f55d28-20e3-4a21-845c-e2c6e9adfccc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# linear decision rules\n",
    "def reg1(beta1, X1):\n",
    "    criteria = np.c_[np.ones(X1.shape[0]), X1] @ beta1 >= 0\n",
    "    d1 = np.array(criteria, dtype=int)\n",
    "    return d1\n",
    "\n",
    "def reg2(beta2, X1X2):\n",
    "    criteria = np.c_[np.ones(X1X2.shape[0]), X1X2] @ beta2 >= 0\n",
    "    d2 = np.array(criteria, dtype=int)\n",
    "    return d2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecab42ba-bfca-4e2e-8a9c-7e92b65ea63e",
   "metadata": {},
   "source": [
    "Length of $\\boldsymbol\\beta = \\begin{pmatrix}\\boldsymbol\\beta_1 \\\\ \\boldsymbol\\beta_2\\end{pmatrix}$ \n",
    "should be $(8+1)+(1+15)=25$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7759eb79-c88d-448d-945b-1790cb8f89d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def norm1check1(betas):\n",
    "    beta1s = betas[:9]\n",
    "    return np.linalg.norm(beta1s, 2)\n",
    "\n",
    "def norm1check2(betas):\n",
    "    beta2s = betas[9:]\n",
    "    return np.linalg.norm(beta2s, 2)\n",
    "    \n",
    "norm1const1 = sp.optimize.NonlinearConstraint(norm1check1, 0.99, 1.)\n",
    "norm1const2 = sp.optimize.NonlinearConstraint(norm1check2, 0.99, 1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b9065e8-db36-4373-8fc0-104958253e1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "Decision = namedtuple('decision', ['d1', 'd2'])\n",
    "\n",
    "def decision(reg1, reg2, ga_sol, data, X1lab, X2lab):\n",
    "    \"\"\"\n",
    "    Makes decisions based on decision rules and GA solution.\n",
    "    \n",
    "    Args:\n",
    "      reg1: Function for the first decision rule.\n",
    "      reg2: Function for the second decision rule.\n",
    "      ga_sol: GA solution.\n",
    "      data: Data to use for decision making.\n",
    "    \n",
    "    Returns:\n",
    "      A dictionary containing d1 and d2 decisions.\n",
    "    \"\"\"\n",
    "    # Make decisions using the estimated parameters\n",
    "    x1 = data[X1lab]\n",
    "    x1x2 = data[X1lab+X2lab]\n",
    "    d1 = reg1(ga_sol[:9], x1)\n",
    "    d2 = reg2(ga_sol[9:], x1x2)\n",
    "    \n",
    "    return Decision(d1, d2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2b5e885-dd39-4885-8ac0-66010356af47",
   "metadata": {},
   "outputs": [],
   "source": [
    "def obj_sdmr(betas, reg1, reg2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the SDMR estimator.\n",
    "    \n",
    "    Args:\n",
    "      etas: Parameters for decision rules.\n",
    "      reg1, reg2: Functions to estimate decision rules.\n",
    "      data: pandas DataFrame containing the data.\n",
    "      models: Dictionary containing fitted nuisance models.\n",
    "      apply_penalty: Boolean flag indicating whether to apply the constraints.\n",
    "    \n",
    "    Returns:\n",
    "      A float representing the SDMR estimator.\n",
    "    \"\"\"\n",
    "    X1lab = models['X1lab']\n",
    "    X2lab = models['X2lab']\n",
    "    n1 = models['n1']\n",
    "    X1 = data[X1lab]\n",
    "    X1X2 = data[X1lab+X2lab]\n",
    "\n",
    "    # Estimate decision rules\n",
    "    d1 = reg1(betas[:(n1+1)], X1)\n",
    "    d2 = reg2(betas[(n1+1):], X1X2)\n",
    "    \n",
    "    # Calculate MR estimator (replace V.MR with your implementation)\n",
    "    value = V_MR(d1, d2, data, models)\n",
    "    \n",
    "    return -value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0f99f82-0356-48da-bcbc-06cbbcdefe03",
   "metadata": {},
   "outputs": [],
   "source": [
    "def obj_pg(betas, reg1, reg2, data, models):\n",
    "    \"\"\"\n",
    "    Calculates the SDMR estimator.\n",
    "    \n",
    "    Args:\n",
    "      etas: Parameters for decision rules.\n",
    "      reg1, reg2: Functions to estimate decision rules.\n",
    "      data: pandas DataFrame containing the data.\n",
    "      models: Dictionary containing fitted nuisance models.\n",
    "      apply_penalty: Boolean flag indicating whether to apply the constraints.\n",
    "    \n",
    "    Returns:\n",
    "      A float representing the SDMR estimator.\n",
    "    \"\"\"\n",
    "    X1lab = models['X1lab']\n",
    "    X2lab = models['X2lab']\n",
    "    n1 = models['n1']\n",
    "    X1 = data[X1lab]\n",
    "    X1X2 = data[X1lab+X2lab]\n",
    "\n",
    "    # Estimate decision rules\n",
    "    d1 = reg1(betas[:(n1+1)], X1)\n",
    "    d2 = reg2(betas[(n1+1):], X1X2)\n",
    "    \n",
    "    # Calculate MR estimator (replace V.MR with your implementation)\n",
    "    value = V_plugin(d1, d2, data, models)\n",
    "    \n",
    "    return -value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e186c763-92f1-41d9-9c1d-8398bc53c52a",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84cbb658-288b-49e5-a169-fca6d9e4f87e",
   "metadata": {},
   "source": [
    "# Run simulation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eecbfe9c",
   "metadata": {},
   "source": [
    "## Load processed data\n",
    "* Refer to `MIMIC3 preprocess.ipynb` for data processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d834d70-81ac-4971-b068-6e8353a6f37b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('MIMIC-III_sepsis_sdtr_sofa8.csv', index_col=0)\n",
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f7a621b-d4ed-456c-8c65-1f7f5a4b753c",
   "metadata": {},
   "outputs": [],
   "source": [
    "X1lab = [i for i in df.columns if i.startswith('X') and i.endswith('1')]\n",
    "X2lab = [i for i in df.columns if i.startswith('X') and i.endswith('2')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21f50d0b-a6f9-442f-ab8e-23cbfa07a5ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_beta = len(X1lab) + 1 + len(X1lab)+len(X2lab) + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c0d02b6-231b-44ed-88f6-e849a74745cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a65e0da3-6d9d-470a-91f8-60bcbf9b65fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "ttr, tte = train_test_split(df, test_size=0.5, \n",
    "                            stratify=df[['A1', 'A2', 'C1', 'C2', 'S1', 'S2']].fillna(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8a11045-a274-4f72-887a-cb4532547707",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.C1.mean(), np.isnan(df.C1).mean(), df.S1.mean(), np.isnan(df.S1).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d238abad-adc9-4b76-a735-5df4afdaad2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "tte.C1.mean(), tte.C2.mean(), tte.S1.mean(), tte.S2.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62d0f235-1215-470f-8902-9e33976024ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "ttr.C1.mean(), ttr.C2.mean(), ttr.S1.mean(), ttr.S2.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcb372e8-154b-42f8-84b2-1b967740dd16",
   "metadata": {},
   "source": [
    "## AIPW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab100d0e-f0a5-4b38-b2e2-f2ada76d8f6b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ```python\n",
    "res_aipw_train = []\n",
    "res_aipw_test = []\n",
    "\n",
    "for i in tqdm(range(50)):\n",
    "    df_train, df_test = train_test_split(df, test_size=0.5, random_state=i+1,\n",
    "                                         stratify=df[['C1', 'C2', 'S1', 'S2']].fillna(0))\n",
    "    \n",
    "    # train model\n",
    "    models_train = fit_models_aipcw(df_train, X1lab, X2lab)\n",
    "    models_test = fit_models_aipcw(df_test, X1lab, X2lab)\n",
    "    \n",
    "    # OPL\n",
    "    n_beta = len(X1lab) + 1 + len(X1lab)+len(X2lab) + 1\n",
    "    \n",
    "    sol_aipw_train = sp.optimize.minimize(\n",
    "        obj_aipw, bounds=[(-1., 1.)] * n_beta, \n",
    "        x0=np.ones(n_beta),\n",
    "        args=(reg1, reg2, df_train, models_train),\n",
    "        constraints=[norm1const1, norm1const2]\n",
    "    )\n",
    "    sol_aipw_test = sp.optimize.minimize(\n",
    "        obj_aipw, bounds=[(-1., 1.)] * n_beta, \n",
    "        x0=np.ones(n_beta),\n",
    "        args=(reg1, reg2, df_test, models_test),\n",
    "        constraints=[norm1const1, norm1const2]\n",
    "    )\n",
    "\n",
    "    res_aipw_train.append(sol_aipw_train)\n",
    "    res_aipw_test.append(sol_aipw_test)\n",
    "    \n",
    "write_pkl(res_aipw_train, 'res_aipw_MIMIC(8)_train.pkl')\n",
    "write_pkl(res_aipw_test, 'res_aipw_MIMIC(8)_test.pkl')\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "5b7c979d-0314-4bfc-9fdc-6375b5e7192d",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "res_aipw_train = read_pkl('res_aipw_MIMIC(8)_train.pkl')\n",
    "res_aipw_test = read_pkl('res_aipw_MIMIC(8)_test.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5acdf1a5-4990-473c-88f3-dda011a2d565",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ```python\n",
    "beta_aipw_train = np.zeros((50, n_beta)); beta_aipw_test = np.zeros((50, n_beta))\n",
    "Vaipw_train = np.zeros(50); Vaipw_test = np.zeros(50); Vaipw_hat_test = np.zeros(50)\n",
    "\n",
    "for i in tqdm(range(50)):\n",
    "    df_train, df_test = train_test_split(df, test_size=0.5, random_state=i+1,\n",
    "                                         stratify=df[['C1', 'C2', 'S1', 'S2']].fillna(0))\n",
    "    \n",
    "    models_train = fit_models_aipcw(df_train, X1lab, X2lab)\n",
    "    models_test_mr = fit_models(df_test, X1lab, X2lab)\n",
    "\n",
    "    beta_aipw_train[i, :] = res_aipw_train[i].x\n",
    "    beta_aipw_test[i, :] = res_aipw_test[i].x\n",
    "\n",
    "    X1test = df_test[X1lab]\n",
    "    X1X2test = df_test[X1lab+X2lab]\n",
    "    d1_train = reg1(beta_aipw_train[i, :9], X1test)\n",
    "    d2_train = reg2(beta_aipw_train[i, 9:], X1X2test)\n",
    "    Vaipw_hat_test[i] = -V_aipcw(d1_train, d2_train, df_test, models_train)\n",
    "    Vaipw_test[i] = -V_MR(d1_train, d2_train, df_test, models_test_mr)\n",
    "\n",
    "write_pkl(Vaipw_test, 'Vaipw_test_SOFA8.pkl')\n",
    "write_pkl(Vaipw_hat_test, 'Vaipw_hat_test_SOFA8.pkl')\n",
    "# ```"
   ]
  },
  {
   "cell_type": "raw",
   "id": "1b986415-39db-44a8-b94d-e28cf0257d5d",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "Vaipw_test = read_pkl('Vaipw_test_SOFA8.pkl')\n",
    "Vaipw_hat_test = read_pkl('Vaipw_hat_test_SOFA8.pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50c1bde3-14fc-471d-95b4-3e0c1e4df672",
   "metadata": {},
   "source": [
    "## MR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d445e7f-9878-4507-9b68-5d438e434469",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ```python\n",
    "res_mr_train = []\n",
    "res_mr_test = []\n",
    "for i in tqdm(range(50)):\n",
    "    df_train, df_test = train_test_split(df, test_size=0.5, random_state=i+1,\n",
    "                                         stratify=df[['C1', 'C2', 'S1', 'S2']].fillna(0))\n",
    "    \n",
    "    # train model\n",
    "    models_train = fit_models(df_train, X1lab, X2lab)\n",
    "    models_test = fit_models(df_test, X1lab, X2lab)\n",
    "    \n",
    "    # OPL\n",
    "    n_beta = len(X1lab) + 1 + len(X1lab)+len(X2lab) + 1\n",
    "    \n",
    "    res_mr_train.append(\n",
    "        sp.optimize.differential_evolution(\n",
    "            obj_sdmr, bounds=[(-1., 1.)] * n_beta, \n",
    "            args=[reg1, reg2, df_train, models_train],\n",
    "            # workers=2,\n",
    "            constraints=[norm1const1, norm1const2])\n",
    "    )\n",
    "    res_mr_test.append(\n",
    "        sp.optimize.differential_evolution(\n",
    "            obj_sdmr, bounds=[(-1., 1.)] * n_beta, \n",
    "            args=[reg1, reg2, df_test, models_test],\n",
    "            # workers=2,\n",
    "            constraints=[norm1const1, norm1const2])\n",
    "    )\n",
    "\n",
    "write_pkl(res_mr_train, 'res_mr(MR)_MIMIC(8)_train.pkl')\n",
    "write_pkl(res_mr_test, 'res_mr(MR)_MIMIC(8)_test.pkl')\n",
    "# ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e82f183-b954-480b-911c-dbfd45531dd0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# ```python\n",
    "Vmr_hat_test, Vmr_test, Vmr_test_learned = [], [], []\n",
    "for i in tqdm(range(50)):\n",
    "    df_train, df_test = train_test_split(df, test_size=0.5, random_state=i+1,\n",
    "                                         stratify=df[['C1', 'C2', 'S1', 'S2']].fillna(0))\n",
    "    \n",
    "    models_train_mr = fit_models(df_train, X1lab, X2lab)\n",
    "    models_test_mr = fit_models(df_test, X1lab, X2lab)\n",
    "    \n",
    "    X1test = df_test[X1lab]\n",
    "    X1X2test = df_test[X1lab+X2lab]\n",
    "    d1_train = reg1(res_mr_train[i].x[:9], X1test)\n",
    "    d2_train = reg2(res_mr_train[i].x[9:], X1X2test)\n",
    "    Vmr_hat_test.append(-V_MR(d1_train, d2_train, df_test, models_train_mr))\n",
    "    \n",
    "    d1_test = reg1(res_mr_test[i].x[:9], X1test)\n",
    "    d2_test = reg2(res_mr_test[i].x[9:], X1X2test)\n",
    "    Vmr_test.append(-V_MR(d1_test, d2_test, df_test, models_test_mr))\n",
    "\n",
    "    Vmr_test_learned.append(-V_MR(d1_train, d2_train, df_test, models_test_mr))\n",
    "\n",
    "Vmr_hat_test = np.array(Vmr_hat_test)\n",
    "Vmr_test_learned = np.array(Vmr_test_learned)\n",
    "Vmr_test = np.array(Vmr_test)\n",
    "\n",
    "write_pkl(Vmr_hat_test, 'Vmr(MR)_hat_test_mimic(8).pkl')\n",
    "write_pkl(Vmr_test, 'Vmr(MR)_test(8).pkl')\n",
    "write_pkl(Vmr_test_learned, 'Vmr(MR)_test_learned(8).pkl')\n",
    "# ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f863b597-e707-41fb-8395-7917a3c3b280",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_mr_train = read_pkl('res_mr(MR)_MIMIC(8)_train.pkl')\n",
    "res_mr_test = read_pkl('res_mr(MR)_MIMIC(8)_test.pkl')\n",
    "Vaipw_test = read_pkl('Vaipw_test_SOFA8.pkl')\n",
    "Vaipw_hat_test = read_pkl('Vaipw_hat_test_SOFA8.pkl')\n",
    "Vmr_hat_test = read_pkl('Vmr(MR)_hat_test_mimic(8).pkl')\n",
    "Vmr_test = read_pkl('Vmr(MR)_test(8).pkl')\n",
    "Vmr_test_learned = read_pkl('Vmr(MR)_test_learned(8).pkl')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f2d8aea-bbf4-40ee-aad5-614a88c4660f",
   "metadata": {},
   "source": [
    "### Plot result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65fbffc8-dc4f-4e63-a659-57e9327e00a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stylize_axes(ax):\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "\n",
    "    ax.yaxis.set_ticks_position('left') \n",
    "    ax.xaxis.set_ticks_position('bottom')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8827fb87-b267-4eec-9e53-1e0a01b7643b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(3,2.75), dpi=1200)\n",
    "\n",
    "ax1 = plt.subplot(121)\n",
    "stylize_axes(ax1)\n",
    "plt.boxplot(np.c_[(-Vmr_hat_test + Vmr_test), \n",
    "                  (-Vaipw_hat_test + Vmr_test)],\n",
    "            sym='x', flierprops={'markersize': 3, 'markeredgewidth': .25},\n",
    "            widths=.5)\n",
    "plt.xticks(range(1, 2+1), ['MR', 'AIPW'])\n",
    "plt.ylabel(r'$V_{train}(\\widehat\\beta) - V_{test}(\\beta^*)$')\n",
    "plt.axhline(0, c=\"r\");\n",
    "\n",
    "stylize_axes(plt.subplot(122, sharey=ax1))\n",
    "plt.boxplot(np.c_[(-Vmr_test_learned + Vmr_test), \n",
    "                  (-Vaipw_test + Vmr_test)],\n",
    "            sym='x', flierprops={'markersize': 3, 'markeredgewidth': .25},\n",
    "            widths=.5)\n",
    "plt.xticks(range(1, 2+1), ['MR', 'AIPW'])\n",
    "plt.ylabel(r'$V_{test}(\\widehat\\beta) - V_{test}(\\beta^*)$')\n",
    "plt.axhline(0, c=\"r\")\n",
    "\n",
    "plt.tight_layout();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afbdfce2-7a08-420b-943d-6f391d6a6b34",
   "metadata": {},
   "source": [
    "### Visualize learned DTR"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "745f1601-d6be-47f3-87a6-f8c6711c20ce",
   "metadata": {},
   "source": [
    "```python\n",
    "# linear decision rules\n",
    "def reg1(beta1, X1):\n",
    "    criteria = np.c_[np.ones(X1.shape[0]), X1] @ beta1 >= 0\n",
    "    d1 = np.array(criteria, dtype=int)\n",
    "    return d1\n",
    "\n",
    "def reg2(beta2, X1X2):\n",
    "    criteria = np.c_[np.ones(X1X2.shape[0]), X1X2] @ beta2 >= 0\n",
    "    d2 = np.array(criteria, dtype=int)\n",
    "    return d2\n",
    "\n",
    "X1 = data[X1lab]\n",
    "X1X2 = data[X1lab+X2lab]\n",
    "\n",
    "# Estimate decision rules\n",
    "d1 = reg1(betas[:(n1+1)], X1)\n",
    "d2 = reg2(betas[(n1+1):], X1X2)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75b7fb1a-12f1-4241-8bd6-47905759b8b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dtr_all = sp.optimize.differential_evolution(\n",
    "    obj_pg, bounds=[(-1., 1.)] * n_beta, \n",
    "    args=[reg1, reg2, df, fit_models(df, X1lab, X2lab)],\n",
    "    constraints=[norm1const1, norm1const2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27ec8c8b-89f2-4e4b-b0dc-f1542e4996c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "dtr_aipw = sp.optimize.differential_evolution(\n",
    "    obj_aipw, bounds=[(-1., 1.)] * n_beta, \n",
    "    args=[reg1, reg2, df, fit_models_aipcw(df, X1lab, X2lab)],\n",
    "    constraints=[norm1const1, norm1const2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39f1e1e6-33dd-443b-a389-6df1a3792610",
   "metadata": {},
   "outputs": [],
   "source": [
    "Vmr_hat_test.mean(), dtr_all.fun, Vaipw_hat_test.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51314c31-8aae-45a4-9c5a-6522540c5fc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def stylize_axes(ax):\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "\n",
    "    # ax.xaxis.set_tick_params(top='off', direction='out', width=1)\n",
    "    # ax.yaxis.set_tick_params(right='off', direction='out', width=1)\n",
    "\n",
    "    ax.yaxis.set_ticks_position('left') \n",
    "    ax.xaxis.set_ticks_position('bottom')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6af5df1-ec37-4633-a8b0-f052cba646e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[X1lab+X2lab].std(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5381cb8a-c71b-4285-a392-25d2da0546b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "X1lab_pretty = ['Age1', 'Weight1', 'Temp1', 'glucose1', 'BUN1',\n",
    "                'Creatinine1', 'WBC1', 'SOFA1']\n",
    "X2lab_pretty = ['Weight2', 'Temp2', 'glucose2', 'BUN2',\n",
    "                'Creatinine2', 'WBC2', 'SOFA2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f06b5c35-415d-462d-a401-efc416e6945c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dtr_all1 = dtr_all.x[1:1+len(X1lab_pretty)] * df[X1lab].std(0)\n",
    "dtr_all2 = dtr_all.x[1+1+len(X1lab_pretty):] * df[X1lab+X2lab].std(0)\n",
    "intp1 = dtr_all.x[0] + df[X1lab].mean(0) @ dtr_all.x[1:1+len(X1lab_pretty)]\n",
    "intp2 = dtr_all.x[1+len(X1lab_pretty)] + df[X1lab+X2lab].mean(0) @ dtr_all.x[1+1+len(X1lab_pretty):]\n",
    "\n",
    "dtr_all1 = np.concatenate(([intp1], dtr_all1))\n",
    "dtr_all2 = np.concatenate(([intp2], dtr_all2))\n",
    "\n",
    "dtr_all1 /= np.linalg.norm(dtr_all1, 2)\n",
    "dtr_all2 /= np.linalg.norm(dtr_all2, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b58800a8-15ec-4574-ae21-5591fae8cff1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(4,2.3), dpi=150)\n",
    "\n",
    "stylize_axes(plt.subplot(211))\n",
    "plt.bar(np.arange(1+len(X1lab_pretty)), dtr_all1, \n",
    "        width=.35, label=r'$\\widehat{\\beta}_1$')\n",
    "plt.axhline(0, c='k', lw=.25)\n",
    "plt.xticks(range(15+1), [], rotation=90)\n",
    "plt.xlim(-.5,15.5)\n",
    "plt.grid(alpha=.2)\n",
    "# plt.legend(loc=4)\n",
    "plt.ylabel(r\"$\\widehat{\\beta}_1$\")\n",
    "\n",
    "stylize_axes(plt.subplot(212))\n",
    "plt.bar(np.arange(1+len(X1lab_pretty+X2lab_pretty)), dtr_all2, \n",
    "        width=.35, label=r'$\\widehat{\\beta}_2$', color='C1')\n",
    "plt.xticks(range(15+1), ['Intercept']+X1lab_pretty+X2lab_pretty, rotation=90)\n",
    "plt.xlim(-.5,15.5)\n",
    "plt.grid(alpha=.2)\n",
    "plt.axhline(0, c='k', lw=.25)\n",
    "plt.ylabel(r\"$\\widehat{\\beta}_2$\")\n",
    "# plt.legend(loc=4)\n",
    "\n",
    "plt.xlabel(\"Standardized Variables\");\n",
    "plt.suptitle(' ');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21d09b2b-1b06-4569-b910-ee95c642ec6e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
