{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# IHDP: RieszNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CumqS__UdZQj"
   },
   "source": [
    "## Library Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 851,
     "status": "ok",
     "timestamp": 1626014142199,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "Tb0UNXId8Sbh"
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import os\n",
    "import glob\n",
    "from joblib import dump, load\n",
    "import pandas as pd\n",
    "import scipy\n",
    "import scipy.stats\n",
    "import scipy.special\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "from utils.riesznet import RieszNet\n",
    "from utils.moments import ate_moment_fn\n",
    "from utils.ihdp_data import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Moment Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "moment_fn = ate_moment_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "u-OfQSQislV8"
   },
   "source": [
    "## MAE Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 1030,
     "status": "ok",
     "timestamp": 1626014143226,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "QKiyRUJTxMBe"
   },
   "outputs": [],
   "source": [
    "data_base_dir = \"./data/IHDP/sim_data\"\n",
    "simulation_files = sorted(glob.glob(\"{}/*.csv\".format(data_base_dir)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ytutDbXdY6ch"
   },
   "source": [
    "### Estimator Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12,
     "status": "ok",
     "timestamp": 1626014143228,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "Xzye0tURZ4Mc",
    "outputId": "95a7f986-3019-44c9-c5ef-71aa1c147105"
   },
   "outputs": [],
   "source": [
    "drop_prob = 0.0  # dropout prob of dropout layers throughout notebook\n",
    "n_hidden = 100  # width of hidden layers throughout notebook\n",
    "\n",
    "# Training params\n",
    "learner_lr = 1e-5\n",
    "learner_l2 = 1e-3\n",
    "learner_l1 = 0.0\n",
    "n_epochs = 600\n",
    "earlystop_rounds = 40 # how many epochs to wait for an out-of-sample improvement\n",
    "earlystop_delta = 1e-4\n",
    "target_reg = 1.0\n",
    "riesz_weight = 0.1\n",
    "\n",
    "bs = 64\n",
    "device = torch.cuda.current_device() if torch.cuda.is_available() else None\n",
    "print(\"GPU:\", torch.cuda.is_available())\n",
    "\n",
    "from itertools import chain, combinations\n",
    "from itertools import combinations_with_replacement as combinations_w_r\n",
    "\n",
    "def _combinations(n_features, degree, interaction_only):\n",
    "        comb = (combinations if interaction_only else combinations_w_r)\n",
    "        return chain.from_iterable(comb(range(n_features), i)\n",
    "                                   for i in range(0, degree + 1))\n",
    "\n",
    "class Learner(nn.Module):\n",
    "\n",
    "    def __init__(self, n_t, n_hidden, p, degree, interaction_only=False):\n",
    "        super().__init__()\n",
    "        n_common = 200\n",
    "        self.monomials = list(_combinations(n_t, degree, interaction_only))\n",
    "        self.common = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_common), nn.ELU(),\n",
    "                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU(),\n",
    "                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU())\n",
    "        self.riesz_nn = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))\n",
    "        self.riesz_poly = nn.Sequential(nn.Linear(len(self.monomials), 1))\n",
    "        self.reg_nn0 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),\n",
    "                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),\n",
    "                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))\n",
    "        self.reg_nn1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),\n",
    "                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),\n",
    "                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))\n",
    "        self.reg_poly = nn.Sequential(nn.Linear(len(self.monomials), 1))\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        poly = torch.cat([torch.prod(x[:, t], dim=1, keepdim=True)\n",
    "                          for t in self.monomials], dim=1)\n",
    "        feats = self.common(x)\n",
    "        riesz = self.riesz_nn(feats) + self.riesz_poly(poly)\n",
    "        reg = self.reg_nn0(feats) * (1 - x[:, [0]]) + self.reg_nn1(feats) * x[:, [0]] + self.reg_poly(poly)\n",
    "        return torch.cat([reg, riesz], dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nsims = 1000\n",
    "np.random.seed(123)\n",
    "sim_ids = np.random.choice(len(simulation_files), nsims, replace=False)\n",
    "methods = ['dr', 'direct', 'ips']\n",
    "srr = {'dr' : True, 'direct' : False, 'ips' : True}\n",
    "\n",
    "true_ATEs = []\n",
    "results = []\n",
    "\n",
    "for it, sim in enumerate(sim_ids):\n",
    "    simulation_file = simulation_files[sim]\n",
    "    x = load_and_format_covariates(simulation_file, delimiter=' ')\n",
    "    t, y, y_cf, mu_0, mu_1 = load_other_stuff(simulation_file, delimiter=' ')\n",
    "    X = np.c_[t, x]\n",
    "    true_ATE = np.mean(mu_1 - mu_0)\n",
    "    true_ATEs.append(true_ATE)\n",
    "\n",
    "    y_scaler = StandardScaler(with_mean=True).fit(y)\n",
    "    y = y_scaler.transform(y)\n",
    "    \n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)\n",
    "\n",
    "    torch.cuda.empty_cache()\n",
    "    learner = Learner(X_train.shape[1], n_hidden, drop_prob, 0, interaction_only=True)\n",
    "    agmm = RieszNet(learner, moment_fn)\n",
    "    # Fast training\n",
    "    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,\n",
    "             earlystop_rounds=2, earlystop_delta=earlystop_delta,\n",
    "             learner_lr=1e-4, learner_l2=learner_l2, learner_l1=learner_l1,\n",
    "             n_epochs=100, bs=bs, target_reg=target_reg,\n",
    "             riesz_weight=riesz_weight, optimizer='adam',\n",
    "             model_dir=str(Path.home()), device=device, verbose=0)\n",
    "    # Fine tune\n",
    "    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,\n",
    "             earlystop_rounds=earlystop_rounds, earlystop_delta=earlystop_delta,\n",
    "             learner_lr=learner_lr, learner_l2=learner_l2, learner_l1=learner_l1,\n",
    "             n_epochs=600, bs=bs, target_reg=target_reg,\n",
    "             riesz_weight=riesz_weight, optimizer='adam', warm_start=True,\n",
    "             model_dir=str(Path.home()), device=device, verbose=0)\n",
    "    \n",
    "    params = tuple(x * y_scaler.scale_[0] for method in methods\n",
    "                   for x in agmm.predict_avg_moment(X, y,  model='earlystop', method = method, srr = srr[method])) + (true_ATE, )\n",
    "                        \n",
    "    results.append(params)\n",
    "\n",
    "res = tuple(np.array(x) for x in zip(*results))\n",
    "truth = res[-1:]\n",
    "res_dict = {}\n",
    "for it, method in enumerate(methods):\n",
    "    point, lb, ub = res[it * 3: (it + 1)*3]\n",
    "    res_dict[method] = {'point': point, 'lb': lb, 'ub': ub,\n",
    "                        'MAE': np.mean(np.abs(point - truth)),\n",
    "                        'std. err.': np.std(np.abs(point - truth)) / np.sqrt(nsims),\n",
    "                        }\n",
    "    print(\"{} : MAE = {:.3f} +/- {:.3f}\".format(method, res_dict[method]['MAE'], res_dict[method]['std. err.']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './results/IHDP/RieszNet/MAE'\n",
    "\n",
    "if not os.path.exists(path):\n",
    "    os.makedirs(path)\n",
    "            \n",
    "dump(res_dict, path + '/IHDP_MAE_NN.joblib')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './results/IHDP/RieszNet/MAE'\n",
    "\n",
    "if not os.path.exists(path):\n",
    "    os.makedirs(path)\n",
    "\n",
    "methods_str = [\"DR\", \"Direct\", \"IPS\"] \n",
    "\n",
    "with open(path + '/IHDP_MAE_NN.tex', \"w\") as f:\n",
    "    f.write(\"\\\\begin{tabular}{lc} \\n\" +\n",
    "            \"\\\\toprule \\n\" +\n",
    "            \"& MAE $\\\\pm$ std. err. \\\\\\\\ \\n\" +\n",
    "            \"\\\\midrule \\n\")\n",
    "    \n",
    "    for i, method in enumerate(methods):\n",
    "        f.write(\" & \".join([methods_str[i], \"{:.3f} $\\\\pm$ {:.3f}\".format(res_dict[method]['MAE'], \n",
    "                                                                          res_dict[method]['std. err.'])]) + \" \\n\")\n",
    "\n",
    "    f.write(\"\\\\bottomrule \\n \\\\end{tabular}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Coverage Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_base_dir = \"./data/IHDP/sim_data_redraw_T\"\n",
    "simulation_files = sorted(glob.glob(\"{}/*.csv\".format(data_base_dir)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rmse_fn(y_pred, y_true):\n",
    "    return np.sqrt(np.mean((y_pred - y_true)**2))\n",
    "\n",
    "nsims = 100\n",
    "np.random.seed(123)\n",
    "sim_ids = np.random.choice(len(simulation_files), nsims, replace=False)\n",
    "methods = ['dr', 'direct', 'ips']\n",
    "srr = {'dr' : True, 'direct' : False, 'ips' : True}\n",
    "\n",
    "true_ATEs = []\n",
    "results = []\n",
    "\n",
    "for it, sim in enumerate(sim_ids):\n",
    "    simulation_file = simulation_files[sim]\n",
    "    x = load_and_format_covariates(simulation_file, delimiter=' ')\n",
    "    t, y, y_cf, mu_0, mu_1 = load_other_stuff(simulation_file, delimiter=' ')\n",
    "    X = np.c_[t, x]\n",
    "    true_ATE = np.mean(mu_1 - mu_0)\n",
    "    true_ATEs.append(true_ATE)\n",
    "\n",
    "    y_scaler = StandardScaler(with_mean=True).fit(y)\n",
    "    y = y_scaler.transform(y)\n",
    "    \n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)\n",
    "\n",
    "    torch.cuda.empty_cache()\n",
    "    learner = Learner(X_train.shape[1], n_hidden, drop_prob, 0, interaction_only=True)\n",
    "    agmm = RieszNet(learner, moment_fn)\n",
    "    # Fast training\n",
    "    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,\n",
    "             earlystop_rounds=2, earlystop_delta=earlystop_delta,\n",
    "             learner_lr=1e-4, learner_l2=learner_l2, learner_l1=learner_l1,\n",
    "             n_epochs=100, bs=bs, target_reg=target_reg,\n",
    "             riesz_weight=riesz_weight, optimizer='adam',\n",
    "             model_dir=str(Path.home()), device=device, verbose=0)\n",
    "    # Fine tune\n",
    "    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,\n",
    "             earlystop_rounds=earlystop_rounds, earlystop_delta=earlystop_delta,\n",
    "             learner_lr=learner_lr, learner_l2=learner_l2, learner_l1=learner_l1,\n",
    "             n_epochs=600, bs=bs, target_reg=target_reg,\n",
    "             riesz_weight=riesz_weight, optimizer='adam', warm_start=True,\n",
    "             model_dir=str(Path.home()), device=device, verbose=0)\n",
    "    \n",
    "    params = tuple(x * y_scaler.scale_[0] for method in methods\n",
    "                   for x in agmm.predict_avg_moment(X, y,  model='earlystop', method = method, srr = srr[method])) + (true_ATE, )\n",
    "                        \n",
    "    results.append(params)\n",
    "                        \n",
    "res = tuple(np.array(x) for x in zip(*results))\n",
    "truth = res[-1:]\n",
    "res_dict = {}\n",
    "for it, method in enumerate(methods):\n",
    "    point, lb, ub = res[it * 3: (it + 1)*3]\n",
    "    res_dict[method] = {'point': point, 'lb': lb, 'ub': ub,\n",
    "                        'cov': np.mean(np.logical_and(truth >= lb, truth <= ub)),\n",
    "                        'bias': np.mean(point - truth),\n",
    "                        'rmse': rmse_fn(point, truth)\n",
    "                        }\n",
    "    print(\"{} : bias = {:.3f}, rmse = {:.3f}, cov = {:.3f}\".format(method, res_dict[method]['bias'], res_dict[method]['rmse'], res_dict[method]['cov']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './results/IHDP/RieszNet/coverage'\n",
    "\n",
    "if not os.path.exists(path):\n",
    "    os.makedirs(path)\n",
    "    \n",
    "dump(res_dict, path + '/IHDP_coverage_NN.joblib')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Histogram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './results/IHDP/RieszNet/coverage'\n",
    "\n",
    "if not os.path.exists(path):\n",
    "    os.makedirs(path)\n",
    "    \n",
    "method_strs = [\"{}. Bias: {:.3f}, RMSE: {:.3f}, Coverage: {:.3f}\".format(method, d['bias'], d['rmse'], d['cov'])\n",
    "               for method, d in res_dict.items()]\n",
    "plt.title(\"\\n\".join(method_strs))\n",
    "for method, d in res_dict.items():\n",
    "    plt.hist(np.array(d['point']), alpha=.5, label=method)\n",
    "plt.axvline(x = np.mean(truth), label='true', color='red')\n",
    "plt.legend()\n",
    "plt.savefig(path + '/IHDP_coverage_NN.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "NNRiesz_IHDP.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
