{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ],
   "outputs": [],
   "metadata": {
    "executionInfo": {
     "elapsed": 27,
     "status": "ok",
     "timestamp": 1626715360966,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "G_qWGM2RvR3a"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# BHP: RieszNet"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Library Imports"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "from pathlib import Path\n",
    "import os\n",
    "import glob\n",
    "from joblib import dump, load\n",
    "import pandas as pd\n",
    "import scipy\n",
    "import scipy.stats\n",
    "import scipy.special\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.model_selection import cross_val_predict\n",
    "from utils.NN_avgmom_sim import sim_fun\n",
    "from utils.moments import avg_small_diff"
   ],
   "outputs": [],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1011,
     "status": "ok",
     "timestamp": 1626715362111,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "tpLT3n3Vvhjb",
    "outputId": "99d93f81-f8fb-44f2-9e02-8f70936fbe2c"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## NN settings"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "drop_prob = 0.0  # dropout prob of dropout layers throughout notebook\n",
    "n_hidden = 100  # width of hidden layers throughout notebook\n",
    "\n",
    "# Training params\n",
    "learner_lr = 1e-4\n",
    "learner_l2 = 1e-3\n",
    "n_epochs = 300\n",
    "earlystop_rounds = 20 # how many epochs to wait for an out-of-sample improvement\n",
    "earlystop_delta = 1e-3\n",
    "bs = 64"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "fast_train_opt = {'earlystop_rounds' : 2, 'earlystop_delta' : earlystop_delta,\n",
    "                  'learner_lr' : 1e-3, 'learner_l2' : learner_l2, 'learner_l1' : 0.0,\n",
    "                  'n_epochs' : 100, 'bs' : bs, 'target_reg' : 1, 'riesz_weight' : 0.1,\n",
    "                  'optimizer' : 'adam'}\n",
    "            \n",
    "train_opt = {'earlystop_rounds' : earlystop_rounds, 'earlystop_delta' : earlystop_delta,\n",
    "             'learner_lr' : learner_lr, 'learner_l2' : learner_l2, 'learner_l1' : 0.0,\n",
    "             'n_epochs' : n_epochs, 'bs' : bs, 'target_reg' : 1, 'riesz_weight' : 0.1,\n",
    "             'optimizer' : 'adam'}         "
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Read Data"
   ],
   "metadata": {
    "id": "5NE5F8z7Sjhq"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "df = pd.read_csv('./data/BHP/data_BHP2.csv')\n",
    "df = df[df[\"log_p\"] > math.log(1.2)]\n",
    "df = df[df[\"log_y\"] > math.log(15000)]\n",
    "Xdf = df.iloc[:,1:]\n",
    "X_nostatedum = Xdf.drop([\"distance_oil1000\", \"share\"], axis=1).values\n",
    "columns = Xdf.columns\n",
    "state_dum = pd.get_dummies(Xdf['state_fips'], prefix=\"state\")\n",
    "Xdf = pd.concat([Xdf, state_dum], axis = 1)\n",
    "Xdf = Xdf.drop([\"distance_oil1000\", \"state_fips\", \"share\"], axis=1)\n",
    "W = Xdf.drop([\"log_p\"], axis=1).values\n",
    "T = Xdf['log_p'].values"
   ],
   "outputs": [],
   "metadata": {
    "executionInfo": {
     "elapsed": 138,
     "status": "ok",
     "timestamp": 1626715362224,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "0Vt_gLC8viPt"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Generate Semi-Synthetic Treatment"
   ],
   "metadata": {
    "id": "-DregF_xfSXp"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# Conditional Mean\n",
    "mu_T = RandomForestRegressor(n_estimators = 100, min_samples_leaf = 50, random_state = 123)\n",
    "mu_T.fit(W, T)\n",
    "\n",
    "# Conditional Variance\n",
    "sigma2_T = RandomForestRegressor(n_estimators = 100, min_samples_leaf = 50, max_depth = 5, random_state = 123)\n",
    "e_T = T - cross_val_predict(mu_T, W, T)\n",
    "sigma2_T.fit(W, e_T ** 2)"
   ],
   "outputs": [],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 16440,
     "status": "ok",
     "timestamp": 1626715378660,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "7tYXj8zufSEn",
    "outputId": "0af4e0c2-bfef-47ac-e64a-33fa53d94e06"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "def gen_T(W): # T ~ N(\\mu(W), \\sigma^2(W))\n",
    "    n = W.shape[0]\n",
    "    return (mu_T.predict(W) + np.sqrt(sigma2_T.predict(W)) * np.random.normal(size=(n,))).reshape(-1,1)\n",
    "\n",
    "def true_rr(X):\n",
    "    return (X[:, 0] - mu_T.predict(X[:, 1:]))/(sigma2_T.predict(X[:, 1:]))"
   ],
   "outputs": [],
   "metadata": {
    "executionInfo": {
     "elapsed": 34,
     "status": "ok",
     "timestamp": 1626715378662,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "jH17wYhSlU65"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Run Simulations"
   ],
   "metadata": {
    "id": "hNfu5_rpSzZm"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "for i in range(10):\n",
    "    np.random.seed(i)\n",
    "\n",
    "    b = np.random.uniform(-0.5, 0.5, size=(20, 1))\n",
    "    c = np.random.uniform(-0.2, 0.2, size=(8, 1))\n",
    "\n",
    "    def nonlin(X):\n",
    "        return 1.5*scipy.special.expit(10 * X[:, 6]) + 1.5*scipy.special.expit(10 * X[:, 8])\n",
    "\n",
    "    def true_f_simple(X):\n",
    "        return -0.6 * X[:, 0]\n",
    "\n",
    "    def true_f_simple_lin_conf(X):\n",
    "        return true_f_simple(X) + np.matmul(X[:, 1:21], b).flatten()\n",
    "\n",
    "    def true_f_simple_nonlin_conf(X):\n",
    "        return true_f_simple_lin_conf(X) + nonlin(X)\n",
    "\n",
    "    def true_f_compl(X):\n",
    "        return -0.5 * (X[:, 1]**2/10 + .5) * X[:, 0]**3 / 3\n",
    "\n",
    "    def true_f_compl_lin_conf(X):\n",
    "        return -0.5 * (X[:, 1]**2/10 + np.matmul(X[:, 1:9], c).flatten() + .5) * X[:, 0]**3 / 3 + np.matmul(X[:, 1:21], b).flatten()\n",
    "\n",
    "    def true_f_compl_nonlin_conf(X):\n",
    "        return true_f_compl_lin_conf(X) + nonlin(X)\n",
    "\n",
    "    for true_f in [true_f_simple, true_f_simple_lin_conf, true_f_simple_nonlin_conf,\n",
    "                   true_f_compl, true_f_compl_lin_conf, true_f_compl_nonlin_conf]:\n",
    "        print(\"Now trying \" + true_f.__name__)\n",
    "\n",
    "        def gen_y(X):\n",
    "            n = X.shape[0]\n",
    "            return true_f(X) + np.random.normal(0, np.sqrt(5.6 * np.var(true_f(X))), size = (n,))\n",
    "\n",
    "        path = './results/BHP/RieszNet/' + true_f.__name__\n",
    "        if not os.path.exists(path):\n",
    "            os.makedirs(path)\n",
    "\n",
    "        namedata = path + \"/seed\" + str(i) + '.joblib'\n",
    "        nameplot = path + \"/seed\" + str(i) + '.pdf'\n",
    "        sim_fun(W, moment_fn = avg_small_diff, n_hidden = n_hidden, drop_prob = drop_prob, \n",
    "                true_reg = true_f, true_rr = true_rr, gen_y = gen_y, gen_T = gen_T, \n",
    "                N_sim = 100, fast_train_opt = fast_train_opt, train_opt = train_opt,\n",
    "                seed = i, verbose = 1, plot = True, save = namedata, saveplot = nameplot)"
   ],
   "outputs": [],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "5HWQ_81xplDM",
    "outputId": "8b567dbb-44c0-4250-89b0-758a0141606c"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Summary Outputs"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### LaTeX Table"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "f_string = [\"1. Simple $f$\",\n",
    "            \"2. Simple $f$ with linear confound.\",\n",
    "            \"3. Simple $f$ with linear and non-linear confound.\",\n",
    "            \"4. Complex $f$\",\n",
    "            \"5. Complex $f$ with linear confound.\",\n",
    "            \"6. Complex $f$ with linear and non-linear confound.\"]\n",
    "\n",
    "true_fs = ['true_f_simple', 'true_f_simple_lin_conf', 'true_f_simple_nonlin_conf',\n",
    "           'true_f_compl', 'true_f_compl_lin_conf', 'true_f_compl_nonlin_conf']\n",
    "\n",
    "methods = ['direct', 'ips', 'dr']\n",
    "    \n",
    "with open(\"./results/BHP/RieszNet/res_avg_der_NN.tex\", \"w\") as f:\n",
    "    f.write(\"\\\\begin{tabular}{*{11}{r}} \\n\" +\n",
    "            \"\\\\toprule \\n\" +\n",
    "            \"&& \\\\multicolumn{3}{c}{Direct} & \\\\multicolumn{3}{c}{IPS} & \\\\multicolumn{3}{c}{DR} \\\\\\\\ \\n\" +\n",
    "            \"\\\\cmidrule(lr){3-5} \\\\cmidrule(lr){6-8} \\\\cmidrule(lr){9-11} \\n\" +\n",
    "            \"reg $R^2$ &  rr $R^2$ &  Bias &  RMSE &  Cov. &  Bias &  RMSE &  Cov. &  Bias &  RMSE &  Cov. \\\\\\\\ \\n\" +\n",
    "            \"\\\\midrule \\n\")\n",
    "    \n",
    "    for f_i, true_f in enumerate(true_fs):\n",
    "        path = './results/BHP/RieszNet/' + true_f\n",
    "        f.write(\"\\\\addlinespace \\n \\\\multicolumn{11}{l}{\\\\textbf{\" + f_string[f_i] + \"}} \\\\\\\\ \\n\")\n",
    "        \n",
    "        r2_reg, r2_rr = [], []\n",
    "        res = {}\n",
    "        for method in methods:\n",
    "            res[method] = {'bias': [], 'rmse': [], 'cov': []}\n",
    "            \n",
    "        for i in range(10):\n",
    "            namedata = path + '/seed' + str(i) + '.joblib'\n",
    "            loaded = load(namedata)\n",
    "            r2_reg = np.append(r2_reg, loaded[2])\n",
    "            r2_rr = np.append(r2_rr, loaded[4])\n",
    "                \n",
    "            for method in methods:\n",
    "                res[method]['bias'].append(loaded[0][method]['bias'])\n",
    "                res[method]['rmse'].append(loaded[0][method]['rmse'])\n",
    "                res[method]['cov'].append(loaded[0][method]['cov'])\n",
    "            \n",
    "        f.write(\" & \".join([\"${:.3f}$\".format(np.mean(x)) for x in [r2_reg, r2_rr]]) + \" & \")\n",
    "        f.write(\" & \".join([\"${:.3f}$\".format(np.mean(res[method][x])) for method in methods\n",
    "                            for x in ['bias', 'rmse', 'cov']]) + \" \\\\\\\\ \\n\")\n",
    "\n",
    "    f.write(\"\\\\bottomrule \\n \\\\end{tabular}\")"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Histograms over 10 Seeds"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "for true_f in true_fs: \n",
    "    \n",
    "    path = './results/BHP/RieszNet/' + true_f\n",
    "    \n",
    "    rmse_reg, r2_reg, rmse_rr, r2_rr, ipsbias, drbias, truth = [], [], [], [], [], [], []\n",
    "    res = {}\n",
    "            \n",
    "    for method in methods:\n",
    "        res[method] = {'point' : [], 'bias': [], 'rmse': [], 'cov': []}\n",
    "            \n",
    "    for i in range(10):\n",
    "        namedata = path + '/seed' + str(i) + '.joblib'\n",
    "        loaded = load(namedata)\n",
    "        rmse_reg = np.append(rmse_reg, loaded[1])\n",
    "        r2_reg = np.append(r2_reg, loaded[2])\n",
    "        rmse_rr = np.append(rmse_rr, loaded[3])\n",
    "        r2_rr = np.append(r2_rr, loaded[4])\n",
    "        ipsbias = np.append(ipsbias, loaded[5])\n",
    "        drbias = np.append(drbias, loaded[6])\n",
    "        truth = np.append(truth, loaded[7])\n",
    "                \n",
    "        for method in methods:\n",
    "            res[method]['point'] = np.append(res[method]['point'], loaded[0][method]['point'])\n",
    "            res[method]['bias'].append(loaded[0][method]['bias'])\n",
    "            res[method]['rmse'].append(loaded[0][method]['rmse'])\n",
    "            res[method]['cov'].append(loaded[0][method]['cov'])\n",
    "            \n",
    "    nuisance_str = (\"reg RMSE: {:.3f}, R2: {:.3f}, rr RMSE: {:.3f}, R2: {:.3f}\\n\"\n",
    "                    \"IPS orthogonality: {:.3f}, DR orthogonality: {:.3f}\").format(np.mean(rmse_reg), np.mean(r2_reg),\n",
    "                                                                                          np.mean(rmse_rr), np.mean(r2_rr),\n",
    "                                                                                          np.mean(ipsbias), np.mean(drbias))\n",
    "    method_strs = [\"{}. Bias: {:.3f}, RMSE: {:.3f}, Coverage: {:.3f}\".format(method, np.mean(d['bias']), np.mean(d['rmse']), np.mean(d['cov']))\n",
    "                    for method, d in res.items()]\n",
    "    plt.title(\"\\n\".join([nuisance_str] + method_strs))\n",
    "    for method, d in res.items():\n",
    "        plt.hist(np.array(d['point']), alpha=.5, label=method)\n",
    "    plt.axvline(x = np.mean(truth), label='true', color='red')\n",
    "    plt.legend()\n",
    "    nameplot = path + '/all.pdf'\n",
    "    plt.savefig(nameplot, bbox_inches='tight')\n",
    "    plt.show()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "BHP_semisynth.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}