{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "from sklearn.metrics import mean_squared_error as mse_f\n",
    "from scipy import sparse\n",
    "from scipy.stats import gamma\n",
    "from scipy.stats import ttest_ind\n",
    "import warnings\n",
    "import pandas as pd\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
    "warnings.filterwarnings(\"ignore\", category=FutureWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "paper_model_names = models = {'pif.z-only':'PIF-Item+Net',\n",
    "          'pif.z-theta-joint':'PIF-Item+Net+Purchase',\n",
    "            'spf.main':'MSPF', \n",
    "          'unadjusted.main':'Unadjusted',\n",
    "          'network_pref_only.main':'Network Only',\n",
    "          'item_only.main':'PIF-Item',\n",
    "         'no_unobs.main':'BD Adj.',\n",
    "         'item_only_oracle.main':'PIF-Item (No region)'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_table(exp_results, regimes, models, exps=10):\n",
    "    ncols = len(regimes.keys())\n",
    "    nrows = len(models)\n",
    "    results=np.zeros((nrows, ncols))\n",
    "    std = np.zeros((nrows, ncols))\n",
    "    \n",
    "    col_idx=0\n",
    "    for regime, c in regimes.items():\n",
    "        row_idx=0\n",
    "        for model in models:\n",
    "            mse = np.zeros(exps)\n",
    "            for i in range(exps):\n",
    "                try:\n",
    "                    beta_predicted = exp_results[c][model][i][0]\n",
    "                    truth = exp_results[c][model][i][1]  \n",
    "                    sq_err = (beta_predicted - truth) ** 2\n",
    "                    mse[i] = sq_err.mean()\n",
    "                except:\n",
    "                    print(model, 'exp',i,'not found')\n",
    "            results[row_idx][col_idx]=round(mse.mean()*1000, 2)\n",
    "            std[row_idx][col_idx]=round(mse.std()*1000,2)\n",
    "            row_idx += 1\n",
    "        col_idx += 1\n",
    "    \n",
    "    proper_names = [paper_model_names[m] for m in models]\n",
    "    col_names = list(regimes.keys())\n",
    "    df = pd.DataFrame(results, index=proper_names, columns=col_names, dtype=str)\n",
    "    std_df = pd.DataFrame(std, index=proper_names, columns=col_names, dtype=str)\n",
    "    df = df + '$\\pm$' + std_df\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "pokec_out = '../../out/pokec_paper_results'\n",
    "exps = 10\n",
    "embed='user'\n",
    "models = ['no_unobs.main',\n",
    "          'unadjusted.main',\n",
    "          'spf.main',\n",
    "          'network_pref_only.main',\n",
    "          'item_only.main',\n",
    "          'pif.z-only',\n",
    "          'pif.z-theta-joint',\n",
    "          'item_only_oracle.main']\n",
    "\n",
    "conf_types = ['homophily', 'exog', 'both']\n",
    "confounding_strengths = [(50, 10), (50, 50), (50, 100)]\n",
    "exp_results = {}\n",
    "\n",
    "for i in range(1, exps+1):\n",
    "    for model in models:\n",
    "        for (cov1conf, cov2conf) in confounding_strengths:\n",
    "            for ct in conf_types:\n",
    "                try:\n",
    "                    base_file_name = 'conf=' + str((cov1conf, cov2conf)) +';conf_type=' +ct + '.npz'\n",
    "                    result_file = os.path.join(pokec_out, str(i), model + '_model_fitted_params', base_file_name)\n",
    "                    res = np.load(result_file)\n",
    "                    params = res['fitted'] \n",
    "                    truth = res['true']\n",
    "\n",
    "                    if (ct, (cov1conf,cov2conf)) in exp_results:\n",
    "                        if model in exp_results[(ct, (cov1conf,cov2conf))]:\n",
    "                            exp_results[(ct, (cov1conf,cov2conf))][model].append((params, truth))\n",
    "                        else:\n",
    "                            exp_results[(ct, (cov1conf,cov2conf))][model]= [(params, truth)]\n",
    "                    else:\n",
    "                        exp_results[(ct, (cov1conf,cov2conf))] = {model:[(params, truth)]}\n",
    "                except:\n",
    "                    print(result_file,' not found')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "confounding_type='exog'\n",
    "models = list(exp_results[(confounding_type, confounding_strengths[1])].keys())\n",
    "regime1 = {'Low':(confounding_type, confounding_strengths[0]), \n",
    "           'Med.':(confounding_type, confounding_strengths[1]), \n",
    "           'High':(confounding_type, confounding_strengths[2])}\n",
    "\n",
    "df1 = print_table(exp_results, regime1, models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "confounding_type='homophily'\n",
    "models = list(exp_results[(confounding_type, confounding_strengths[0])].keys())\n",
    "regime1 = {'Low':(confounding_type, confounding_strengths[0]), \n",
    "           'Med.':(confounding_type, confounding_strengths[1]), \n",
    "           'High':(confounding_type, confounding_strengths[2])}\n",
    "\n",
    "df2= print_table(exp_results, regime1, models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "confounding_type='both'\n",
    "models = list(exp_results[(confounding_type, confounding_strengths[0])].keys())\n",
    "regime1 = {'Low':(confounding_type, confounding_strengths[0]), \n",
    "           'Med.':(confounding_type, confounding_strengths[1]), \n",
    "           'High':(confounding_type, confounding_strengths[2])}\n",
    "\n",
    "df3 = print_table(exp_results, regime1, models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Low</th>\n",
       "      <th>Med.</th>\n",
       "      <th>High</th>\n",
       "      <th>Low</th>\n",
       "      <th>Med.</th>\n",
       "      <th>High</th>\n",
       "      <th>Low</th>\n",
       "      <th>Med.</th>\n",
       "      <th>High</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>BD Adj.</th>\n",
       "      <td>0.17$\\pm$0.07</td>\n",
       "      <td>0.17$\\pm$0.16</td>\n",
       "      <td>0.2$\\pm$0.17</td>\n",
       "      <td>0.29$\\pm$0.06</td>\n",
       "      <td>0.27$\\pm$0.08</td>\n",
       "      <td>0.25$\\pm$0.06</td>\n",
       "      <td>0.13$\\pm$0.06</td>\n",
       "      <td>0.13$\\pm$0.09</td>\n",
       "      <td>0.12$\\pm$0.06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Unadjusted</th>\n",
       "      <td>1.56$\\pm$0.21</td>\n",
       "      <td>2.0$\\pm$0.35</td>\n",
       "      <td>2.16$\\pm$0.31</td>\n",
       "      <td>1.91$\\pm$0.25</td>\n",
       "      <td>2.4$\\pm$0.4</td>\n",
       "      <td>2.44$\\pm$0.33</td>\n",
       "      <td>2.21$\\pm$0.25</td>\n",
       "      <td>2.49$\\pm$0.35</td>\n",
       "      <td>2.56$\\pm$0.35</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MSPF</th>\n",
       "      <td>0.53$\\pm$0.42</td>\n",
       "      <td>0.56$\\pm$0.35</td>\n",
       "      <td>1.1$\\pm$1.55</td>\n",
       "      <td>0.66$\\pm$0.38</td>\n",
       "      <td>0.67$\\pm$0.37</td>\n",
       "      <td>0.76$\\pm$0.48</td>\n",
       "      <td>0.55$\\pm$0.34</td>\n",
       "      <td>0.62$\\pm$0.43</td>\n",
       "      <td>1.5$\\pm$2.67</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Network Only</th>\n",
       "      <td>0.48$\\pm$0.14</td>\n",
       "      <td>0.84$\\pm$0.22</td>\n",
       "      <td>0.84$\\pm$0.25</td>\n",
       "      <td>0.76$\\pm$0.2</td>\n",
       "      <td>1.08$\\pm$0.36</td>\n",
       "      <td>1.16$\\pm$0.36</td>\n",
       "      <td>0.55$\\pm$0.11</td>\n",
       "      <td>0.73$\\pm$0.15</td>\n",
       "      <td>0.78$\\pm$0.17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PIF-Item</th>\n",
       "      <td>0.17$\\pm$0.11</td>\n",
       "      <td>0.15$\\pm$0.12</td>\n",
       "      <td>0.21$\\pm$0.12</td>\n",
       "      <td>0.28$\\pm$0.08</td>\n",
       "      <td>0.38$\\pm$0.17</td>\n",
       "      <td>0.48$\\pm$0.23</td>\n",
       "      <td>0.24$\\pm$0.12</td>\n",
       "      <td>0.21$\\pm$0.12</td>\n",
       "      <td>0.23$\\pm$0.12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PIF-Item+Net</th>\n",
       "      <td>0.31$\\pm$0.2</td>\n",
       "      <td>0.31$\\pm$0.22</td>\n",
       "      <td>0.35$\\pm$0.18</td>\n",
       "      <td>0.43$\\pm$0.2</td>\n",
       "      <td>0.53$\\pm$0.22</td>\n",
       "      <td>0.5$\\pm$0.23</td>\n",
       "      <td>0.32$\\pm$0.21</td>\n",
       "      <td>0.4$\\pm$0.2</td>\n",
       "      <td>0.37$\\pm$0.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PIF-Item+Net+Purchase</th>\n",
       "      <td>0.26$\\pm$0.19</td>\n",
       "      <td>0.2$\\pm$0.15</td>\n",
       "      <td>0.3$\\pm$0.18</td>\n",
       "      <td>0.38$\\pm$0.19</td>\n",
       "      <td>0.46$\\pm$0.22</td>\n",
       "      <td>0.42$\\pm$0.24</td>\n",
       "      <td>0.29$\\pm$0.2</td>\n",
       "      <td>0.35$\\pm$0.19</td>\n",
       "      <td>0.32$\\pm$0.19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PIF-Item (No region)</th>\n",
       "      <td>0.1$\\pm$0.06</td>\n",
       "      <td>0.11$\\pm$0.04</td>\n",
       "      <td>0.13$\\pm$0.07</td>\n",
       "      <td>1.15$\\pm$0.14</td>\n",
       "      <td>1.65$\\pm$0.27</td>\n",
       "      <td>1.7$\\pm$0.22</td>\n",
       "      <td>0.3$\\pm$0.04</td>\n",
       "      <td>0.35$\\pm$0.05</td>\n",
       "      <td>0.38$\\pm$0.05</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 Low           Med.           High  \\\n",
       "BD Adj.                0.17$\\pm$0.07  0.17$\\pm$0.16   0.2$\\pm$0.17   \n",
       "Unadjusted             1.56$\\pm$0.21   2.0$\\pm$0.35  2.16$\\pm$0.31   \n",
       "MSPF                   0.53$\\pm$0.42  0.56$\\pm$0.35   1.1$\\pm$1.55   \n",
       "Network Only           0.48$\\pm$0.14  0.84$\\pm$0.22  0.84$\\pm$0.25   \n",
       "PIF-Item               0.17$\\pm$0.11  0.15$\\pm$0.12  0.21$\\pm$0.12   \n",
       "PIF-Item+Net            0.31$\\pm$0.2  0.31$\\pm$0.22  0.35$\\pm$0.18   \n",
       "PIF-Item+Net+Purchase  0.26$\\pm$0.19   0.2$\\pm$0.15   0.3$\\pm$0.18   \n",
       "PIF-Item (No region)    0.1$\\pm$0.06  0.11$\\pm$0.04  0.13$\\pm$0.07   \n",
       "\n",
       "                                 Low           Med.           High  \\\n",
       "BD Adj.                0.29$\\pm$0.06  0.27$\\pm$0.08  0.25$\\pm$0.06   \n",
       "Unadjusted             1.91$\\pm$0.25    2.4$\\pm$0.4  2.44$\\pm$0.33   \n",
       "MSPF                   0.66$\\pm$0.38  0.67$\\pm$0.37  0.76$\\pm$0.48   \n",
       "Network Only            0.76$\\pm$0.2  1.08$\\pm$0.36  1.16$\\pm$0.36   \n",
       "PIF-Item               0.28$\\pm$0.08  0.38$\\pm$0.17  0.48$\\pm$0.23   \n",
       "PIF-Item+Net            0.43$\\pm$0.2  0.53$\\pm$0.22   0.5$\\pm$0.23   \n",
       "PIF-Item+Net+Purchase  0.38$\\pm$0.19  0.46$\\pm$0.22  0.42$\\pm$0.24   \n",
       "PIF-Item (No region)   1.15$\\pm$0.14  1.65$\\pm$0.27   1.7$\\pm$0.22   \n",
       "\n",
       "                                 Low           Med.           High  \n",
       "BD Adj.                0.13$\\pm$0.06  0.13$\\pm$0.09  0.12$\\pm$0.06  \n",
       "Unadjusted             2.21$\\pm$0.25  2.49$\\pm$0.35  2.56$\\pm$0.35  \n",
       "MSPF                   0.55$\\pm$0.34  0.62$\\pm$0.43   1.5$\\pm$2.67  \n",
       "Network Only           0.55$\\pm$0.11  0.73$\\pm$0.15  0.78$\\pm$0.17  \n",
       "PIF-Item               0.24$\\pm$0.12  0.21$\\pm$0.12  0.23$\\pm$0.12  \n",
       "PIF-Item+Net           0.32$\\pm$0.21    0.4$\\pm$0.2   0.37$\\pm$0.2  \n",
       "PIF-Item+Net+Purchase   0.29$\\pm$0.2  0.35$\\pm$0.19  0.32$\\pm$0.19  \n",
       "PIF-Item (No region)    0.3$\\pm$0.04  0.35$\\pm$0.05  0.38$\\pm$0.05  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_results = pd.concat([df1, df2, df3], axis=1)\n",
    "all_results"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
