{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib\n",
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "from sklearn.metrics import mean_squared_error as mse, auc\n",
    "from scipy import sparse\n",
    "from scipy.stats import gamma\n",
    "from scipy.stats import ttest_ind\n",
    "import warnings\n",
    "import seaborn as sns\n",
    "sns.set(style=\"darkgrid\")\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
    "warnings.filterwarnings(\"ignore\", category=FutureWarning)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pokec simulation results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "pokec_out = '../pokec_pifonly_predictive_check/'\n",
    "num_exps=10\n",
    "mixture_prob = 0.5\n",
    "confounding_configs = (20., 20.)\n",
    "Ks = [5, 10, 15]\n",
    "A_perf = np.zeros((num_exps, 3))\n",
    "YP_perf = np.zeros((num_exps, 3))\n",
    "Y_perf = np.zeros((num_exps, 3))\n",
    "for exp_iter in range(num_exps):\n",
    "    for k_iter, K in enumerate(Ks):\n",
    "        f_name = 'conf=' + str(confounding_configs) + ';mixture_prob=' + str(mixture_prob) + ';K=' + str(K) + '_scores.npy'\n",
    "        result_file = os.path.join(pokec_out, str(exp_iter+1), 'pif', f_name)\n",
    "        results = np.load(result_file)\n",
    "        A_perf[exp_iter,k_iter] = results[0]\n",
    "        YP_perf[exp_iter,k_iter] = results[1]\n",
    "        Y_perf[exp_iter,k_iter] = results[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "pokec_out = '../out/pif_ppc/'\n",
    "num_exps=10\n",
    "mixture_prob = 0.\n",
    "confounding_configs = (100., 100.)\n",
    "Ks = [5, 10]\n",
    "A_perf = np.zeros((num_exps, 2))\n",
    "for exp_iter in range(num_exps):\n",
    "    for k_iter, K in enumerate(Ks):\n",
    "        f_name = 'conf=' + str(confounding_configs) + ';mixture_prob=' + str(mixture_prob) + ';K=' + str(K) + '_scores.npy'\n",
    "        result_file = os.path.join(pokec_out, str(exp_iter+1), f_name)\n",
    "        results = np.load(result_file)\n",
    "        A_perf[exp_iter,k_iter] = results[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mixture_prob = 0.\n",
    "confounding_configs = (100., 100.)\n",
    "Ks = [5, 10]\n",
    "YP_perf = np.zeros((num_exps, 2))\n",
    "for exp_iter in range(num_exps):\n",
    "    for k_iter, K in enumerate(Ks):\n",
    "        f_name = 'conf=' + str(confounding_configs) + ';mixture_prob=' + str(mixture_prob) + ';K=' + str(K) + '_scores.npy'\n",
    "        result_file = os.path.join(pokec_out, str(exp_iter+1), f_name)\n",
    "        results = np.load(result_file)\n",
    "        YP_perf[exp_iter,k_iter] = results[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([0.631, 0.978]), array([0.416, 0.376]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A_perf.mean(axis=0), YP_perf.mean(axis=0)\n",
    "# from scipy.stats import ttest_rel\n",
    "# ttest_rel(A_perf[:,0], A_perf[:,1]), ttest_rel(YP_perf[:,0], YP_perf[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5 & 68.3 $\\pm$ 0.86&83.5 $\\pm$ 0.5&85.7 $\\pm$ 0.4 && 81.4 $\\pm$ 0.69&85.0 $\\pm$ 0.14&86.8 $\\pm$ 0.08 \\\\\n",
      "10 & 65.3 $\\pm$ 0.86&83.8 $\\pm$ 0.13&87.3 $\\pm$ 0.09 && 86.9 $\\pm$ 0.76&85.1 $\\pm$ 0.08&87.1 $\\pm$ 0.1 \\\\\n",
      "15 & 67.5 $\\pm$ 0.64&84.4 $\\pm$ 0.09&87.4 $\\pm$ 0.09 && 86.9 $\\pm$ 0.58&84.6 $\\pm$ 0.15&86.8 $\\pm$ 0.08 \\\\\n"
     ]
    }
   ],
   "source": [
    "models = ['spf', 'multi_cause']\n",
    "u_c=20.\n",
    "i_c=20.\n",
    "\n",
    "def fmt_mean_std(mean, std):\n",
    "    return str(round(mean*100, 1)) + ' $\\pm$ ' + str(round(std*10,2))\n",
    "\n",
    "for K in [5,10,15]:\n",
    "    results =[]\n",
    "    for model in models:\n",
    "\n",
    "        heldout_a = []\n",
    "        heldout_past = []\n",
    "        heldout_outcomes = []\n",
    "        for i in range(num_exps):\n",
    "            heldout_a.append(exp_results[(u_c, i_c, K)][model][i][0])\n",
    "            heldout_past.append(exp_results[(u_c, i_c, K)][model][i][1])\n",
    "            heldout_outcomes.append(exp_results[(u_c, i_c, K)][model][i][2])\n",
    "        \n",
    "        results += [np.mean(heldout_a), np.std(heldout_a), np.mean(heldout_past), \n",
    "                    np.std(heldout_past),np.mean(heldout_outcomes), np.std(heldout_outcomes)]\n",
    "    print(K, \"&\", \n",
    "          \"&\".join([fmt_mean_std(results[0], results[1]), \n",
    "                   fmt_mean_std(results[2], results[3]),\n",
    "                    fmt_mean_std(results[4], results[5])]),\n",
    "                      \"&&\",\n",
    "                    \"&\".join([\n",
    "                    fmt_mean_std(results[6], results[7]),\n",
    "                    fmt_mean_std(results[8], results[9]),\n",
    "                    fmt_mean_std(results[10], results[11])]),\n",
    "          \"\\\\\\\\\")\n",
    "    \n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
