{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# !pip install econml"
   ],
   "outputs": [],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 15032,
     "status": "ok",
     "timestamp": 1626040785027,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "Px1d4VFblOVY",
    "outputId": "0f6d21f1-621a-4851-da4c-5c22bab36f19"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# IHDP: ForestRiesz"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Library Imports"
   ],
   "metadata": {
    "id": "CumqS__UdZQj"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "from pathlib import Path\n",
    "import os\n",
    "import glob\n",
    "from joblib import dump, load\n",
    "import pandas as pd\n",
    "import scipy\n",
    "import scipy.stats\n",
    "import scipy.special\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from utils.forestriesz import ForestRieszATE\n",
    "from utils.ihdp_data import *"
   ],
   "outputs": [],
   "metadata": {
    "executionInfo": {
     "elapsed": 461,
     "status": "ok",
     "timestamp": 1626040879451,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "Tb0UNXId8Sbh"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Moment Definition"
   ],
   "metadata": {
    "id": "lGjlQO03CBbu"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "def moment_fn(x, test_fn): # Returns the moment for the ATE example, for each sample in x\n",
    "    t1 = np.hstack([np.ones((x.shape[0], 1)), x[:, 1:]])\n",
    "    t0 = np.hstack([np.zeros((x.shape[0], 1)), x[:, 1:]])\n",
    "    return test_fn(t1) - test_fn(t0)"
   ],
   "outputs": [],
   "metadata": {
    "executionInfo": {
     "elapsed": 58,
     "status": "ok",
     "timestamp": 1626040879892,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "6EsyZwYtCAKa"
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## MAE Experiment"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "data_base_dir = \"./data/IHDP/sim_data\"\n",
    "simulation_files = sorted(glob.glob(\"{}/*.csv\".format(data_base_dir)))"
   ],
   "outputs": [],
   "metadata": {
    "executionInfo": {
     "elapsed": 260,
     "status": "aborted",
     "timestamp": 1626041846619,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "wFQ-JQm6lFDC"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "nsims = 1000\n",
    "np.random.seed(123)\n",
    "sim_ids = np.random.choice(len(simulation_files), nsims, replace=False)\n",
    "methods = ['dr', 'direct', 'ips']\n",
    "\n",
    "true_ATEs = []\n",
    "results = []\n",
    "\n",
    "for it, sim in enumerate(sim_ids):\n",
    "    simulation_file = simulation_files[sim]\n",
    "    x = load_and_format_covariates(simulation_file, delimiter=' ')\n",
    "    t, y, y_cf, mu_0, mu_1 = load_other_stuff(simulation_file, delimiter=' ')\n",
    "    X = np.c_[t, x]\n",
    "    true_ATE = np.mean(mu_1 - mu_0)\n",
    "    true_ATEs.append(true_ATE)\n",
    "\n",
    "    y_scaler = StandardScaler(with_mean=True).fit(y)\n",
    "    y = y_scaler.transform(y)\n",
    "    est = ForestRieszATE(criterion='het', n_estimators=1000, min_samples_leaf=2,\n",
    "                         min_var_fraction_leaf=0.001, min_var_leaf_on_val=True,\n",
    "                         min_impurity_decrease = 0.01, max_samples=.8, max_depth=None,\n",
    "                         warm_start=False, inference=False, subforest_size=1,\n",
    "                         honest=True, verbose=0, n_jobs=-1, random_state=123)\n",
    "    est.fit(X[:, 1:], X[:, [0]], y.reshape(-1, 1))\n",
    "    \n",
    "    params = tuple(x * y_scaler.scale_[0] for method in methods\n",
    "                   for x in est.predict_ate(X, y, method = method)) + (true_ATE, )\n",
    "                        \n",
    "    results.append(params)\n",
    "\n",
    "res = tuple(np.array(x) for x in zip(*results))\n",
    "truth = res[-1:]\n",
    "res_dict = {}\n",
    "for it, method in enumerate(methods):\n",
    "    point, lb, ub = res[it * 3: (it + 1)*3]\n",
    "    res_dict[method] = {'point': point, 'lb': lb, 'ub': ub,\n",
    "                        'MAE': np.mean(np.abs(point - truth)),\n",
    "                        'std. err.': np.std(np.abs(point - truth)) / np.sqrt(nsims),\n",
    "                        }\n",
    "    print(\"{} : MAE = {:.3f} +/- {:.3f}\".format(method, res_dict[method]['MAE'], res_dict[method]['std. err.']))"
   ],
   "outputs": [],
   "metadata": {
    "executionInfo": {
     "elapsed": 262,
     "status": "aborted",
     "timestamp": 1626041846621,
     "user": {
      "displayName": "Víctor Quintas Martínez",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgzDmne50Fc-Rgd6ii0jOAYJv9OzNuPlF4x0TOY2g=s64",
      "userId": "01033527572468555224"
     },
     "user_tz": 240
    },
    "id": "rXunLlURlFDD"
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "path = './results/IHDP/ForestRiesz/MAE'\n",
    "\n",
    "if not os.path.exists(path):\n",
    "    os.makedirs(path)\n",
    "            \n",
    "dump(res_dict, path + '/IHDP_MAE_RF.joblib')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Table"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "path = './results/IHDP/ForestRiesz/MAE'\n",
    "\n",
    "if not os.path.exists(path):\n",
    "    os.makedirs(path)\n",
    "\n",
    "methods_str = [\"DR\", \"Direct\", \"IPS\"] \n",
    "\n",
    "with open(path + '/IHDP_MAE_RF.tex', \"w\") as f:\n",
    "    f.write(\"\\\\begin{tabular}{lc} \\n\" +\n",
    "            \"\\\\toprule \\n\" +\n",
    "            \"& MAE $\\\\pm$ std. err. \\\\\\\\ \\n\" +\n",
    "            \"\\\\midrule \\n\")\n",
    "    \n",
    "    for i, method in enumerate(methods):\n",
    "        f.write(\" & \".join([methods_str[i], \"{:.3f} $\\\\pm$ {:.3f}\".format(res_dict[method]['MAE'], \n",
    "                                                                          res_dict[method]['std. err.'])]) + \" \\n\")\n",
    "\n",
    "    f.write(\"\\\\bottomrule \\n \\\\end{tabular}\")"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Coverage Experiment"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "data_base_dir = \"./data/IHDP/sim_data_redraw_T\"\n",
    "simulation_files = sorted(glob.glob(\"{}/*.csv\".format(data_base_dir)))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "def rmse_fn(y_pred, y_true):\n",
    "    return np.sqrt(np.mean((y_pred - y_true)**2))\n",
    "\n",
    "nsims = 100\n",
    "np.random.seed(123)\n",
    "sim_ids = np.random.choice(len(simulation_files), nsims, replace=False)\n",
    "methods = ['dr', 'direct', 'ips']\n",
    "\n",
    "true_ATEs = []\n",
    "results = []\n",
    "\n",
    "sim_ids = np.random.choice(len(simulation_files), nsims, replace=False)\n",
    "for it, sim in enumerate(sim_ids):\n",
    "    simulation_file = simulation_files[sim]\n",
    "    x = load_and_format_covariates(simulation_file, delimiter=' ')\n",
    "    t, y, y_cf, mu_0, mu_1 = load_other_stuff(simulation_file, delimiter=' ')\n",
    "    X = np.c_[t, x]\n",
    "    true_ATE = np.mean(mu_1 - mu_0)\n",
    "    true_ATEs.append(true_ATE)\n",
    "\n",
    "    y_scaler = StandardScaler(with_mean=True).fit(y)\n",
    "    y = y_scaler.transform(y)\n",
    "    est = ForestRieszATE(criterion='het', n_estimators=100, min_samples_leaf=2,\n",
    "                         min_var_fraction_leaf=0.001, min_var_leaf_on_val=True,\n",
    "                         min_impurity_decrease = 0.01, max_samples=.8, max_depth=None,\n",
    "                         warm_start=False, inference=False, subforest_size=1,\n",
    "                         honest=True, verbose=0, n_jobs=-1, random_state=123)\n",
    "    est.fit(X[:, 1:], X[:, [0]], y.reshape(-1, 1))\n",
    "\n",
    "                        \n",
    "    params = tuple(x * y_scaler.scale_[0] for method in methods\n",
    "                   for x in est.predict_ate(X, y, method = method)) + (true_ATE, )\n",
    "                        \n",
    "    results.append(params)\n",
    "                        \n",
    "res = tuple(np.array(x) for x in zip(*results))\n",
    "truth = res[-1:]\n",
    "res_dict = {}\n",
    "for it, method in enumerate(methods):\n",
    "    point, lb, ub = res[it * 3: (it + 1)*3]\n",
    "    res_dict[method] = {'point': point, 'lb': lb, 'ub': ub,\n",
    "                        'cov': np.mean(np.logical_and(truth >= lb, truth <= ub)),\n",
    "                        'bias': np.mean(point - truth),\n",
    "                        'rmse': rmse_fn(point, truth)\n",
    "                        }\n",
    "    print(\"{} : bias = {:.3f}, rmse = {:.3f}, cov = {:.3f}\".format(method, res_dict[method]['bias'], res_dict[method]['rmse'], res_dict[method]['cov']))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "path = './results/IHDP/ForestRiesz/coverage'\n",
    "\n",
    "if not os.path.exists(path):\n",
    "    os.makedirs(path)\n",
    "    \n",
    "dump(res_dict, path + '/IHDP_coverage_RF.joblib')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Histogram"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "path = './results/IHDP/ForestRiesz/coverage'\n",
    "\n",
    "if not os.path.exists(path):\n",
    "    os.makedirs(path)\n",
    "    \n",
    "method_strs = [\"{}. Bias: {:.3f}, RMSE: {:.3f}, Coverage: {:.3f}\".format(method, d['bias'], d['rmse'], d['cov'])\n",
    "               for method, d in res_dict.items()]\n",
    "plt.title(\"\\n\".join(method_strs))\n",
    "for method, d in res_dict.items():\n",
    "    plt.hist(np.array(d['point']), alpha=.5, label=method)\n",
    "plt.axvline(x = np.mean(truth), label='true', color='red')\n",
    "plt.legend()\n",
    "plt.savefig(path + '/IHDP_coverage_RF.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "RandomForestRiesz.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}