{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib\n",
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "from sklearn.metrics import mean_squared_error as mse_f\n",
    "from scipy import sparse\n",
    "from scipy.stats import gamma\n",
    "from scipy.stats import ttest_ind\n",
    "import warnings\n",
    "import seaborn as sns\n",
    "sns.set(style=\"darkgrid\")\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
    "warnings.filterwarnings(\"ignore\", category=FutureWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_set_overlap(Beta_p, Beta, k=50):\n",
    "    top = np.argsort(Beta)[-k:]\n",
    "    top_p = np.argsort(Beta_p)[-k:]\n",
    "    return np.intersect1d(top, top_p).shape[0]/np.union1d(top, top_p).shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_most_least_influential(Beta,n=100):\n",
    "    sorted_inf = np.argsort(Beta)\n",
    "    highest = sorted_inf[-n:]\n",
    "    lowest = sorted_inf[:n]\n",
    "    return lowest, highest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pokec_out = '../pokec_mixture_covariates_out'\n",
    "# pokec_out = '../pokec_k=5_mixture_covariates_out'\n",
    "pokec_out = '../pokec_k=5_p=10_mixture_covariates_out/'\n",
    "exps = 10\n",
    "models = ['pif', 'spf', 'unadjusted', 'network_pref_only']\n",
    "probs = [0., 0.25,0.5, 0.75,1.]\n",
    "confounding_strengths = [(10., 10.), (10., 100.), (100.,100.), (100., 10.), (50., 50.)]\n",
    "exp_results = {}\n",
    "true_influences = []\n",
    "for i in range(1, exps+1):\n",
    "    for model in models:\n",
    "        for (cov1conf, cov2conf) in confounding_strengths:\n",
    "            for p in probs:\n",
    "                base_file_name = 'conf=' + str((cov1conf, cov2conf)) +';p=' + str(p)\n",
    "                f_name = base_file_name +'_fitted.gz'\n",
    "                truth_fname = base_file_name + '_true.gz'\n",
    "                result_file = os.path.join(pokec_out, str(i), model + '_model_fitted_params', f_name)\n",
    "                truth_file = os.path.join(pokec_out, str(i), model + '_model_fitted_params', truth_fname)\n",
    "                params = np.loadtxt(result_file)\n",
    "                truth = np.loadtxt(truth_file)\n",
    "                true_influences += list(truth)\n",
    "                if (p, (cov1conf,cov2conf)) in exp_results:\n",
    "                    if model in exp_results[(p, (cov1conf,cov2conf))]:\n",
    "                        exp_results[(p, (cov1conf,cov2conf))][model].append((params, truth))\n",
    "                    else:\n",
    "                        exp_results[(p, (cov1conf,cov2conf))][model]= [(params, truth)]\n",
    "                else:\n",
    "                    exp_results[(p, (cov1conf,cov2conf))] = {model:[(params, truth)]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.00000000e+000, 0.00000000e+000, 2.57590711e-280, 4.21238520e-104,\n",
       "       8.94647462e-001, 8.94647462e-001, 2.81629488e+000])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "quartiles = np.percentile(np.array(true_influences), [0,25,50,75,95,95,99.9])\n",
    "high_val = 0.5\n",
    "low_val = 0.001\n",
    "quartiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mixture probability: 0.0\n",
      "Model: pif\n",
      "[0.0002692186220390563, 0.00021405996709202534, 0.00019348260828905146, 0.00035630196927439035, 0.00016204946691756605, 0.00027323989768318526, 0.0012745451471670667, 0.00021195501700674418, 0.00017410227432835, 0.00021898499764814818]\n",
      "Mean MSE: 0.0003347939967445584 and st dev.: 0.00031786681336027104\n",
      "Mean MSE for most influential: 0.0018956662915991115 and st dev.: 0.0050670758094908215\n",
      "Mean MSE for less influential: 0.00020485997421523433 and st dev.: 3.67519765050844e-05\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.0\n",
      "Model: spf\n",
      "[0.0002671289110055427, 0.00024285815993459772, 0.00016515430299888676, 0.0006757799348381926, 0.000233290643429223, 0.000193768949353639, 0.00027428183736962473, 0.00029617115524518153, 0.0003308720231791265, 0.000288865414054749]\n",
      "Mean MSE: 0.0002968171331408763 and st dev.: 0.00013458329484974173\n",
      "Mean MSE for most influential: 0.001176039528121623 and st dev.: 0.0024785770482029434\n",
      "Mean MSE for less influential: 0.0002384109580749945 and st dev.: 4.1415847706065136e-05\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.0\n",
      "Model: unadjusted\n",
      "[0.004195585842048887, 0.004525773065551254, 0.004544247681743799, 0.004516317766610696, 0.004723130969520451, 0.0047453792433282006, 0.004359687934010557, 0.004625664086559123, 0.0046028952563006735, 0.004781325409425433]\n",
      "Mean MSE: 0.0045620007255099074 and st dev.: 0.00017097491586905289\n",
      "Mean MSE for most influential: 0.0007569354302778087 and st dev.: 0.0008518237293657025\n",
      "Mean MSE for less influential: 0.004769614067432952 and st dev.: 0.00018997009566527862\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.0\n",
      "Model: network_pref_only\n",
      "[0.0007011121224973115, 0.001033457375492617, 0.0010110333770155027, 0.0007339486985911336, 0.0007895612769210388, 0.000821490277264575, 0.0009520687344924569, 0.0008381356525509776, 0.0008364713774583563, 0.0010468552322947004]\n",
      "Mean MSE: 0.0008764134124578671 and st dev.: 0.00011925041598826098\n",
      "Mean MSE for most influential: 0.0015355265657062377 and st dev.: 0.0017491256651790278\n",
      "Mean MSE for less influential: 0.0008261032030331185 and st dev.: 0.00010483976582047345\n",
      "------------------------------------------------------------\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.25\n",
      "Model: pif\n",
      "[0.00031583603998452145, 0.00025266676447081724, 0.00030884245869448873, 0.0002176782026719904, 0.0017383275832989567, 0.0002895534412035757, 0.00027006569846642266, 0.0024374671296384506, 0.0002734837765662434, 0.00024853462921611423]\n",
      "Mean MSE: 0.000635245572421158 and st dev.: 0.0007434621240599139\n",
      "Mean MSE for most influential: 0.0075577038510158765 and st dev.: 0.014599250107072909\n",
      "Mean MSE for less influential: 0.00024348236905362913 and st dev.: 2.4637335930723832e-05\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.25\n",
      "Model: spf\n",
      "[0.00019555407223632, 0.0002144182273553102, 0.00015534383477600256, 0.0006411055962503256, 0.00022629555099171875, 0.00022302169184560826, 0.00016246793806171597, 0.0006991272167438579, 0.00023188283483787972, 0.000268670800191633]\n",
      "Mean MSE: 0.0003017887763290372 and st dev.: 0.0001872636388211684\n",
      "Mean MSE for most influential: 0.0022653632849914413 and st dev.: 0.003926693409208782\n",
      "Mean MSE for less influential: 0.00019708095619612687 and st dev.: 2.3490991894773396e-05\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.25\n",
      "Model: unadjusted\n",
      "[0.003696561441918285, 0.0036366207603429554, 0.0037570797381787017, 0.0038264311219489374, 0.0037538727539759495, 0.0038761577761187225, 0.0037971810041565854, 0.003722193671492137, 0.0037796897229463184, 0.0037030204816157477]\n",
      "Mean MSE: 0.0037548808472694347 and st dev.: 6.586556036547141e-05\n",
      "Mean MSE for most influential: 0.000594234530950943 and st dev.: 0.0004857589229806296\n",
      "Mean MSE for less influential: 0.003925233762252825 and st dev.: 7.608332588728101e-05\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.25\n",
      "Model: network_pref_only\n",
      "[0.0005167311689814923, 0.0008305358215084495, 0.0005171581414028201, 0.0008663007024372976, 0.000553036124086509, 0.0007388443829020199, 0.0007340365119861982, 0.0006339935064080599, 0.0006076051099684779, 0.0007394827727808884]\n",
      "Mean MSE: 0.0006737724242462213 and st dev.: 0.00011981385145651245\n",
      "Mean MSE for most influential: 0.002222672629605775 and st dev.: 0.0019481287097823494\n",
      "Mean MSE for less influential: 0.0005728929531531082 and st dev.: 3.6573504448202004e-05\n",
      "------------------------------------------------------------\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.5\n",
      "Model: pif\n",
      "[0.00030068767446856255, 0.0005211222120757019, 0.00030528793441987454, 0.0002572445021839285, 0.00022878020479465404, 0.0035900424181216934, 0.00034618792295138646, 0.0002448489469660669, 0.00022845960918787915, 0.00042859700492835207]\n",
      "Mean MSE: 0.00064512584300981 and st dev.: 0.0009856933893265987\n",
      "Mean MSE for most influential: 0.007639979164812976 and st dev.: 0.020235701000123025\n",
      "Mean MSE for less influential: 0.00026511045082004437 and st dev.: 2.2411731361874443e-05\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.5\n",
      "Model: spf\n",
      "[0.0002339083646050949, 0.00020703861527280337, 0.0002747939403342217, 0.00018031437717394208, 0.0002722726733747032, 0.00019887546249947066, 0.0002241058962288962, 0.00018515546629967864, 0.0002666946872342467, 0.0002009708627186922]\n",
      "Mean MSE: 0.00022441303457417495 and st dev.: 3.4199702009678906e-05\n",
      "Mean MSE for most influential: 0.000325460407290633 and st dev.: 0.00042539837597759916\n",
      "Mean MSE for less influential: 0.0002086071107247097 and st dev.: 1.458077783964644e-05\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.5\n",
      "Model: unadjusted\n",
      "[0.0032903273594403603, 0.0033961635502797543, 0.0033078147318000213, 0.0033754251319663897, 0.0032961952853607333, 0.003482185012007331, 0.003386599418959068, 0.0033657220642236956, 0.003254798416515696, 0.0032638888564661466]\n",
      "Mean MSE: 0.0033419119827019193 and st dev.: 6.774715299056981e-05\n",
      "Mean MSE for most influential: 0.0008511432329585776 and st dev.: 0.0007165443654959441\n",
      "Mean MSE for less influential: 0.0034801089277636922 and st dev.: 7.353795614722289e-05\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.5\n",
      "Model: network_pref_only\n",
      "[0.0003759581153842339, 0.0009491269564051198, 0.00043393040344439607, 0.0005740384992986706, 0.0004401973559347377, 0.0005273805097337629, 0.0004969469270745592, 0.0004892923412409715, 0.00044129680938043966, 0.0008003998093171813]\n",
      "Mean MSE: 0.0005528567727214073 and st dev.: 0.00017237924835621378\n",
      "Mean MSE for most influential: 0.0019523871306777545 and st dev.: 0.0025477670431692068\n",
      "Mean MSE for less influential: 0.0004465356349955936 and st dev.: 4.4951822002428686e-05\n",
      "------------------------------------------------------------\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.75\n",
      "Model: pif\n",
      "[0.00043566931456074547, 0.0004116934963082552, 0.00032538955748652535, 0.00038393539399923267, 0.0003056293390376491, 0.0003864351223450818, 0.0004196915050331086, 0.00033876409944668935, 0.0003385370639607991, 0.00045079426967796124]\n",
      "Mean MSE: 0.00037965391618560485 and st dev.: 4.7622856982042946e-05\n",
      "Mean MSE for most influential: 0.0005429301023660222 and st dev.: 0.0006034565747596906\n",
      "Mean MSE for less influential: 0.0003551497407934738 and st dev.: 3.4950893908149083e-05\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.75\n",
      "Model: spf\n",
      "[0.00023951625398825252, 0.00019904930228211074, 0.002307309169361754, 0.00018523397673301048, 0.00028861649995336777, 0.00028837647623736353, 0.00028992829771925426, 0.00023165503115338366, 0.00028639647080126344, 0.0002697442843126616]\n",
      "Mean MSE: 0.00045858257625424226 and st dev.: 0.0006173300545216613\n",
      "Mean MSE for most influential: 0.004509923090830592 and st dev.: 0.012381969591650933\n",
      "Mean MSE for less influential: 0.0002420431939616268 and st dev.: 3.298086708255567e-05\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.75\n",
      "Model: unadjusted\n",
      "[0.0033131782653662133, 0.003404782023365231, 0.003349767057568906, 0.0033735361349868844, 0.0035507248850953272, 0.0035706570471438588, 0.0034020824062349145, 0.0033861602196649594, 0.0035664493166173306, 0.003366929596116562]\n",
      "Mean MSE: 0.0034284266952160185 and st dev.: 9.140050179165285e-05\n",
      "Mean MSE for most influential: 0.0004833241829803457 and st dev.: 0.0006035151994437971\n",
      "Mean MSE for less influential: 0.003589085530554703 and st dev.: 0.00010315594654425667\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 0.75\n",
      "Model: network_pref_only\n",
      "[0.0007216825865557372, 0.0007642201881832619, 0.000584069294255665, 0.0005813706648214514, 0.0005896680295844307, 0.0006208319568424618, 0.0012465580813731173, 0.0005724256927599505, 0.0005541869498339934, 0.0006890693700161319]\n",
      "Mean MSE: 0.0006924082814226201 and st dev.: 0.0001965587596942755\n",
      "Mean MSE for most influential: 0.0025469161687006526 and st dev.: 0.003984860064772518\n",
      "Mean MSE for less influential: 0.0005799373223772715 and st dev.: 7.554526370702561e-05\n",
      "------------------------------------------------------------\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 1.0\n",
      "Model: pif\n",
      "[0.0004856880949949214, 0.00040480955347525034, 0.00047828066518534917, 0.0003442854956041918, 0.0004347211306028622, 0.0003370176292229688, 0.0004661746261256054, 0.00035483425438447637, 0.0004179063332872914, 0.0004258912024433487]\n",
      "Mean MSE: 0.0004149608985326266 and st dev.: 5.190531563920026e-05\n",
      "Mean MSE for most influential: 0.0004390379497507972 and st dev.: 0.0004480391658104898\n",
      "Mean MSE for less influential: 0.0004013298397765772 and st dev.: 3.678299133745802e-05\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 1.0\n",
      "Model: spf\n",
      "[0.0004110502712596922, 0.0004076393861601767, 0.00037236465195783123, 0.0009499049358838112, 0.00045746696331903296, 0.0003880127637742017, 0.00033500871552982265, 0.0003932020898824648, 0.00037694575123963664, 0.0005741303598606266]\n",
      "Mean MSE: 0.00046657258888672964 and st dev.: 0.00017250786322118243\n",
      "Mean MSE for most influential: 0.001571087317358304 and st dev.: 0.0030758355917028255\n",
      "Mean MSE for less influential: 0.00038787476715368196 and st dev.: 4.723896667596763e-05\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 1.0\n",
      "Model: unadjusted\n",
      "[0.004575555660237567, 0.005204712037160888, 0.004562639159086594, 0.004830330361717376, 0.004606239608162604, 0.004925393557210137, 0.0044864483704387895, 0.004756409680334518, 0.004591898079336901, 0.004771789327118254]\n",
      "Mean MSE: 0.004731141584080363 and st dev.: 0.00020532803153433616\n",
      "Mean MSE for most influential: 0.0016841796046438672 and st dev.: 0.002896612295182674\n",
      "Mean MSE for less influential: 0.004896267996133377 and st dev.: 0.000116731454011444\n",
      "------------------------------------------------------------\n",
      "Mixture probability: 1.0\n",
      "Model: network_pref_only\n",
      "[0.0010256540134946882, 0.0011883378973126413, 0.0013447601467855349, 0.0013522279996677248, 0.001076449153271084, 0.0010479576905755804, 0.0012963354119515812, 0.0012764899388232082, 0.0009664446720748892, 0.0012023874245589482]\n",
      "Mean MSE: 0.001177704434851588 and st dev.: 0.00013332560080298503\n",
      "Mean MSE for most influential: 0.001832160011052257 and st dev.: 0.0018620336996460194\n",
      "Mean MSE for less influential: 0.001129543092941799 and st dev.: 0.00015798411928823113\n",
      "------------------------------------------------------------\n",
      "------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "mean_mse_results = {m:[] for m in models}\n",
    "config=(10.,10.)\n",
    "for p in probs:\n",
    "    for model in models:\n",
    "        print(\"Mixture probability:\", p)\n",
    "        print(\"Model:\", model)\n",
    "        all_mses = []\n",
    "        most_mses = []\n",
    "        least_mses = []\n",
    "        for i in range(exps):\n",
    "            beta_hat = exp_results[(p,config)][model][i][0]\n",
    "            beta = exp_results[(p,config)][model][i][1]\n",
    "#             least, most = get_most_least_influential(beta)\n",
    "            most = beta >= high_val\n",
    "            least = beta <= low_val\n",
    "#             sq = (beta-beta_hat) ** 2\n",
    "#             print(sq.max(), beta[sq.argmax()], beta_hat[sq.argmax()], sq.mean())\n",
    "            all_mses.append(mse_f(beta, beta_hat))\n",
    "            most_mses.append(mse_f(beta[most], beta_hat[most]))\n",
    "            least_mses.append(mse_f(beta[least], beta_hat[least]))\n",
    "        print(all_mses)\n",
    "        print(\"Mean MSE:\", np.mean(all_mses), \"and st dev.:\", np.std(all_mses))\n",
    "        print(\"Mean MSE for most influential:\", np.mean(most_mses), \"and st dev.:\", np.std(most_mses))\n",
    "        print(\"Mean MSE for less influential:\", np.mean(least_mses), \"and st dev.:\", np.std(least_mses))\n",
    "            \n",
    "        mean_mse_results[model].append(np.mean(all_mses))\n",
    "        \n",
    "        print('-'*60)\n",
    "    print('-'*60)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fix confounding strength; vary mixture probabilities "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n",
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "outdir = '../doc/icml/draft/img/mixture/'\n",
    "os.makedirs(outdir, exist_ok=True)\n",
    "config=(50.,50.)\n",
    "regimes = {'exog. confounding only':(0., config), 'both confounding':(0.5, config), 'homophily only':(1.0, config)}\n",
    "settings = {'high':\"High Influence Users\", 'low':\"Low Influence Users\"}\n",
    "for setting in settings:\n",
    "    mse_results={}\n",
    "    for regime, c in regimes.items():\n",
    "        for model in models:\n",
    "            mse = np.zeros(exps)\n",
    "            for i in range(exps):\n",
    "                beta_predicted = exp_results[c][model][i][0]\n",
    "                truth = exp_results[c][model][i][1] \n",
    "                if setting == \"high\":\n",
    "                    mask = (truth >= high_val)\n",
    "                else:\n",
    "                    mask = (truth <= low_val)\n",
    "                sq_err = (beta_predicted[mask] - truth[mask]) ** 2\n",
    "                mse[i] = sq_err.mean()\n",
    "            mse_results[(model, regime)] = mse\n",
    "    \n",
    "    bp = plt.boxplot([mse_results[('unadjusted', 'exog. confounding only')],  mse_results[('network_pref_only', 'exog. confounding only')], mse_results[('spf', 'exog. confounding only')] ,mse_results[('pif', 'exog. confounding only')],\n",
    "                mse_results[('unadjusted', 'both confounding')],mse_results[('network_pref_only', 'both confounding')], mse_results[('spf', 'both confounding')], mse_results[('pif', 'both confounding')],\n",
    "                mse_results[('unadjusted', 'homophily only')], mse_results[('network_pref_only', 'homophily only')], mse_results[('spf', 'homophily only')], mse_results[('pif', 'homophily only')]],\n",
    "                positions=[2,3,4,5, 7,8,9,10, 12,13,14,15],\n",
    "                showfliers=False, patch_artist=True)\n",
    "\n",
    "    colors = ['slategrey', 'navy', 'maroon', 'black']\n",
    "    for b_idx in range(12):\n",
    "        for item in ['boxes']:\n",
    "            c = colors[b_idx % 4]\n",
    "            plt.setp(bp[item][b_idx], color=c)\n",
    "    axes = plt.axes()\n",
    "    axes.set_xticklabels(['Exog. Only', 'Both', 'Homophily Only'])\n",
    "    axes.set_xticks([3.5, 8.5, 13.5])\n",
    "    plt.setp(bp['medians'], color='white')\n",
    "    \n",
    "    model = plt.plot([40,40], c=colors[3], label='PIF')\n",
    "    comp1 = plt.plot([40,40], c=colors[2], label='SPF')\n",
    "    comp2 = plt.plot([40,40], c=colors[0], label='Unadjusted')\n",
    "    comp2 = plt.plot([40,40], c=colors[1], label='Network Pref. Only')\n",
    "#     plt.legend()\n",
    "#     plt.ylabel(\"Mean Squared Error\")\n",
    "    ymin = -0.0001\n",
    "    ymax = 0.01\n",
    "    plt.ylim((ymin, ymax))\n",
    "    plt.title(settings[setting])\n",
    "#     plt.show()\n",
    "    plt.savefig(outdir + 'setting=' + setting+ '.pdf')\n",
    "    plt.clf()\n",
    "    axes.clear()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fix mixture to 0.0 (exog. confounding only); vary strength."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n",
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "outdir = '../doc/icml/draft/img/exog_only/'\n",
    "os.makedirs(outdir, exist_ok=True)\n",
    "mixture_prob = 0.0\n",
    "regimes = {'low':(mixture_prob, (10.,10.)), 'med':(mixture_prob, (50., 50.)), 'high':(mixture_prob, (100.,100.))}\n",
    "settings = {'high':\"High Influence Users\", 'low':\"Low Influence Users\"}\n",
    "for setting in settings:\n",
    "    mse_results={}\n",
    "    for regime, c in regimes.items():\n",
    "        for model in models:\n",
    "            mse = np.zeros(exps)\n",
    "            for i in range(exps):\n",
    "                beta_predicted = exp_results[c][model][i][0]\n",
    "                truth = exp_results[c][model][i][1]\n",
    "#                 sanity_mask= beta_predicted != 1.0\n",
    "                if setting == \"high\":\n",
    "                    mask = (truth >= high_val) #& (sanity_mask)\n",
    "                else:\n",
    "                    mask = (truth <= low_val) #& (sanity_mask)\n",
    "                sq_err = (beta_predicted[mask] - truth[mask]) ** 2\n",
    "                mse[i] = sq_err.mean()\n",
    "            mse_results[(model, regime)] = mse\n",
    "\n",
    "    bp = plt.boxplot([mse_results[('unadjusted', 'low')],  mse_results[('network_pref_only', 'low')], mse_results[('spf', 'low')] ,mse_results[('pif', 'low')],\n",
    "                mse_results[('unadjusted', 'med')],mse_results[('network_pref_only', 'med')], mse_results[('spf', 'med')], mse_results[('pif', 'med')],\n",
    "                mse_results[('unadjusted', 'high')], mse_results[('network_pref_only', 'high')], mse_results[('spf', 'high')], mse_results[('pif', 'high')]],\n",
    "                positions=[2,3,4,5, 7,8,9,10, 12,13,14,15],\n",
    "                showfliers=False, patch_artist=True)\n",
    "\n",
    "    colors = ['slategrey', 'navy', 'maroon', 'black']\n",
    "    for b_idx in range(12):\n",
    "        for item in ['boxes']:\n",
    "            c = colors[b_idx % 4]\n",
    "            plt.setp(bp[item][b_idx], color=c)\n",
    "    axes = plt.axes()\n",
    "    axes.set_xticklabels(['Low', 'Medium', 'High'])\n",
    "    axes.set_xticks([3.5, 8.5, 13.5])\n",
    "    plt.setp(bp['medians'], color='white')\n",
    "    \n",
    "    model = plt.plot([40,40], c=colors[3], label='PIF')\n",
    "    comp1 = plt.plot([40,40], c=colors[2], label='SPF')\n",
    "    comp2 = plt.plot([40,40], c=colors[0], label='Unadjusted')\n",
    "    comp2 = plt.plot([40,40], c=colors[1], label='Network Pref. Only')\n",
    "#     plt.legend()\n",
    "#     plt.ylabel(\"Mean Squared Error\")\n",
    "    ymin = -0.0001\n",
    "    ymax = 0.01\n",
    "    plt.ylim((ymin, ymax))\n",
    "    plt.title(settings[setting])\n",
    "#     plt.show()\n",
    "    plt.savefig(outdir + 'setting=' + setting+ '.pdf')\n",
    "    plt.clf()\n",
    "    axes.clear()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fix mixture prob to 1.0 (all homophily confounding). Vary conf. strength."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n",
      "/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:98: MatplotlibDeprecationWarning: \n",
      "Adding an axes using the same arguments as a previous axes currently reuses the earlier instance.  In a future version, a new instance will always be created and returned.  Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n",
      "  \"Adding an axes using the same arguments as a previous axes \"\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "outdir = '../doc/icml/draft/img/homophily_only/'\n",
    "os.makedirs(outdir, exist_ok=True)\n",
    "mixture_prob = 1.0\n",
    "regimes = {'low':(mixture_prob, (10.,10.)), 'med':(mixture_prob, (50., 50.)), 'high':(mixture_prob, (100.,100.))}\n",
    "settings = {'high':\"High Influence Users\", 'low':\"Low Influence Users\"}\n",
    "for setting in settings:\n",
    "    mse_results={}\n",
    "    for regime, c in regimes.items():\n",
    "        for model in models:\n",
    "            mse = np.zeros(exps)\n",
    "            for i in range(exps):\n",
    "                beta_predicted = exp_results[c][model][i][0]\n",
    "                truth = exp_results[c][model][i][1]\n",
    "#                 sanity_mask= beta_predicted != 1.0\n",
    "                if setting == \"high\":\n",
    "                    mask = (truth >= high_val) #& (sanity_mask)\n",
    "                else:\n",
    "                    mask = (truth <= low_val) #& (sanity_mask)\n",
    "                sq_err = (beta_predicted[mask] - truth[mask]) ** 2\n",
    "                mse[i] = sq_err.mean()\n",
    "            mse_results[(model, regime)] = mse\n",
    "\n",
    "    bp = plt.boxplot([mse_results[('unadjusted', 'low')],  mse_results[('network_pref_only', 'low')], mse_results[('spf', 'low')] ,mse_results[('pif', 'low')],\n",
    "                mse_results[('unadjusted', 'med')],mse_results[('network_pref_only', 'med')], mse_results[('spf', 'med')], mse_results[('pif', 'med')],\n",
    "                mse_results[('unadjusted', 'high')], mse_results[('network_pref_only', 'high')], mse_results[('spf', 'high')], mse_results[('pif', 'high')]],\n",
    "                positions=[2,3,4,5, 7,8,9,10, 12,13,14,15],\n",
    "                showfliers=False, patch_artist=True)\n",
    "\n",
    "    colors = ['slategrey', 'navy', 'maroon', 'black']\n",
    "    for b_idx in range(12):\n",
    "        for item in ['boxes']:\n",
    "            c = colors[b_idx % 4]\n",
    "            plt.setp(bp[item][b_idx], color=c)\n",
    "    axes = plt.axes()\n",
    "    axes.set_xticklabels(['Low', 'Medium', 'High'])\n",
    "    axes.set_xticks([3.5, 8.5, 13.5])\n",
    "    plt.setp(bp['medians'], color='white')\n",
    "    \n",
    "    model = plt.plot([40,40], c=colors[3], label='PIF')\n",
    "    comp1 = plt.plot([40,40], c=colors[2], label='SPF')\n",
    "    comp2 = plt.plot([40,40], c=colors[0], label='Unadjusted')\n",
    "    comp2 = plt.plot([40,40], c=colors[1], label='Network Pref. Only')\n",
    "#     plt.legend()\n",
    "#     plt.ylabel(\"Squared Error\")\n",
    "    ymin = -0.0001\n",
    "    ymax = 0.012 if setting=='high' else 0.01\n",
    "    plt.ylim((ymin, ymax))\n",
    "    plt.title(settings[setting])\n",
    "#     plt.show()\n",
    "    plt.savefig(outdir + 'setting=' + setting+ '.pdf')\n",
    "    plt.clf()\n",
    "    axes.clear()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
