{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib\n",
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "from sklearn.metrics import mean_squared_error as mse_f\n",
    "from scipy import sparse\n",
    "from scipy.stats import gamma\n",
    "from scipy.stats import ttest_ind\n",
    "import warnings\n",
    "import seaborn as sns\n",
    "sns.set(style=\"darkgrid\")\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
    "warnings.filterwarnings(\"ignore\", category=FutureWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_set_overlap(Beta_p, Beta, k=50):\n",
    "    top = np.argsort(Beta)[-k:]\n",
    "    top_p = np.argsort(Beta_p)[-k:]\n",
    "    return np.intersect1d(top, top_p).shape[0]/np.union1d(top, top_p).shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_most_least_influential(Beta,n=100):\n",
    "    sorted_inf = np.argsort(Beta)\n",
    "    highest = sorted_inf[-n:]\n",
    "    lowest = sorted_inf[:n]\n",
    "    return lowest, highest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "pokec_out = '../pokec_multi_covar_out'\n",
    "exps = {1:set(range(1,11))-{7,10}, 0:[1,2,9]}\n",
    "models = ['multi_cause', 'spf', 'naive']\n",
    "probs = [0., 0.3, 0.5, 0.8, 1.]\n",
    "confounding_strengths = [10., 100.]\n",
    "exp_results = {}\n",
    "for idx, d in enumerate([pokec_out, pokec_out + '_v1']):\n",
    "    for i in exps[idx]:\n",
    "        for model in models:\n",
    "            for c in confounding_strengths:\n",
    "                for p in probs:\n",
    "                    f_name = 'conf_age=' + str(c) +';p=' + str(p) +'_fitted.gz'\n",
    "                    truth_fname = 'conf_age=' + str(c) + ';p=' + str(p) + '_true.gz'\n",
    "                    result_file = os.path.join(d, str(i), model + '_model_fitted_params', f_name)\n",
    "                    truth_file = os.path.join(d, str(i), model + '_model_fitted_params', truth_fname)\n",
    "                    params = np.loadtxt(result_file)\n",
    "                    truth = np.loadtxt(truth_file)\n",
    "                    if (p, c) in exp_results:\n",
    "                        if model in exp_results[(p, c)]:\n",
    "                            exp_results[(p, c)][model].append((params, truth))\n",
    "                        else:\n",
    "                            exp_results[(p, c)][model]= [(params, truth)]\n",
    "                    else:\n",
    "                        exp_results[(p, c)] = {model:[(params, truth)]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Config: (0.0, 10.0)\n",
      "Model: multi_cause\n",
      "Mean MSE: 0.05226528737971172 and st dev.: 0.029131245129401986\n",
      "Mean MSE for most influential: 0.07637779958998815 and st dev.: 0.5811549220439773\n",
      "Mean MSE for less influential: 0.040097631710122365 and st dev.: 0.02171267921079188\n",
      "------------------------------------------------------------\n",
      "Config: (0.0, 10.0)\n",
      "Model: spf\n",
      "Mean MSE: 0.08112092667087024 and st dev.: 0.33249154892756805\n",
      "Mean MSE for most influential: 0.7191176608341598 and st dev.: 6.719145142035612\n",
      "Mean MSE for less influential: 0.050038047019529734 and st dev.: 0.02150017294464184\n",
      "------------------------------------------------------------\n",
      "Config: (0.0, 10.0)\n",
      "Model: naive\n",
      "Mean MSE: 0.05933955300330302 and st dev.: 0.01158701626347028\n",
      "Mean MSE for most influential: 0.05848818973020711 and st dev.: 0.24672842669885534\n",
      "Mean MSE for less influential: 0.056957966439163996 and st dev.: 0.026203797618173753\n",
      "------------------------------------------------------------\n",
      "------------------------------------------------------------\n",
      "Config: (0.3, 100.0)\n",
      "Model: multi_cause\n",
      "Mean MSE: 0.050751420085781686 and st dev.: 0.024335343385303132\n",
      "Mean MSE for most influential: 0.07846343129512133 and st dev.: 0.4841525845287791\n",
      "Mean MSE for less influential: 0.040044635249943894 and st dev.: 0.021741312105359976\n",
      "------------------------------------------------------------\n",
      "Config: (0.3, 100.0)\n",
      "Model: spf\n",
      "Mean MSE: 0.0617935126100293 and st dev.: 0.12679849458755757\n",
      "Mean MSE for most influential: 0.24791958627611543 and st dev.: 2.55568881369346\n",
      "Mean MSE for less influential: 0.050133264888584855 and st dev.: 0.02147898729164354\n",
      "------------------------------------------------------------\n",
      "Config: (0.3, 100.0)\n",
      "Model: naive\n",
      "Mean MSE: 0.06460063029700414 and st dev.: 0.011517437419205585\n",
      "Mean MSE for most influential: 0.06189831185627512 and st dev.: 0.24829125319560172\n",
      "Mean MSE for less influential: 0.05948744611697141 and st dev.: 0.02769961726056849\n",
      "------------------------------------------------------------\n",
      "------------------------------------------------------------\n",
      "Config: (0.0, 100)\n",
      "Model: multi_cause\n",
      "Mean MSE: 0.055102113842400165 and st dev.: 0.040756706348973354\n",
      "Mean MSE for most influential: 0.22056678863596108 and st dev.: 0.8414256541817721\n",
      "Mean MSE for less influential: 0.04020564682793256 and st dev.: 0.021664539074128537\n",
      "------------------------------------------------------------\n",
      "Config: (0.0, 100)\n",
      "Model: spf\n",
      "Mean MSE: 0.05918714891437932 and st dev.: 0.011278598454685282\n",
      "Mean MSE for most influential: 0.192416381075721 and st dev.: 0.24594036717083928\n",
      "Mean MSE for less influential: 0.05040713720825725 and st dev.: 0.021425128404536652\n",
      "------------------------------------------------------------\n",
      "Config: (0.0, 100)\n",
      "Model: naive\n",
      "Mean MSE: 0.078200692274091 and st dev.: 0.011604707663063126\n",
      "Mean MSE for most influential: 0.08430386962641263 and st dev.: 0.24819408960886918\n",
      "Mean MSE for less influential: 0.073900384626646 and st dev.: 0.03379573321401781\n",
      "------------------------------------------------------------\n",
      "------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "models = ['multi_cause', 'spf', 'naive']\n",
    "mean_mse_results = {m:[] for m in models}\n",
    "\n",
    "configs = [(0., 10.), (0.3, 100.), (0., 100,)]\n",
    "\n",
    "for config in configs:\n",
    "    for model in models:\n",
    "        print(\"Config:\", config)\n",
    "        print(\"Model:\", model)\n",
    "        all_mses = []\n",
    "        most_mses = []\n",
    "        least_mses = []\n",
    "        for i in range(len(exp_results[config][model])):\n",
    "            beta_hat = exp_results[config][model][i][0]\n",
    "            beta = exp_results[config][model][i][1]\n",
    "            least, most = get_most_least_influential(beta)\n",
    "            all_mses.append(mse_f(beta, beta_hat))\n",
    "            most_mses.append(mse_f(beta[most], beta_hat[most]))\n",
    "            least_mses.append(mse_f(beta[least], beta_hat[least]))\n",
    "                \n",
    "        print(\"Mean MSE:\", np.median(all_mses), \"and st dev.:\", np.std(all_mses))\n",
    "        print(\"Mean MSE for most influential:\", np.median(most_mses), \"and st dev.:\", np.std(most_mses))\n",
    "        print(\"Mean MSE for less influential:\", np.median(least_mses), \"and st dev.:\", np.std(least_mses))\n",
    "            \n",
    "        mean_mse_results[model].append(np.mean(all_mses))\n",
    "        \n",
    "        print('-'*60)\n",
    "    print('-'*60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "outdir = '../doc/aistats20/img/multicovariate_plots/'\n",
    "os.makedirs(outdir, exist_ok=True)\n",
    "from scipy.stats import gaussian_kde\n",
    "regimes = [(0., 10.),(0.3, 100.), (0., 100.)]\n",
    "for regime in regimes:\n",
    "    for model in ['naive', 'multi_cause']:\n",
    "        per_user_mse = []\n",
    "        predicted = []\n",
    "        true_influence = []\n",
    "        for i in range(len(exp_results[regime][model])):\n",
    "            beta_predicted = exp_results[regime][model][i][0]\n",
    "            truth = exp_results[regime][model][i][1]\n",
    "            mses = (beta_predicted - truth) ** 2\n",
    "            true_influence += list(truth)\n",
    "            per_user_mse += list(mses) \n",
    "            predicted += list(beta_predicted)\n",
    "\n",
    "        true_influence = np.array(true_influence)\n",
    "        per_user_mse = np.array(per_user_mse)\n",
    "        \n",
    "        highest = [true_influence >5]\n",
    "        lowest = [true_influence < 0.01]\n",
    "        plt.boxplot([per_user_mse[highest], per_user_mse[lowest]], labels=['High Influence Users', 'Low Influence Users'], showfliers=False)\n",
    "        plt.ylim((-0.005, 0.17))\n",
    "#         plt.title(\"Confounding strengths = \" + str(regime) + ';model = '+ model)\n",
    "#         plt.show()\n",
    "        plt.savefig(outdir + 'model=' + model + ';regime=' + str(regime) + '.pdf')\n",
    "        plt.clf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n",
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "outdir = '../doc/aistats20/img/multicovariate_plots/'\n",
    "os.makedirs(outdir, exist_ok=True)\n",
    "from scipy.stats import gaussian_kde\n",
    "regimes = {'low':(0., 10.), 'med':(0.3, 100.), 'high':(0., 100.)}\n",
    "settings = {'high':\"High Influence Users\", 'low':\"Low Influence Users\"}\n",
    "for setting in settings:\n",
    "    mse_results={}\n",
    "    for regime, config in regimes.items():\n",
    "        for model in ['naive', 'spf', 'multi_cause']:\n",
    "            per_user_mse = []\n",
    "            predicted = []\n",
    "            true_influence = []\n",
    "            for i in range(len(exp_results[config][model])):\n",
    "                beta_predicted = exp_results[config][model][i][0]\n",
    "                truth = exp_results[config][model][i][1]\n",
    "                mses = (beta_predicted - truth) ** 2\n",
    "                true_influence += list(truth)\n",
    "                per_user_mse += list(mses) \n",
    "                predicted += list(beta_predicted)\n",
    "\n",
    "            true_influence = np.array(true_influence)\n",
    "            per_user_mse = np.array(per_user_mse)\n",
    "\n",
    "            if setting == \"high\":\n",
    "                mask = [true_influence > 5]\n",
    "            else:\n",
    "                mask = [true_influence < 0.01]\n",
    "\n",
    "            mse_results[(model, regime)] = per_user_mse[mask]\n",
    "\n",
    "    bp = plt.boxplot([mse_results[('naive', 'low')], mse_results[('multi_cause', 'low')],\n",
    "                mse_results[('naive', 'med')], mse_results[('multi_cause', 'med')],\n",
    "                mse_results[('naive', 'high')],  mse_results[('multi_cause', 'high')]],\n",
    "                positions=[3, 4, 7,8, 11,12],\n",
    "                showfliers=False, patch_artist=True)\n",
    "\n",
    "    colors = ['slategrey', 'maroon']\n",
    "    for b_idx in range(6):\n",
    "        for item in ['boxes']:\n",
    "            c = colors[b_idx % 2]\n",
    "            plt.setp(bp[item][b_idx], color=c)\n",
    "    axes = plt.axes()\n",
    "    axes.set_xticklabels(['Low', 'Medium', 'High'])\n",
    "    axes.set_xticks([3.5, 7.5, 11.5])\n",
    "    plt.setp(bp['medians'], color='white')\n",
    "\n",
    "    grey = plt.plot([1,1], c=colors[0], label='Naive')\n",
    "    maroon = plt.plot([1,1], c=colors[1], label='PIF')\n",
    "    plt.legend()\n",
    "    plt.ylabel(\"MSE\")\n",
    "    if setting=='high':\n",
    "        ymax=0.3\n",
    "    else:\n",
    "        ymax=0.1\n",
    "    plt.ylim((-0.005, ymax))\n",
    "    plt.title(settings[setting])\n",
    "#     plt.show()\n",
    "    plt.savefig(outdir + 'setting=' + setting+ '.pdf')\n",
    "    plt.clf()\n",
    "    axes.clear()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n",
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "outdir = '../doc/aistats20/img/multicovariate_plots/'\n",
    "os.makedirs(outdir, exist_ok=True)\n",
    "from scipy.stats import gaussian_kde\n",
    "regimes = {'low':(0., 10.), 'med':(0.3, 100.), 'high':(0., 100.)}\n",
    "settings = {'high':\"High Influence Users\", 'low':\"Low Influence Users\"}\n",
    "for setting in settings:\n",
    "    mse_results={}\n",
    "    for regime, config in regimes.items():\n",
    "        for model in ['naive', 'spf', 'multi_cause']:\n",
    "            per_user_mse = []\n",
    "            predicted = []\n",
    "            true_influence = []\n",
    "            for i in range(len(exp_results[config][model])):\n",
    "                beta_predicted = exp_results[config][model][i][0]\n",
    "                truth = exp_results[config][model][i][1]\n",
    "                mses = (beta_predicted - truth) ** 2\n",
    "                true_influence += list(truth)\n",
    "                per_user_mse += list(mses) \n",
    "                predicted += list(beta_predicted)\n",
    "\n",
    "            true_influence = np.array(true_influence)\n",
    "            per_user_mse = np.array(per_user_mse)\n",
    "            predicted = np.array(predicted)\n",
    "\n",
    "            if setting == \"high\":\n",
    "                mask = [(true_influence >=0.1) & (true_influence <=1)]\n",
    "            else:\n",
    "                mask = [true_influence < 0.001]\n",
    "            \n",
    "            mse_results[(model, regime)] = per_user_mse[mask]\n",
    "\n",
    "    bp = plt.boxplot([mse_results[('naive', 'low')], mse_results[('spf', 'low')] ,mse_results[('multi_cause', 'low')],\n",
    "                mse_results[('naive', 'med')],mse_results[('spf', 'med')], mse_results[('multi_cause', 'med')],\n",
    "                mse_results[('naive', 'high')], mse_results[('spf', 'high')], mse_results[('multi_cause', 'high')]],\n",
    "                positions=[3, 4,5, 8,9,10, 13,14,15],\n",
    "                showfliers=False, patch_artist=True)\n",
    "\n",
    "    colors = ['slategrey', 'navy', 'maroon']\n",
    "    for b_idx in range(9):\n",
    "        for item in ['boxes']:\n",
    "            c = colors[b_idx % 3]\n",
    "            plt.setp(bp[item][b_idx], color=c)\n",
    "    axes = plt.axes()\n",
    "    axes.set_xticklabels(['Low', 'Medium', 'High'])\n",
    "    axes.set_xticks([3.5, 7.5, 11.5])\n",
    "    plt.setp(bp['medians'], color='white')\n",
    "\n",
    "    grey = plt.plot([1,1], c=colors[0], label='Naive')\n",
    "    navy = plt.plot([1,1], c=colors[0], label='mSPF')\n",
    "    maroon = plt.plot([1,1], c=colors[1], label='PIF')\n",
    "    plt.legend()\n",
    "    plt.ylabel(\"Squared Error\")\n",
    "    if setting=='high':\n",
    "        ymax=0.3\n",
    "    else:\n",
    "        ymax=0.1\n",
    "    plt.ylim((-0.005, ymax))\n",
    "    plt.title(settings[setting])\n",
    "#     plt.show()\n",
    "    plt.savefig(outdir + 'mSPF-included-setting=' + setting+ '.pdf')\n",
    "    plt.clf()\n",
    "    axes.clear()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n",
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n",
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n",
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "outdir = '../doc/aistats20/img/multicovariate_plots/'\n",
    "os.makedirs(outdir, exist_ok=True)\n",
    "from scipy.stats import gaussian_kde\n",
    "regimes = {'low':(0., 10.), 'med':(0.3, 100.), 'high':(0., 100.)}\n",
    "settings = {'high':\"High Influence Users\", 'low':\"Low Influence Users\"}\n",
    "for setting in settings:\n",
    "    mse_results={}\n",
    "    for regime, config in regimes.items():\n",
    "        for model in ['naive', 'spf', 'multi_cause']:\n",
    "            per_user_mse = []\n",
    "            predicted = []\n",
    "            true_influence = []\n",
    "            for i in range(len(exp_results[config][model])):\n",
    "                beta_predicted = exp_results[config][model][i][0]\n",
    "                truth = exp_results[config][model][i][1]\n",
    "                mses = (beta_predicted - truth) ** 2\n",
    "                true_influence += list(truth)\n",
    "                per_user_mse += list(mses) \n",
    "                predicted += list(beta_predicted)\n",
    "\n",
    "            true_influence = np.array(true_influence)\n",
    "            per_user_mse = np.array(per_user_mse)\n",
    "            predicted = np.array(predicted)\n",
    "\n",
    "            if setting == \"high\":\n",
    "                mask = [(true_influence >=0.1) & (true_influence <=1)]\n",
    "            else:\n",
    "                mask = [true_influence < 0.01]\n",
    "            \n",
    "            mse_results[(model, regime)] = per_user_mse[mask]\n",
    "    for comparison in ['naive', 'spf']:\n",
    "        bp1 = plt.boxplot([mse_results[(comparison, 'low')], mse_results[('multi_cause', 'low')],\n",
    "                    mse_results[(comparison, 'med')], mse_results[('multi_cause', 'med')], \n",
    "                    mse_results[(comparison, 'high')], mse_results[('multi_cause', 'high')]],\n",
    "                    positions=[3, 4, 5,6, 7,8],\n",
    "                    showfliers=False, patch_artist=True)\n",
    "        axes = plt.axes()\n",
    "        colors = ['slategrey', 'maroon'] if comparison=='naive' else ['navy', 'maroon'] \n",
    "        for b_idx in range(6):\n",
    "            for item in ['boxes']:\n",
    "                mod = 2\n",
    "                c = colors[b_idx % mod]\n",
    "                plt.setp(bp1[item][b_idx], color=c)\n",
    "        axes.set_xticklabels(['Low', 'Medium', 'High'])\n",
    "        axes.set_xticks([3.5, 5.5, 7.5])\n",
    "        plt.setp(bp1['medians'], color='white')\n",
    "\n",
    "        model = plt.plot([1,1], c=colors[1], label='PIF')\n",
    "        if comparison == 'spf':\n",
    "            comp = plt.plot([1,1], c=colors[0], label='mSPF')\n",
    "        else:\n",
    "            comp = plt.plot([1,1], c=colors[0], label='Naive')\n",
    "        plt.legend()\n",
    "        plt.ylabel(\"Squared Error\")\n",
    "        ymax = 0.1 if setting =='high' else 0.08\n",
    "        ymin = -0.005\n",
    "        if setting == 'high' and comparison == 'spf':\n",
    "            ymax = 0.25\n",
    "        elif setting == 'low' and comparison=='spf':\n",
    "            ymax = 0.01\n",
    "            ymin=0.\n",
    "            \n",
    "        plt.ylim((ymin, ymax))\n",
    "        plt.title(settings[setting])\n",
    "\n",
    "#         plt.show()\n",
    "        plt.savefig(outdir + 'split-plot-setting=' + setting+ ';comparison=' +comparison+ '.pdf')\n",
    "        plt.clf()\n",
    "        axes.clear()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n",
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "outdir = '../doc/aistats20/img/outcome_conf_plots/'\n",
    "os.makedirs(outdir, exist_ok=True)\n",
    "from scipy.stats import gaussian_kde\n",
    "regimes = {'low':(0., 10.), 'med':(0.3, 100.), 'high':(0., 100.)}\n",
    "settings = {'high':\"High Influence Users\", 'low':\"Low Influence Users\"}\n",
    "for setting in settings:\n",
    "    mse_results={}\n",
    "    for regime, config in regimes.items():\n",
    "        for model in ['naive', 'spf', 'multi_cause']:\n",
    "            n_exp = len(exp_results[config][model])\n",
    "            mse = np.zeros(n_exp)\n",
    "            for i in range(n_exp):\n",
    "                beta_predicted = exp_results[config][model][i][0]\n",
    "                truth = exp_results[config][model][i][1]\n",
    "                if setting == \"high\":\n",
    "                    mask = [(truth >= 0.5) &( truth <=2)]\n",
    "                else:\n",
    "                    mask = [truth < 0.001]\n",
    "                sq_err = (beta_predicted[mask] - truth[mask]) ** 2\n",
    "                mse[i] = sq_err.mean()\n",
    "            mse_results[(model, regime)] = mse\n",
    "    bp = plt.boxplot([mse_results[('naive', 'low')], mse_results[('spf', 'low')] ,mse_results[('multi_cause', 'low')],\n",
    "                mse_results[('naive', 'med')],mse_results[('spf', 'med')], mse_results[('multi_cause', 'med')],\n",
    "                mse_results[('naive', 'high')], mse_results[('spf', 'high')], mse_results[('multi_cause', 'high')]],\n",
    "                positions=[3, 4,5, 8,9,10, 13,14,15],\n",
    "                showfliers=False, patch_artist=True)\n",
    "\n",
    "    colors = ['slategrey', 'navy', 'maroon']\n",
    "    for b_idx in range(9):\n",
    "        for item in ['boxes']:\n",
    "            c = colors[b_idx % 3]\n",
    "            plt.setp(bp[item][b_idx], color=c)\n",
    "    axes = plt.axes()\n",
    "    axes.set_xticklabels(['Low', 'Medium', 'High'])\n",
    "    axes.set_xticks([4, 9, 14])\n",
    "    plt.setp(bp['medians'], color='white')\n",
    "    \n",
    "    model = plt.plot([40,40], c=colors[2], label='PIF')\n",
    "    comp1 = plt.plot([40,40], c=colors[1], label='mSPF')\n",
    "    comp2 = plt.plot([40,40], c=colors[0], label='Naive')\n",
    "    \n",
    "#     plt.legend()\n",
    "    plt.ylabel(\"Squared Error\")\n",
    "\n",
    "    ymin = -0.05 if setting=='high' else 0.04\n",
    "    ymax = 1.3 if setting=='high' else 0.1\n",
    "    plt.ylim((ymin, ymax))\n",
    "    plt.title(settings[setting])\n",
    "    plt.savefig(outdir+'setting='+ setting +'.pdf')\n",
    "    plt.clf()\n",
    "    axes.clear()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
