{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from datetime import datetime\n",
    "\n",
    "from data.twins.data import TwinsData\n",
    "\n",
    "\n",
    "from algorithm.general import PearsonConfounderTest, GConfounderTest\n",
    "from algorithm.kernel.test import KernelConfounderTest\n",
    "from algorithm.misc import PearsonPrognosticConfounderTest, KernelPrognosticConfounderTest\n",
    "\n",
    "\n",
    "from experiment.plot import set_mpl_default_settings, marker_dict, color_dict\n",
    "set_mpl_default_settings()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from econml.dml import LinearDML\n",
    "from sklearn.linear_model import LinearRegression\n",
    "\n",
    "def estimate_ate(data, covariates):\n",
    "    T = data['T']\n",
    "    Y = data['Y']\n",
    "    if len(covariates) > 0:\n",
    "        W = data[covariates]  \n",
    "        est = LinearDML(model_y=LinearRegression(), model_t=LinearRegression(), discrete_treatment=False)\n",
    "        est.fit(Y, T, W=W)\n",
    "        ate = est.effect(T0=0, T1=1)[0]\n",
    "    else:\n",
    "        est = LinearRegression()\n",
    "        est.fit(T.values.reshape(-1,1), Y.values)\n",
    "        ate = est.coef_[0]\n",
    "    return ate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def experiment(method, conf_strength, alpha, nbr_iter, resample_mechnisms, nbr_confounders=None, nbr_observed_confounders=0, compute_bias=True, max_tests=None):\n",
    "    \n",
    "    data_gen = TwinsData('../data/twins/files', nbr_confounders=nbr_confounders)\n",
    "    conf_rate = 0\n",
    "    bias  = []\n",
    "    statistic = []\n",
    "    for _ in range(nbr_iter):\n",
    "        \n",
    "        # sample data\n",
    "        data_gen.randomly_select_confounders()\n",
    "        confounder_list = data_gen.covar_list\n",
    "        out = data_gen.sample(conf_strength, resample_mechanisms=resample_mechnisms, binary=False, nbr_changes=0)\n",
    "        data = out['observational_data']\n",
    "        \n",
    "        # choose nbr_observed_confounders from confounder_list\n",
    "        observed_covariates = list(np.random.choice(confounder_list, size=nbr_observed_confounders, replace=False))\n",
    "        \n",
    "        # run test\n",
    "        res = method.test(data, observed_covariates=observed_covariates, max_tests=max_tests, alpha=alpha)\n",
    "        pval = res['pval']\n",
    "        \n",
    "        if type(method).__name__ in ['PearsonConfounderTest', 'KernelConfounderTest', 'KernelPrognosticConfounderTest', 'PearsonPrognosticConfounderTest']:\n",
    "            statistic.append(res['X'])\n",
    "        else:\n",
    "            statistic.append(0)\n",
    "\n",
    "        if pval < alpha:\n",
    "            conf_rate += 1\n",
    "\n",
    "        # compute bias of ate\n",
    "        if compute_bias:\n",
    "            for e in out['true_ate']:\n",
    "                true_ate = out['true_ate'][e]\n",
    "                est_ate = estimate_ate(data[e], observed_covariates)\n",
    "                bias.append(np.abs(est_ate - true_ate))\n",
    "        else:\n",
    "            for e in out['true_ate']:\n",
    "                bias.append(-1)\n",
    "    \n",
    "    return conf_rate/nbr_iter, np.mean(bias), np.std(bias), np.mean(statistic), np.std(statistic), res['threshold']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment\n",
    "For each method, we vary\n",
    "- confounding strength\n",
    "- number of observed confounders\n",
    "- nbr of changes (sparsity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load = False\n",
    "\n",
    "# Experiment parameters\n",
    "methods = [KernelConfounderTest()]\n",
    "conf_strengths = [0, 0.25, 1, 2.5, 5] \n",
    "nbr_observed_confounders = [0,1,2,3,4,5]\n",
    "max_tests = [1,5,50]\n",
    "alpha = 0.05\n",
    "nbr_iter = 50\n",
    "nbr_confounders = [3,5]\n",
    "resample_mechanisms = True\n",
    "compute_bias = True\n",
    "\n",
    "\n",
    "# Get timestamp for experiment\n",
    "now = datetime.now()\n",
    "timestamp = now.strftime(\"%m%d%H%M\")\n",
    "print('Timestamp:', timestamp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not load:\n",
    "\n",
    "    res = {'detection_rate' : [], \n",
    "            'bias_mean' : [], \n",
    "            'bias_std' : [],\n",
    "            'method' : [],\n",
    "            'conf_strength' : [],\n",
    "            'nbr_observed_confounders' : [],\n",
    "            'max_tests' : [],\n",
    "            'resample_mechanisms' : [],\n",
    "            'alpha' : [],\n",
    "            'nbr_iter' : [],\n",
    "            'stat_mean' : [],\n",
    "            'stat_std' : [],\n",
    "            'threshold' : [],\n",
    "            'nbr_confounders' : []\n",
    "            }\n",
    "\n",
    "    for i, method in enumerate(methods):\n",
    "        for j, conf_strength in tqdm(enumerate(conf_strengths)):\n",
    "            for k, n_obs_conf in enumerate(nbr_observed_confounders):\n",
    "                    for m, max_test in enumerate(max_tests):\n",
    "                        for n, nbr_conf in enumerate(nbr_confounders):\n",
    "                            if nbr_conf < n_obs_conf:\n",
    "                                continue\n",
    "                            print('Method:', type(method).__name__, 'Confounder strength:', conf_strength, 'Nbr observed confounders:', n_obs_conf, 'Max tests:', max_test, 'Nbr confounders:', nbr_conf)\n",
    "                            conf_rate, bias_mean, bias_std, stat_mean, stat_std, threshold = experiment(method=method, \n",
    "                                                                                                        conf_strength=conf_strength, \n",
    "                                                                                                        alpha=alpha, \n",
    "                                                                                                        nbr_iter=nbr_iter, \n",
    "                                                                                                        resample_mechnisms=resample_mechanisms,\n",
    "                                                                                                        nbr_confounders=nbr_conf, \n",
    "                                                                                                        nbr_observed_confounders=n_obs_conf,\n",
    "                                                                                                        compute_bias=compute_bias,\n",
    "                                                                                                        max_tests=max_test)\n",
    "                            res['detection_rate'].append(conf_rate)\n",
    "                            res['bias_mean'].append(bias_mean)\n",
    "                            res['bias_std'].append(bias_std)\n",
    "                            res['method'].append(type(method).__name__)\n",
    "                            res['conf_strength'].append(conf_strength)\n",
    "                            res['nbr_observed_confounders'].append(n_obs_conf)\n",
    "                            res['max_tests'].append(max_test)\n",
    "                            res['resample_mechanisms'].append(resample_mechanisms)\n",
    "                            res['alpha'].append(alpha)\n",
    "                            res['nbr_iter'].append(nbr_iter)\n",
    "                            res['stat_mean'].append(stat_mean)\n",
    "                            res['stat_std'].append(stat_std)\n",
    "                            res['threshold'].append(threshold)\n",
    "                            res['nbr_confounders'].append(nbr_conf)\n",
    "                    \n",
    "    # Save res as DataFrame and into csv file\n",
    "    df = pd.DataFrame(res)\n",
    "    df.to_csv(f'results/twins_{timestamp}.csv')\n",
    "\n",
    "else:\n",
    "\n",
    "    timestamp_str = \"11011016\"\n",
    "    timestamp = int(timestamp_str)\n",
    "    path = f'results/twins_{timestamp}.csv'\n",
    "    df = pd.read_csv(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot the variation of confounding strength versus detection rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fix_nbr_observed_confounders = 0\n",
    "fix_max_tests = 50\n",
    "fix_nbr_conf = 5\n",
    "tmp_df = df[df['nbr_observed_confounders'] == fix_nbr_observed_confounders]\n",
    "tmp_df = tmp_df[tmp_df['max_tests'] == fix_max_tests]\n",
    "tmp_df = tmp_df[tmp_df['nbr_confounders'] == fix_nbr_conf]\n",
    "\n",
    "plt.figure(figsize=(6,4))\n",
    "for i, method in enumerate(tmp_df.method.unique()):\n",
    "    plt.plot(tmp_df.conf_strength[tmp_df.method == method], tmp_df.detection_rate[tmp_df.method == method], label=method, marker=marker_dict[method], color=color_dict[method])\n",
    "    lb = [p - np.sqrt(p*(1-p)/nbr_iter) for p in tmp_df.detection_rate[tmp_df.method == method].values]\n",
    "    ub = [p + np.sqrt(p*(1-p)/nbr_iter) for p in tmp_df.detection_rate[tmp_df.method == method].values]\n",
    "    plt.fill_between(tmp_df.conf_strength[tmp_df.method == method], lb, ub, alpha=0.2, color=color_dict[method])\n",
    "plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n",
    "plt.xlabel('Confounder strength')\n",
    "plt.ylabel('Detection rate')\n",
    "plt.title(f'Nbr observed confounders: {fix_nbr_observed_confounders}, Max tests: {fix_max_tests}')\n",
    "plt.ylim(-.1,1.1)\n",
    "\n",
    "plt.savefig(f'results/figures/twins_vary_conf_{timestamp}.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,4))\n",
    "for i, method in enumerate(tmp_df.method.unique()):\n",
    "    plt.errorbar(tmp_df.bias_mean[tmp_df.method == method],\n",
    "                    tmp_df.stat_mean[tmp_df.method == method], \n",
    "                    xerr=tmp_df.bias_std[tmp_df.method == method], \n",
    "                    yerr= tmp_df.stat_std[tmp_df.method == method], #[np.sqrt(p*(1-p)/nbr_iter) for p in tmp_df.detection_rate[tmp_df.method == method].values],\n",
    "                    capsize=4,\n",
    "                    capthick=2,\n",
    "                    linestyle='--',\n",
    "                    label=method, \n",
    "                    marker=marker_dict[method], \n",
    "                    color=color_dict[method])\n",
    "plt.axhline(tmp_df.threshold.values[0], color='black', linestyle='-.', alpha=0.9)\n",
    "plt.xlabel('Average omitted variable bias')\n",
    "plt.ylabel('Average test statistic')\n",
    "plt.savefig(f'results/figures/twins_vary_bias_{timestamp}.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot the variation of nbr of observed confounders versus detection rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fix_conf_strength = 5\n",
    "fix_max_tests = 50\n",
    "fix_nbr_conf = 3\n",
    "tmp_df = df[df['conf_strength'] == fix_conf_strength]\n",
    "tmp_df = tmp_df[tmp_df['max_tests'] == fix_max_tests]\n",
    "\n",
    "\n",
    "plt.figure(figsize=(6,4))\n",
    "tmp_df = tmp_df[tmp_df['nbr_confounders'] == fix_nbr_conf]\n",
    "for i, method in enumerate(tmp_df.method.unique()):\n",
    "    plt.plot(tmp_df.nbr_observed_confounders[tmp_df.method == method], tmp_df.detection_rate[tmp_df.method == method], label=method, marker=marker_dict[method], color=color_dict[method])\n",
    "    lb = [p - np.sqrt(p*(1-p)/nbr_iter) for p in tmp_df.detection_rate[tmp_df.method == method].values]\n",
    "    ub = [p + np.sqrt(p*(1-p)/nbr_iter) for p in tmp_df.detection_rate[tmp_df.method == method].values]\n",
    "    plt.fill_between(tmp_df.nbr_observed_confounders[tmp_df.method == method], lb, ub, alpha=0.2, color=color_dict[method])\n",
    "plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n",
    "plt.xlabel('Number of confounders adjusted for')\n",
    "plt.ylabel('Detection rate')\n",
    "plt.title(f'Conf strength: {fix_conf_strength}, Max tests: {fix_max_tests}')\n",
    "plt.ylim(-.1,1.1)\n",
    "\n",
    "plt.savefig(f'results/figures/twins_vary_obs_{timestamp}.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,4))\n",
    "for i, method in enumerate(tmp_df.method.unique()):\n",
    "    plt.plot(tmp_df.nbr_observed_confounders[tmp_df.method == method], tmp_df.stat_mean[tmp_df.method == method], label=method, marker=marker_dict[method], color=color_dict[method])\n",
    "    lb = tmp_df.stat_mean[tmp_df.method == method].values - tmp_df.stat_std[tmp_df.method == method].values\n",
    "    ub = tmp_df.stat_mean[tmp_df.method == method].values + tmp_df.stat_std[tmp_df.method == method].values\n",
    "    plt.fill_between(tmp_df.nbr_observed_confounders[tmp_df.method == method], lb, ub, alpha=0.2, color=color_dict[method])\n",
    "\n",
    "plt.xlabel('Number of confounders adjusted for')\n",
    "plt.ylabel('Average value of test statistic')\n",
    "plt.title(f'Conf strength: {fix_conf_strength}, Max tests: {fix_max_tests}')\n",
    "plt.axhline(tmp_df.threshold.values[0], color='black', linestyle='--', label='Reject at $\\\\alpha=0.05$')\n",
    "plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n",
    "#plt.yscale('symlog')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,4))\n",
    "for i, method in enumerate(tmp_df.method.unique()):\n",
    "    plt.plot(tmp_df.nbr_observed_confounders[tmp_df.method == method], tmp_df.bias_mean[tmp_df.method == method])\n",
    "    lb = tmp_df.bias_mean[tmp_df.method == method].values - tmp_df.bias_std[tmp_df.method == method].values\n",
    "    ub = tmp_df.bias_mean[tmp_df.method == method].values + tmp_df.bias_std[tmp_df.method == method].values\n",
    "    plt.fill_between(tmp_df.nbr_observed_confounders[tmp_df.method == method], lb, ub, alpha=0.2, color=color_dict[method])\n",
    "    break\n",
    "plt.xlabel('Number of confounders adjusted for')\n",
    "plt.ylabel('Omitted variable bias')\n",
    "plt.ylim(0,10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check how test depends on number of tests performed in Fisher's method\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fix_conf_strength = 5\n",
    "fix_nbr_conf = 5\n",
    "fix_max_tests_list = [1,5,50]\n",
    "tmp_df = df[df['conf_strength'] == fix_conf_strength]\n",
    "tmp_df = tmp_df[tmp_df['nbr_confounders'] == fix_nbr_conf]\n",
    "\n",
    "plt.figure(figsize=(6,4))\n",
    "for fix_max_tests in fix_max_tests_list:\n",
    "    tmp2_df = tmp_df[tmp_df['max_tests'] == fix_max_tests]\n",
    "    for i, method in enumerate(tmp_df.method.unique()):\n",
    "        p = plt.plot(tmp2_df.nbr_observed_confounders[tmp_df.method == method], tmp2_df.stat_mean[tmp_df.method == method], label= method + f' $n_c=${fix_max_tests}', marker=marker_dict[method])\n",
    "        plt.plot(tmp2_df.nbr_observed_confounders[tmp_df.method == method], tmp2_df.threshold[tmp_df.method == method], linestyle='--', color=p[-1].get_color())\n",
    "        lb = tmp2_df.stat_mean[tmp_df.method == method].values - tmp2_df.stat_std[tmp_df.method == method].values\n",
    "        ub = tmp2_df.stat_mean[tmp_df.method == method].values + tmp2_df.stat_std[tmp_df.method == method].values\n",
    "        plt.fill_between(tmp2_df.nbr_observed_confounders[tmp_df.method == method], lb, ub, alpha=0.2)\n",
    "    \n",
    "        \n",
    "plt.xlabel('Number of confounders adjusted for')\n",
    "plt.ylabel('Average value of test statistic')\n",
    "plt.title(f'Conf strength: {fix_conf_strength}, Max tests: {fix_max_tests}')\n",
    "plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,4))\n",
    "colors = ['green', 'blue', 'orange']\n",
    "for i, fix_max_tests in enumerate(fix_max_tests_list):\n",
    "    tmp2_df = tmp_df[tmp_df['max_tests'] == fix_max_tests]\n",
    "    for j, method in enumerate(tmp_df.method.unique()):\n",
    "        plt.plot(tmp2_df.nbr_observed_confounders[tmp_df.method == method],\n",
    "                    tmp2_df.detection_rate[tmp_df.method == method],\n",
    "                    label=f'$n_c=${fix_max_tests}',\n",
    "                    marker=marker_dict[method],\n",
    "                    color=colors[i])\n",
    "        lb = [p - np.sqrt(p*(1-p)/nbr_iter) for p in tmp2_df.detection_rate[tmp_df.method == method].values]\n",
    "        ub = [p + np.sqrt(p*(1-p)/nbr_iter) for p in tmp2_df.detection_rate[tmp_df.method == method].values]\n",
    "        plt.fill_between(tmp2_df.nbr_observed_confounders[tmp_df.method == method], lb, ub, alpha=0.2, color=colors[i])\n",
    "plt.legend()\n",
    "plt.xlabel('Number of confounders adjusted for')\n",
    "plt.ylabel('Detection rate')\n",
    "plt.ylim(-.1,1.1)\n",
    "plt.hlines(0.05, 0, tmp2_df.nbr_observed_confounders.max(), linestyle='--', color='black')\n",
    "plt.savefig(f'./results/figures/twins_obs_conf_{timestamp}_conf{tmp2_df.nbr_observed_confounders.max()}.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Investigate influence of adjusting for more confounders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fix_conf_strength = 5\n",
    "fix_max_tests = 50\n",
    "nbr_confounder_list = [1,3,5]\n",
    "\n",
    "tmp_df = df[df['conf_strength'] == fix_conf_strength]\n",
    "tmp_df = tmp_df[tmp_df['max_tests'] == fix_max_tests]\n",
    "\n",
    "plt.figure(figsize=(6,4))\n",
    "for i, method in enumerate(tmp_df.method.unique()):\n",
    "    tmp2_df = tmp_df[tmp_df['method'] == method]\n",
    "    rate_list = []\n",
    "    std_list = []\n",
    "    for nbr_conf in nbr_confounder_list:\n",
    "        tmp3_df = tmp2_df[tmp2_df['nbr_confounders'] == nbr_conf]\n",
    "        tmp3_df = tmp3_df[tmp3_df['nbr_observed_confounders'] == nbr_conf]\n",
    "        rate = tmp3_df.detection_rate.values[0]#tmp3_df.stat_mean.values[0]\n",
    "        std = np.sqrt(rate*(1-rate)/nbr_iter)#tmp3_df.stat_std.values[0] \n",
    "        rate_list.append(rate)\n",
    "        std_list.append(std)\n",
    "\n",
    "    print(method, rate_list, std_list)\n",
    "\n",
    "    p = plt.plot(nbr_confounder_list, rate_list, label= method + f' $n_c=${fix_max_tests}', marker=marker_dict[method])\n",
    "    lb = [r - s for r,s in zip(rate_list, std_list)]\n",
    "    ub = [r + s for r,s in zip(rate_list, std_list)]\n",
    "    plt.fill_between(nbr_confounder_list, lb, ub, alpha=0.2)\n",
    "\n",
    "    #plt.plot(nbr_confounder_list, tmp2_df.bias_mean, linestyle='--', color=p[-1].get_color())\n",
    "    #lb = tmp2_df.bias_mean.values - tmp2_df.bias_std.values\n",
    "    #ub = tmp2_df.bias_mean.values + tmp2_df.bias_std.values\n",
    "    #plt.fill_between(nbr_confounder_list, lb, ub, alpha=0.2)\n",
    "\n",
    "plt.legend(loc='upper right', bbox_to_anchor=(2.5, 1))\n",
    "plt.xlabel('Number of confounders adjusted for')\n",
    "plt.ylabel('Detection rate')\n",
    "plt.title(f'Conf strength: {fix_conf_strength}, Max tests: {fix_max_tests}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('CI2')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "1119f8550b2138f5de574d3adbc9d9c628b005f552be6d04a225ae36781ad7c3"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
