{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "rs=1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'dynamic_panel_dgp'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-4-fd73e32a02fc>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mfrom\u001b[0m \u001b[0mdynamic_panel_dgp\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mDynamicPanelDGP\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mLongRangeDynamicPanelDGP\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      2\u001b[0m \u001b[0mn_units\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m10000\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[0mn_periods\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m4\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[0mn_treatments\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m2\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[0mn_x\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m100\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'dynamic_panel_dgp'"
     ]
    }
   ],
   "source": [
    "from dynamic_panel_dgp import DynamicPanelDGP, LongRangeDynamicPanelDGP\n",
    "n_units = 10000\n",
    "n_periods = 4\n",
    "n_treatments = 2\n",
    "n_x = 100\n",
    "s_x = 10\n",
    "s_t = 10\n",
    "sigma_x = .5\n",
    "sigma_t = .5\n",
    "sigma_y = .5\n",
    "gamma = .0\n",
    "autoreg = .1\n",
    "state_effect = .1\n",
    "conf_str = 6\n",
    "hetero_strength = 0\n",
    "hetero_inds = None\n",
    "\n",
    "# dgp_class = LongRangeDynamicPanelDGP\n",
    "dgp_class = DynamicPanelDGP\n",
    "dgp = dgp_class(n_periods, n_treatments, n_x).create_instance(s_x, sigma_x, sigma_y,\n",
    "                                                              conf_str, hetero_strength, hetero_inds,\n",
    "                                                              autoreg, state_effect,\n",
    "                                                              random_seed=rs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "metadata": {},
   "outputs": [],
   "source": [
    "Y, T, X, groups = dgp.observational_data(n_units, gamma, s_t, sigma_t, random_seed=rs)\n",
    "true_effect = dgp.true_effect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQx0lEQVR4nO3df4xd513n8feHiaMdCpXZzVDicSABoqkqWtbVEEBB5UcJTgrCbhdpE3b5tUghEmFBCNMYEAjxRwuWEEIKRFYbttWWRoW6xiphZ2ELKqgteFyXmrQdarKgzLiQocVAYEQc8+UPX4exO2PP+B773Huf90sazT3POT7PR5b18bnPPffeVBWSpMn3OX0HkCTdGBa+JDXCwpekRlj4ktQIC1+SGnFT3wGu5JZbbqnbb7+97xiSNDZOnDjxt1U1s9G+kS7822+/ncXFxb5jSNLYSPJXm+1zSUeSGmHhS1IjLHxJaoSFL0mNsPAlqREjfZfOtTh6coVDC0ucObvGrp3THNg7x/49s33HkqTeTVThHz25wsEjp1g7dx6AlbNrHDxyCsDSl9S8iVrSObSw9GLZX7R27jyHFpZ6SiRJo2OiCv/M2bVtjUtSSyaq8HftnN7WuCS1ZKIK/8DeOaZ3TF0yNr1jigN753pKJEmjY6JetL34wqx36UjSZ5uowocLpW/BS9Jn62RJJ8m9SZaSnE7yyBWO+6ok55N8RxfzSpK2bujCTzIFPArcB7wCeCDJKzY57ueBhWHnlCRtXxdX+HcBp6vq6ap6HngC2LfBcT8EvBt4toM5JUnb1EXhzwLPrNteHoy9KMks8HrgsaudLMmDSRaTLK6urnYQT5IE3RR+Nhiry7Z/CXhjVZ3f4NhL/2DV4aqar6r5mZkNv6VLknQNurhLZxm4bd32buDMZcfMA08kAbgFeF2SF6rqaAfzS5K2oIvCPw7cmeQOYAW4H/jO9QdU1R0XHyf5X8B7LXtJurGGLvyqeiHJw1y4+2YKeLyqnkry0GD/VdftJUnXXydvvKqqJ4EnLxvbsOir6nu7mFOStD0T9Vk6kqTNWfiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGdFL4Se5NspTkdJJHNti/L8lHk3wkyWKSr+tiXknS1t007AmSTAGPAvcAy8DxJMeq6mPrDvt/wLGqqiSvAt4FvHzYuSVJW9fFFf5dwOmqerqqngeeAPatP6CqnquqGmy+BCgkSTdUF4U/Czyzbnt5MHaJJK9P8gngt4H/sdnJkjw4WPZZXF1d7SCeJAm6KfxsMPZZV/BV9Z6qejmwH/i5zU5WVYerar6q5mdmZjqIJ0mCbgp/Gbht3fZu4MxmB1fV+4EvS3JLB3NLkraoi8I/DtyZ5I4kNwP3A8fWH5Dky5Nk8PjVwM3ApzuYW5K0RUPfpVNVLyR5GFgApoDHq+qpJA8N9j8G/Bfgu5OcA9aA/7ruRVxJ0g2QUe7d+fn5Wlxc7DuGJI2NJCeqan6jfb7TVpIaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5Ia0UnhJ7k3yVKS00ke2WD/f0vy0cHPB5J8ZRfzSpK2bujCTzIFPArcB7wCeCDJKy477P8DX19VrwJ+Djg87LySpO3p4gr/LuB0VT1dVc8DTwD71h9QVR+oqr8bbH4I2N3BvJKkbeii8GeBZ9ZtLw/GNvP9wO9stjPJg0kWkyyurq52EE+SBN0UfjYYqw0PTL6RC4X/xs1OVlWHq2q+quZnZmY6iCdJAripg3MsA7et294NnLn8oCSvAt4C3FdVn+5gXknSNnRxhX8cuDPJHUluBu4Hjq0/IMkXA0eA76qqP+9gTknSNg19hV9VLyR5GFgApoDHq+qpJA8N9j8G/DTwn4BfSQLwQlXNDzu3JGnrUrXhcvtImJ+fr8XFxb5jSNLYSHJiswtq32krSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY3opPCT3JtkKcnpJI9ssP/lST6Y5F+S/FgXc0qStuemYU+QZAp4FLgHWAaOJzlWVR9bd9hngP8J7B92PknSteniCv8u4HRVPV1VzwNPAPvWH1BVz1bVceBcB/NJkq5BF4U/Czyzbnt5MHZNkjyYZDHJ4urq6tDhJEkXdFH42WCsrvVkVXW4quaran5mZmaIWJKk9YZew+fCFf1t67Z3A2c6OK9GzNGTKxxaWOLM2TV27ZzmwN459u+55idzkm6wLgr/OHBnkjuAFeB+4Ds7OK9GyNGTKxw8coq1c+cBWDm7xsEjpwAsfWlMDL2kU1UvAA8DC8DHgXdV1VNJHkryEECSL0qyDPwo8FNJlpO8dNi5deMcWlh6sewvWjt3nkMLSz0lkrRdXVzhU1VPAk9eNvbYusd/zYWlHo2pM2fXtjUuafT4Tlttya6d09salzR6LHxtyYG9c0zvmLpkbHrHFAf2zvWUSNJ2dbKko8l38YVZ79KRxpeFry3bv2fWgpfGmEs6ktQIC1+SGmHhS1IjLHxJaoSFL0mNsPAlqREWviQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiP8TltpBBw9ueIXxOu6s/Clnh09ucLBI6dYO3cegJWzaxw8cgrA0lenOlnSSXJvkqUkp5M8ssH+JPnlwf6PJnl1F/NKk+DQwtKLZX/R2rnzHFpY6inRZDl6coW73/w+7njkt7n7ze/j6MmVviNt6npnHfoKP8kU8ChwD7AMHE9yrKo+tu6w+4A7Bz9fDfzq4LfUvDNn17Y1rq0bp2dPNyJrF1f4dwGnq+rpqnoeeALYd9kx+4C31wUfAnYmubWDuaWxt2vn9LbGtXXj9OzpRmTtovBngWfWbS8PxrZ7DABJHkyymGRxdXW1g3jSaDuwd47pHVOXjE3vmOLA3rmeEl3ZOC2RjNOzpxuRtYvCzwZjdQ3HXBisOlxV81U1PzMzM3Q4adTt3zPLm97wSmZ3ThNgduc0b3rDK0duyQH+fdlh5ewaxb8vO4xq6Y/Ts6cbkbWLu3SWgdvWbe8GzlzDMVKz9u+ZHcmCv9yVlh1GMf+BvXOXrIvD6D57uhFZu7jCPw7cmeSOJDcD9wPHLjvmGPDdg7t1vgb4+6r6VAdzS7qBxmmJBMbr2dONyDr0FX5VvZDkYWABmAIer6qnkjw02P8Y8CTwOuA08M/A9w07r6Qbb9fOaVY2KPdRXCK5aFyePcH1z9rJG6+q6kkulPr6scfWPS7gB7uYS1J/xmmJRJ/Nd9pK2rKLV59+DMR4svAlbcs4LZHoUn5apiQ1wsKXpEZY+JLUCAtfkhph4UtSIyx8SWqEhS9JjbDwJakRFr4kNcLCl6RGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCb7zSRDp6csWv4ZMuY+Fr4hw9uXLJF22vnF3j4JFTAJa+muaSjibOoYWlF8v+orVz5zm0sNRTImk0WPiaOGfOrm1rXGqFha+Js2vn9LbGpVZY+Jo4B/bOMb1j6pKx6R1THNg711MiaTQMVfhJ/mOS303yycHvL9jkuMeTPJvkz4aZT9qK/XtmedMbXsnszmkCzO6c5k1veKUv2Kp5qapr/8PJLwCfqao3J3kE+IKqeuMGx70GeA54e1V9xVbPPz8/X4uLi9ecT5Jak+REVc1vtG/YJZ19wNsGj98G7N/ooKp6P/CZIeeSJA1h2MJ/WVV9CmDw+wuHDZTkwSSLSRZXV1eHPZ0kaeCqb7xK8nvAF22w6ye7jwNVdRg4DBeWdK7HHJLUoqsWflV982b7kvxNklur6lNJbgWe7TSdJKkzwy7pHAO+Z/D4e4DfGvJ8kqTrZNjCfzNwT5JPAvcMtkmyK8mTFw9K8k7gg8BckuUk3z/kvJKkbRrqw9Oq6tPAazcYPwO8bt32A8PMI0kanu+0laRGWPiS1AgLX5IaYeFLUiMsfElqhIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGjHUZ+loOEdPrnBoYYkzZ9fYtXOaA3vn/N5VSdeNhd+ToydXOHjkFGvnzgOwcnaNg0dOAVj6kq4Ll3R6cmhh6cWyv2jt3HkOLSz1lEjSpLPwe3Lm7Nq2xiVpWBZ+T3btnN7WuCQNy8LvyYG9c0zvmLpkbHrHFAf2zvWUSNKk80Xbnlx8Yda7dCTdKBZ+j/bvmbXgJd0wLulIUiMsfElqhIUvSY2w8CWpERa+JDUiVdV3hk0lWQX+6hr/+C3A33YY53oap6wwXnnHKSuMV95xygrjlXeYrF9SVTMb7Rjpwh9GksWqmu87x1aMU1YYr7zjlBXGK+84ZYXxynu9srqkI0mNsPAlqRGTXPiH+w6wDeOUFcYr7zhlhfHKO05ZYbzyXpesE7uGL0m61CRf4UuS1rHwJakRE1f4Se5NspTkdJJH+s5zJUkeT/Jskj/rO8vVJLktye8n+XiSp5L8cN+ZriTJf0jyJ0n+dJD3Z/vOdDVJppKcTPLevrNcTZK/THIqyUeSLPad50qS7Ezym0k+Mfj3+7V9Z9pMkrnB3+nFn39I8iOdnX+S1vCTTAF/DtwDLAPHgQeq6mO9BttEktcAzwFvr6qv6DvPlSS5Fbi1qj6c5POBE8D+Ef67DfCSqnouyQ7gj4AfrqoP9RxtU0l+FJgHXlpV39Z3nitJ8pfAfFWN/BuZkrwN+MOqekuSm4HPraqzPce6qkGfrQBfXVXX+gbUS0zaFf5dwOmqerqqngeeAPb1nGlTVfV+4DN959iKqvpUVX148PgfgY8DI/th/nXBc4PNHYOfkb26SbIb+FbgLX1nmSRJXgq8BngrQFU9Pw5lP/Ba4C+6KnuYvMKfBZ5Zt73MCJfSuEpyO7AH+OOeo1zRYInkI8CzwO9W1Sjn/SXgx4F/7TnHVhXwf5OcSPJg32Gu4EuBVeDXBstlb0nykr5DbdH9wDu7POGkFX42GBvZq7pxlOTzgHcDP1JV/9B3niupqvNV9Z+B3cBdSUZy2SzJtwHPVtWJvrNsw91V9WrgPuAHB8uTo+gm4NXAr1bVHuCfgJF+bQ9gsPT07cBvdHneSSv8ZeC2ddu7gTM9ZZk4g7XwdwPvqKojfefZqsFT+D8A7u03yabuBr59sC7+BPBNSf53v5GurKrODH4/C7yHC8upo2gZWF737O43ufAfwKi7D/hwVf1NlyedtMI/DtyZ5I7B/5D3A8d6zjQRBi+CvhX4eFX9Yt95ribJTJKdg8fTwDcDn+g11Caq6mBV7a6q27nwb/Z9VfXfe461qSQvGbxwz2B55FuAkbzTrKr+Gngmydxg6LXASN5ocJkH6Hg5BybsS8yr6oUkDwMLwBTweFU91XOsTSV5J/ANwC1JloGfqaq39ptqU3cD3wWcGqyLA/xEVT3ZX6QruhV42+BOh88B3lVVI3+745h4GfCeC9cA3AT8elX9n34jXdEPAe8YXAQ+DXxfz3muKMnncuFOwx/o/NyTdFumJGlzk7akI0nahIUvSY2w8CWpERa+JDXCwpekRlj4ktQIC1+SGvFvDaUSF2O4M7QAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "true_effect = true_effect.flatten()\n",
    "plt.plot(true_effect, 'o')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.42178248, -0.15785924,  0.08474795, -0.081972  ,  0.00700877,\n",
       "       -0.00741936,  0.00055594, -0.00070558])"
      ]
     },
     "execution_count": 135,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true_effect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_effect=true_effect.reshape((n_periods, n_treatments))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 137,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.51409515, -0.24795617])"
      ]
     },
     "execution_count": 137,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(true_effect, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 138,
   "metadata": {},
   "outputs": [],
   "source": [
    "panelX = X.reshape(-1, n_periods, n_x)\n",
    "panelT = T.reshape(-1, n_periods, n_treatments)\n",
    "panelY = Y.reshape(-1, n_periods)\n",
    "panelGroups = groups.reshape(-1, n_periods)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LinearRegression, LassoCV, Lasso, MultiTaskLasso, MultiTaskLassoCV\n",
    "from sklearn.model_selection import GroupKFold\n",
    "import warnings\n",
    "import numpy as np\n",
    "warnings.simplefilter('ignore')\n",
    "\n",
    "np.random.seed(123)\n",
    "\n",
    "# alpha_regs = [1e-4, 1e-3, 1e-2, 5e-2, .1, 1]\n",
    "def lasso_model(lr):\n",
    "    if lr:\n",
    "        return LinearRegression()\n",
    "    return LassoCV(cv=3, n_alphas=10, max_iter=2000)\n",
    "\n",
    "\n",
    "def mlasso_model(lr):\n",
    "    if lr:\n",
    "        return LinearRegression()\n",
    "    return MultiTaskLassoCV(cv=3, n_alphas=10, max_iter=2000)\n",
    "\n",
    "lr=True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Estimate Dynamic Effects\n",
    "from econml.dynamic.dml import DynamicDML\n",
    "ests = []\n",
    "for t in range(1, n_periods + 1):\n",
    "    ests.append(DynamicDML(model_t=mlasso_model(lr),\n",
    "                            model_y=lasso_model(lr),\n",
    "                            ).fit(panelY[:, :t].reshape(-1,),\n",
    "                                                  panelT[:, :t,\n",
    "                                                         :].reshape(-1, n_treatments),\n",
    "                                                   X=None,\n",
    "                                                  W=panelX[:, :t,\n",
    "                                                         :].reshape(-1, n_x),\n",
    "                                                  groups=panelGroups[:, :t].reshape(-1,),inference=\"auto\",cache_values=\"True\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create matrix of dynamic effects\n",
    "true_effect = true_effect.flatten()\n",
    "effect = np.empty_like(np.zeros((n_periods, n_periods, n_treatments)))\n",
    "effect[:, :, :] = np.nan\n",
    "true_eff = np.empty_like(np.zeros((n_periods, n_periods, n_treatments)))\n",
    "true_eff[:, :, :] = np.nan\n",
    "for p in range(n_periods):\n",
    "    param_hat = ests[p].intercept_\n",
    "    for kappa in range(p + 1):\n",
    "        for t in range(n_treatments):\n",
    "            param_ind = kappa*n_treatments + t\n",
    "            effect[p - kappa][p][t] = param_hat[param_ind]\n",
    "            true_eff[p - kappa][p][t] = true_effect[param_ind]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[ 0.42178248, -0.15785924],\n",
       "        [ 0.08474795, -0.081972  ],\n",
       "        [ 0.05607016, -0.0593549 ],\n",
       "        [ 0.0355804 , -0.04515682]],\n",
       "\n",
       "       [[        nan,         nan],\n",
       "        [ 0.42178248, -0.15785924],\n",
       "        [ 0.08474795, -0.081972  ],\n",
       "        [ 0.05607016, -0.0593549 ]],\n",
       "\n",
       "       [[        nan,         nan],\n",
       "        [        nan,         nan],\n",
       "        [ 0.42178248, -0.15785924],\n",
       "        [ 0.08474795, -0.081972  ]],\n",
       "\n",
       "       [[        nan,         nan],\n",
       "        [        nan,         nan],\n",
       "        [        nan,         nan],\n",
       "        [ 0.42178248, -0.15785924]]])"
      ]
     },
     "execution_count": 125,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true_eff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Truth\n",
    "true_long_range_effects = np.nansum(true_eff, axis=1)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Proposed Method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from econml.dml import LinearDML\n",
    "\n",
    "panelYadj = panelY.copy()\n",
    "for i in range(n_periods):\n",
    "    for j in range(i):\n",
    "        panelYadj[:, i] -= panelT[:, i - j, :] @ effect[i - j, i]\n",
    "TotalYadj = np.sum(panelYadj, axis=1)\n",
    "\n",
    "XS = np.hstack([panelX[:, 1], panelYadj[:, :1]])\n",
    "if lr:\n",
    "    proxy_model = LinearRegression().fit(XS, TotalYadj)\n",
    "else:\n",
    "    proxy_model = LassoCV().fit(XS, TotalYadj)\n",
    "sindex_adj = proxy_model.predict(XS)\n",
    "est1 = LinearDML(model_t=mlasso_model(lr),\n",
    "                model_y=lasso_model(lr),\n",
    "                linear_first_stages=False).fit(\n",
    "    sindex_adj, panelT[:, 0], W=panelX[:, 0])\n",
    "est_on_adj_surr_effects = est1.intercept__inference()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Coefficient Results:  X is None, please call intercept_inference to learn the constant!\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table class=\"simpletable\">\n",
       "<caption>CATE Intercept Results</caption>\n",
       "<tr>\n",
       "          <td></td>          <th>point_estimate</th> <th>stderr</th>  <th>zstat</th> <th>pvalue</th> <th>ci_lower</th> <th>ci_upper</th>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>cate_intercept|T0</th>      <td>0.635</td>      <td>0.056</td> <td>11.375</td>   <td>0.0</td>    <td>0.526</td>    <td>0.745</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>cate_intercept|T1</th>     <td>-0.396</td>      <td>0.056</td> <td>-7.071</td>   <td>0.0</td>   <td>-0.505</td>   <td>-0.286</td> \n",
       "</tr>\n",
       "</table><br/><br/><sub>A linear parametric conditional average treatment effect (CATE) model was fitted:<br/>$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$<br/>where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:<br/>$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$<br/>where $\\phi(X)$ is the output of the `featurizer` or $X$ if `featurizer`=None. Coefficient Results table portrays the $coef_{ij}$ parameter vector for each outcome $i$ and treatment $j$. Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>"
      ],
      "text/plain": [
       "<class 'econml.utilities.Summary'>\n",
       "\"\"\"\n",
       "                         CATE Intercept Results                        \n",
       "=======================================================================\n",
       "                  point_estimate stderr zstat  pvalue ci_lower ci_upper\n",
       "-----------------------------------------------------------------------\n",
       "cate_intercept|T0          0.635  0.056 11.375    0.0    0.526    0.745\n",
       "cate_intercept|T1         -0.396  0.056 -7.071    0.0   -0.505   -0.286\n",
       "-----------------------------------------------------------------------\n",
       "\n",
       "<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:\n",
       "$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$\n",
       "where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:\n",
       "$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$\n",
       "where $\\phi(X)$ is the output of the `featurizer` or $X$ if `featurizer`=None. Coefficient Results table portrays the $coef_{ij}$ parameter vector for each outcome $i$ and treatment $j$. Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>\n",
       "\"\"\""
      ]
     },
     "execution_count": 128,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "est1.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.598181  , -0.34434296])"
      ]
     },
     "execution_count": 129,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# compare with truth\n",
    "true_long_range_effects"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using Unadjusted Surrogate Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "metadata": {},
   "outputs": [],
   "source": [
    "XS = np.hstack([panelX[:, 1], panelY[:, :1]])\n",
    "TotalY = np.sum(panelY, axis=1)\n",
    "if lr:\n",
    "    unadjusted_proxy_model = LinearRegression().fit(XS, TotalY)\n",
    "else:\n",
    "    unadjusted_proxy_model = LassoCV().fit(XS, TotalY)\n",
    "sindex = unadjusted_proxy_model.predict(XS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Coefficient Results:  X is None, please call intercept_inference to learn the constant!\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table class=\"simpletable\">\n",
       "<caption>CATE Intercept Results</caption>\n",
       "<tr>\n",
       "          <td></td>          <th>point_estimate</th> <th>stderr</th>  <th>zstat</th> <th>pvalue</th> <th>ci_lower</th> <th>ci_upper</th>\n",
       "</tr>\n",
       "<tr>\n",
       "  <th>cate_intercept|T0</th>      <td>0.688</td>      <td>0.056</td>  <td>12.39</td>   <td>0.0</td>    <td>0.579</td>    <td>0.796</td> \n",
       "</tr>\n",
       "<tr>\n",
       "  <th>cate_intercept|T1</th>      <td>-0.38</td>      <td>0.056</td> <td>-6.802</td>   <td>0.0</td>   <td>-0.489</td>    <td>-0.27</td> \n",
       "</tr>\n",
       "</table><br/><br/><sub>A linear parametric conditional average treatment effect (CATE) model was fitted:<br/>$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$<br/>where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:<br/>$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$<br/>where $\\phi(X)$ is the output of the `featurizer` or $X$ if `featurizer`=None. Coefficient Results table portrays the $coef_{ij}$ parameter vector for each outcome $i$ and treatment $j$. Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>"
      ],
      "text/plain": [
       "<class 'econml.utilities.Summary'>\n",
       "\"\"\"\n",
       "                         CATE Intercept Results                        \n",
       "=======================================================================\n",
       "                  point_estimate stderr zstat  pvalue ci_lower ci_upper\n",
       "-----------------------------------------------------------------------\n",
       "cate_intercept|T0          0.688  0.056  12.39    0.0    0.579    0.796\n",
       "cate_intercept|T1          -0.38  0.056 -6.802    0.0   -0.489    -0.27\n",
       "-----------------------------------------------------------------------\n",
       "\n",
       "<sub>A linear parametric conditional average treatment effect (CATE) model was fitted:\n",
       "$Y = \\Theta(X)\\cdot T + g(X, W) + \\epsilon$\n",
       "where for every outcome $i$ and treatment $j$ the CATE $\\Theta_{ij}(X)$ has the form:\n",
       "$\\Theta_{ij}(X) = \\phi(X)' coef_{ij} + cate\\_intercept_{ij}$\n",
       "where $\\phi(X)$ is the output of the `featurizer` or $X$ if `featurizer`=None. Coefficient Results table portrays the $coef_{ij}$ parameter vector for each outcome $i$ and treatment $j$. Intercept Results table portrays the $cate\\_intercept_{ij}$ parameter.</sub>\n",
       "\"\"\""
      ]
     },
     "execution_count": 131,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "\n",
    "est = LinearDML(model_y=lasso_model(lr),\n",
    "                model_t=mlasso_model(lr),\n",
    "                linear_first_stages=False).fit(sindex, panelT[:, 0], W=panelX[:, 0])\n",
    "est.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.48576428, 0.03636293])"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# compare with truth\n",
    "true_long_range_effects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
