{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3168,
     "status": "ok",
     "timestamp": 1626715360964,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "wrXy_r_iuXjH",
    "outputId": "c7b167fe-0678-4fc8-c771-3597b7337fa1"
   },
   "outputs": [],
   "source": [
    "# !pip install econml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 27,
     "status": "ok",
     "timestamp": 1626715360966,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "G_qWGM2RvR3a"
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# IHDP: ForestRiesz"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Library Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1011,
     "status": "ok",
     "timestamp": 1626715362111,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "tpLT3n3Vvhjb",
    "outputId": "99d93f81-f8fb-44f2-9e02-8f70936fbe2c"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "from joblib import dump, load\n",
    "import pandas as pd\n",
    "import scipy\n",
    "import scipy.stats\n",
    "import scipy.special\n",
    "import numpy as np\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.model_selection import cross_val_predict\n",
    "from utils.forestriesz import poly_feature_fns\n",
    "from utils.RF_avgmom_sim import sim_fun"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Moment Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def moment_fn(X, test_fn): # Returns the moment for the average small difference\n",
    "    epsilon = 0.01\n",
    "    t1 = np.hstack([X[:, [0]] + epsilon, X[:, 1:]])\n",
    "    t0 = np.hstack([X[:, [0]] - epsilon, X[:, 1:]])\n",
    "    return (test_fn(t1) - test_fn(t0)) / (2*epsilon)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4tc5D26DwbY4"
   },
   "source": [
    "## Settings for RF Estimators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 32,
     "status": "ok",
     "timestamp": 1626715362113,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "bg4cSbJmwRmH"
   },
   "outputs": [],
   "source": [
    "# RFRiesz Settings\n",
    "ForestRiesz_opt = {\n",
    "    'reg_feature_fns': poly_feature_fns(1),\n",
    "    'riesz_feature_fns': poly_feature_fns(3),\n",
    "    'moment_fn': moment_fn,\n",
    "    'l2': 1e-3,\n",
    "    'criterion': 'mse', \n",
    "    'n_estimators': 100,\n",
    "    'min_samples_leaf': 50,\n",
    "    'min_var_fraction_leaf': 0.1, \n",
    "    'min_var_leaf_on_val': True,\n",
    "    'min_impurity_decrease': 0.001, \n",
    "    'max_samples': .65, \n",
    "    'max_depth': None,\n",
    "    'warm_start': False,\n",
    "    'inference': False,\n",
    "    'subforest_size': 1,\n",
    "    'honest': True,\n",
    "    'verbose': 0,\n",
    "    'n_jobs': -1,\n",
    "    'random_state': 123\n",
    "}\n",
    "\n",
    "# RFreg Settings\n",
    "RFreg_opt = {\n",
    "    'reg_feature_fns': poly_feature_fns(1),\n",
    "    'l2': 1e-3,\n",
    "    'criterion': 'mse', \n",
    "    'n_estimators': 100,\n",
    "    'min_samples_leaf': 50,\n",
    "    'min_var_fraction_leaf': 0.1, \n",
    "    'min_var_leaf_on_val': True,\n",
    "    'min_impurity_decrease': 0.001, \n",
    "    'max_samples': .65, \n",
    "    'max_depth': None,\n",
    "    'warm_start': False,\n",
    "    'inference': False,\n",
    "    'subforest_size': 1,\n",
    "    'honest': True,\n",
    "    'verbose': 0,\n",
    "    'n_jobs': -1,\n",
    "    'random_state': 123\n",
    "}\n",
    "\n",
    "# RFrr Settings\n",
    "RFrr_opt = {\n",
    "    'riesz_feature_fns': poly_feature_fns(3),\n",
    "    'moment_fn': moment_fn, \n",
    "    'l2': 1e-3,\n",
    "    'criterion': 'mse', \n",
    "    'n_estimators': 100,\n",
    "    'min_samples_leaf': 50,\n",
    "    'min_var_fraction_leaf': 0.1, \n",
    "    'min_var_leaf_on_val': True,\n",
    "    'min_impurity_decrease': 0.001, \n",
    "    'max_samples': .65, \n",
    "    'max_depth': None,\n",
    "    'warm_start': False,\n",
    "    'inference': False,\n",
    "    'subforest_size': 1,\n",
    "    'honest': True,\n",
    "    'verbose': 0,\n",
    "    'n_jobs': -1,\n",
    "    'random_state': 123\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5NE5F8z7Sjhq"
   },
   "source": [
    "## Read Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 138,
     "status": "ok",
     "timestamp": 1626715362224,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "0Vt_gLC8viPt"
   },
   "outputs": [],
   "source": [
    "df = pd.read_csv('./data/BHP/data_BHP2.csv')\n",
    "df = df[df[\"log_p\"] > math.log(1.2)]\n",
    "df = df[df[\"log_y\"] > math.log(15000)]\n",
    "Xdf = df.iloc[:,1:]\n",
    "X_nostatedum = Xdf.drop([\"distance_oil1000\", \"share\"], axis=1).values\n",
    "columns = Xdf.columns\n",
    "state_dum = pd.get_dummies(Xdf['state_fips'], prefix=\"state\")\n",
    "Xdf = pd.concat([Xdf, state_dum], axis = 1)\n",
    "Xdf = Xdf.drop([\"distance_oil1000\", \"state_fips\", \"share\"], axis=1)\n",
    "W = Xdf.drop([\"log_p\"], axis=1).values\n",
    "T = Xdf['log_p'].values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-DregF_xfSXp"
   },
   "source": [
    "## Generate Semi-Synthetic Treatment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 16440,
     "status": "ok",
     "timestamp": 1626715378660,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "7tYXj8zufSEn",
    "outputId": "0af4e0c2-bfef-47ac-e64a-33fa53d94e06"
   },
   "outputs": [],
   "source": [
    "# Conditional Mean\n",
    "mu_T = RandomForestRegressor(n_estimators = 100, min_samples_leaf = 50, random_state = 123)\n",
    "mu_T.fit(W, T)\n",
    "\n",
    "# Conditional Variance\n",
    "sigma2_T = RandomForestRegressor(n_estimators = 100, min_samples_leaf = 50, max_depth = 5, random_state = 123)\n",
    "e_T = T - cross_val_predict(mu_T, W, T)\n",
    "sigma2_T.fit(W, e_T ** 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 34,
     "status": "ok",
     "timestamp": 1626715378662,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "jH17wYhSlU65"
   },
   "outputs": [],
   "source": [
    "def gen_T(W): # T ~ N(\\mu(W), \\sigma^2(W))\n",
    "    n = W.shape[0]\n",
    "    return (mu_T.predict(W) + np.sqrt(sigma2_T.predict(W)) * np.random.normal(size=(n,))).reshape(-1,1)\n",
    "\n",
    "def true_rr(X):\n",
    "    return (X[:, 0] - mu_T.predict(X[:, 1:]))/(sigma2_T.predict(X[:, 1:]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Simulations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "5HWQ_81xplDM",
    "outputId": "8b567dbb-44c0-4250-89b0-758a0141606c"
   },
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "    np.random.seed(i)\n",
    "\n",
    "    b = np.random.uniform(-0.5, 0.5, size=(20, 1))\n",
    "    c = np.random.uniform(-0.2, 0.2, size=(8, 1))\n",
    "\n",
    "    def nonlin(X):\n",
    "        return 1.5*scipy.special.expit(10 * X[:, 6]) + 1.5*scipy.special.expit(10 * X[:, 8])\n",
    "\n",
    "    def true_f_simple(X):\n",
    "        return -0.6 * X[:, 0]\n",
    "\n",
    "    def true_f_simple_lin_conf(X):\n",
    "        return true_f_simple(X) + np.matmul(X[:, 1:21], b).flatten()\n",
    "\n",
    "    def true_f_simple_nonlin_conf(X):\n",
    "        return true_f_simple_lin_conf(X) + nonlin(X)\n",
    "\n",
    "    def true_f_compl(X):\n",
    "        return -0.5 * (X[:, 1]**2/10 + .5) * X[:, 0]**3 / 3\n",
    "\n",
    "    def true_f_compl_lin_conf(X):\n",
    "        return -0.5 * (X[:, 1]**2/10 + np.matmul(X[:, 1:9], c).flatten() + .5) * X[:, 0]**3 / 3 + np.matmul(X[:, 1:21], b).flatten()\n",
    "\n",
    "    def true_f_compl_nonlin_conf(X):\n",
    "        return true_f_compl_lin_conf(X) + nonlin(X)\n",
    "\n",
    "    for true_f in [true_f_simple, true_f_simple_lin_conf, true_f_simple_nonlin_conf,\n",
    "                   true_f_compl, true_f_compl_lin_conf, true_f_compl_nonlin_conf]:\n",
    "        print(\"Now trying \" + true_f.__name__)\n",
    "\n",
    "        def gen_y(X):\n",
    "            n = X.shape[0]\n",
    "            return true_f(X) + np.random.normal(0, np.sqrt(5.6 * np.var(true_f(X))), size = (n,))\n",
    "        \n",
    "        path = './results/BHP/ForestRiesz/' + true_f.__name__\n",
    "\n",
    "        if not os.path.exists(path):\n",
    "            os.makedirs(path)\n",
    "            \n",
    "        for xfit, mult in [(0, True), (0, False), (1, True), (1, False), (2, False)]:\n",
    "            namedata = path + '/xfit_' + str(xfit) + '_mult_' + str(int(mult)) + '_seed_' + str(i) + '.joblib'\n",
    "            nameplot = path + '/xfit_' + str(xfit) + '_mult_' + str(int(mult)) + '_seed_' + str(i) + '.pdf'\n",
    "            sim_fun(W, moment_fn = moment_fn, true_reg = true_f, true_rr = true_rr, gen_y = gen_y, gen_T = gen_T, N_sim = 100, \n",
    "                    xfit = xfit, multitasking = mult, ForestRiesz_opt = ForestRiesz_opt, RFreg_opt = RFreg_opt, RFrr_opt = RFrr_opt,\n",
    "                    seed = i, verbose = 0, plot = True, save = namedata, saveplot = nameplot)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary Outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### LaTeX Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f_string = [\"1. Simple $f$\",\n",
    "            \"2. Simple $f$ with linear confound.\",\n",
    "            \"3. Simple $f$ with linear and non-linear confound.\",\n",
    "            \"4. Complex $f$\",\n",
    "            \"5. Complex $f$ with linear confound.\",\n",
    "            \"6. Complex $f$ with linear and non-linear confound.\"]\n",
    "\n",
    "true_fs = ['true_f_simple', 'true_f_simple_lin_conf', 'true_f_simple_nonlin_conf',\n",
    "           'true_f_compl', 'true_f_compl_lin_conf', 'true_f_compl_nonlin_conf']\n",
    "\n",
    "methods = ['reg', 'ips', 'dr', 'tmle']\n",
    "    \n",
    "with open(\"./results/BHP/ForestRiesz/res_avg_der_RF.tex\", \"w\") as f:\n",
    "    f.write(\"\\\\begin{tabular}{*{16}{r}} \\n\" +\n",
    "            \"\\\\toprule \\n\" +\n",
    "            \"&&&& \\\\multicolumn{3}{c}{Direct} & \\\\multicolumn{3}{c}{IPS} & \\\\multicolumn{3}{c}{DR} & \\\\multicolumn{3}{c}{TMLE-DR} \\\\\\\\ \\n\" +\n",
    "            \"\\\\cmidrule(lr){5-7} \\\\cmidrule(lr){8-10} \\\\cmidrule(lr){11-13} \\\\cmidrule(lr){14-16} \\n\" +\n",
    "            \"x-fit & multit. & reg $R^2$ &  rr $R^2$ &  Bias &  RMSE &  Cov. &  Bias &  RMSE &  Cov. &  Bias &  RMSE &  Cov. &  Bias &  RMSE &  Cov. \\\\\\\\ \\n\" +\n",
    "            \"\\\\midrule \\n\")\n",
    "    \n",
    "    for f_i, true_f in enumerate(true_fs):\n",
    "        path = './results/BHP/ForestRiesz/' + true_f\n",
    "        f.write(\"\\\\addlinespace \\n \\\\multicolumn{16}{l}{\\\\textbf{\" + f_string[f_i] + \"}} \\\\\\\\ \\n\")\n",
    "        \n",
    "        for xfit, mult in [(0, True), (0, False), (1, True), (1, False), (2, False)]:\n",
    "            mult_str = \"Yes\" if mult else \"No\"\n",
    "            \n",
    "            f.write(\" & \".join([str(xfit), mult_str]) + \" & \")\n",
    "            \n",
    "            r2_reg, r2_rr = [], []\n",
    "            res = {}\n",
    "            for method in methods:\n",
    "                res[method] = {'bias': [], 'rmse': [], 'cov': []}\n",
    "            \n",
    "            for i in range(10):\n",
    "                namedata = path + '/xfit_' + str(xfit) + '_mult_' + str(int(mult)) + '_seed_' + str(i) + '.joblib'\n",
    "                loaded = load(namedata)\n",
    "                r2_reg = np.append(r2_reg, loaded[2])\n",
    "                r2_rr = np.append(r2_rr, loaded[4])\n",
    "                \n",
    "                for method in methods:\n",
    "                    res[method]['bias'].append(loaded[0][method]['bias'])\n",
    "                    res[method]['rmse'].append(loaded[0][method]['rmse'])\n",
    "                    res[method]['cov'].append(loaded[0][method]['cov'])\n",
    "            \n",
    "            f.write(\" & \".join([\"${:.3f}$\".format(np.mean(x)) for x in [r2_reg, r2_rr]]) + \" & \")\n",
    "            f.write(\" & \".join([\"${:.3f}$\".format(np.mean(res[method][x])) for method in methods\n",
    "                                for x in ['bias', 'rmse', 'cov']]) + \" \\\\\\\\ \\n\")\n",
    "\n",
    "    f.write(\"\\\\bottomrule \\n \\\\end{tabular}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Histograms over 10 Seeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for true_f in true_fs: \n",
    "    \n",
    "    path = './results/BHP/ForestRiesz/' + true_f\n",
    "    \n",
    "    for xfit, mult in [(0, True), (0, False), (1, True), (1, False), (2, False)]:            \n",
    "            rmse_reg, r2_reg, rmse_rr, r2_rr, ipsbias, drbias, truth = [], [], [], [], [], [], []\n",
    "            res = {}\n",
    "            \n",
    "            for method in methods:\n",
    "                res[method] = {'point' : [], 'bias': [], 'rmse': [], 'cov': []}\n",
    "            \n",
    "            for i in range(10):\n",
    "                namedata = path + '/xfit_' + str(xfit) + '_mult_' + str(int(mult)) + '_seed_' + str(i) + '.joblib'\n",
    "                loaded = load(namedata)\n",
    "                rmse_reg = np.append(rmse_reg, loaded[1])\n",
    "                r2_reg = np.append(r2_reg, loaded[2])\n",
    "                rmse_rr = np.append(rmse_rr, loaded[3])\n",
    "                r2_rr = np.append(r2_rr, loaded[4])\n",
    "                ipsbias = np.append(ipsbias, loaded[5])\n",
    "                drbias = np.append(drbias, loaded[6])\n",
    "                truth = np.append(truth, loaded[7])\n",
    "                \n",
    "                for method in methods:\n",
    "                    res[method]['point'] = np.append(res[method]['point'], loaded[0][method]['point'])\n",
    "                    res[method]['bias'].append(loaded[0][method]['bias'])\n",
    "                    res[method]['rmse'].append(loaded[0][method]['rmse'])\n",
    "                    res[method]['cov'].append(loaded[0][method]['cov'])\n",
    "            \n",
    "            nuisance_str = (\"reg RMSE: {:.3f}, R2: {:.3f}, rr RMSE: {:.3f}, R2: {:.3f}\\n\"\n",
    "                            \"IPS orthogonality: {:.3f}, DR orthogonality: {:.3f}\").format(np.mean(rmse_reg), np.mean(r2_reg),\n",
    "                                                                                          np.mean(rmse_rr), np.mean(r2_rr),\n",
    "                                                                                          np.mean(ipsbias), np.mean(drbias))\n",
    "            method_strs = [\"{}. Bias: {:.3f}, RMSE: {:.3f}, Coverage: {:.3f}\".format(method, np.mean(d['bias']), np.mean(d['rmse']), np.mean(d['cov']))\n",
    "                           for method, d in res.items()]\n",
    "            plt.title(\"\\n\".join([nuisance_str] + method_strs))\n",
    "            for method, d in res.items():\n",
    "                plt.hist(np.array(d['point']), alpha=.5, label=method)\n",
    "            plt.axvline(x = np.mean(truth), label='true', color='red')\n",
    "            plt.legend()\n",
    "            nameplot = path + '/xfit_' + str(xfit) + '_mult_' + str(int(mult)) + '_all.pdf'\n",
    "            plt.savefig(nameplot, bbox_inches='tight')\n",
    "            plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "BHP_semisynth.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
