{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# This notebook contains the figures relatives to the Table 1 and 8 of the paper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The performances may slightly differ from the figures reported in the paper as we did not fix the seed to draw the datasets. However, it does not change any conclusions!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Library/Java/JavaVirtualMachines/jdk1.8.0_161.jdk/Contents/Home/jre/lib/jli/libjli.dylib\n"
     ]
    }
   ],
   "source": [
    "import utils\n",
    "import numpy as np\n",
    "from IPython.display import HTML, display, Markdown\n",
    "import pandas as pd\n",
    "\n",
    "import ot\n",
    "import partial_gw as pgw\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## UCI dataset - SCAR scenario"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_unl = 800\n",
    "n_pos = 800\n",
    "nb_reps = 10\n",
    "nb_dummies = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Partial-W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "prior = 0.518\n",
    "perfs_mushrooms_pw, perfs_list_mushrooms_pw = pgw.compute_perf_emd('mushrooms', 'mushrooms', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_mushrooms_pw =  perfs_mushrooms_pw['emd_groups']\n",
    "avg_perfs_mushrooms_pw_time = np.mean(perfs_mushrooms_pw['time'])\n",
    "\n",
    "prior = 0.786\n",
    "perfs_shuttle_pw, perfs_list_shuttle_pw = pgw.compute_perf_emd('shuttle', 'shuttle', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_shuttle_pw =  perfs_shuttle_pw['emd_groups']\n",
    "avg_perfs_shuttle_pw_time = np.mean(perfs_shuttle_pw['time'])\n",
    "\n",
    "prior = 0.898\n",
    "perfs_pageblocks_pw, perfs_list_pageblocks_pw = pgw.compute_perf_emd('pageblocks', 'pageblocks', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_pageblocks_pw =  perfs_pageblocks_pw['emd_groups']\n",
    "avg_perfs_pageblocks_pw_time = np.mean(perfs_pageblocks_pw['time'])\n",
    "\n",
    "prior = 0.167\n",
    "perfs_usps_pw, perfs_list_usps_pw = pgw.compute_perf_emd('usps', 'usps', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_usps_pw =  perfs_usps_pw['emd_groups']\n",
    "avg_perfs_usps_pw_time = np.mean(perfs_usps_pw['time'])\n",
    "\n",
    "prior = 0.658\n",
    "perfs_connect4_pw, perfs_list_connect4_pw = pgw.compute_perf_emd('connect-4', 'connect-4', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_connect4_pw =  perfs_connect4_pw['emd_groups']\n",
    "avg_perfs_connect4_pw_time = np.mean(perfs_connect4_pw['time'])\n",
    "\n",
    "prior = 0.394\n",
    "perfs_spambase_pw, perfs_list_spambase_pw = pgw.compute_perf_emd('spambase', 'spambase', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_spambase_pw =  perfs_spambase_pw['emd_groups']\n",
    "avg_perfs_spambase_pw_time = np.mean(perfs_spambase_pw['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_perfs_mushrooms_pw_std = np.std(perfs_list_mushrooms_pw['emd'])\n",
    "perfs_mushrooms_pw_time_std = np.std(perfs_list_mushrooms_pw['time'])\n",
    "\n",
    "avg_perfs_shuttle_pw_std = np.std(perfs_list_shuttle_pw['emd'])\n",
    "perfs_shuttle_pw_time_std = np.std(perfs_list_shuttle_pw['time'])\n",
    "\n",
    "avg_perfs_pageblocks_pw_std = np.std(perfs_list_pageblocks_pw['emd'])\n",
    "perfs_pageblocks_pw_time_std = np.std(perfs_list_pageblocks_pw['time'])\n",
    "\n",
    "avg_perfs_usps_pw_std = np.std(perfs_list_usps_pw['emd'])\n",
    "perfs_usps_pw_time_std = np.std(perfs_list_usps_pw['time'])\n",
    "\n",
    "avg_perfs_connect4_pw_std = np.std(perfs_list_connect4_pw['emd'])\n",
    "perfs_connect4_pw_time_std = np.std(perfs_list_connect4_pw['time'])\n",
    "\n",
    "avg_perfs_spambase_pw_std = np.std(perfs_list_spambase_pw['emd'])\n",
    "perfs_spambase_pw_time_std = np.std(perfs_list_spambase_pw['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Partial-GW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/kaiyizhang/opt/anaconda3/envs/OT-GAN/lib/python3.6/site-packages/ot/lp/__init__.py:329: UserWarning: numItermax reached before optimality. Try to increase numItermax.\n",
      "  result_code_string = check_result(result_code)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error in the EMD!!!!!!!\n"
     ]
    }
   ],
   "source": [
    "prior = 0.518\n",
    "perfs_mushrooms_gw, perfs_list_mushrooms_gw = pgw.compute_perf_pgw('mushrooms', 'mushrooms', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_mushrooms_gw =  perfs_mushrooms_gw['pgw']\n",
    "avg_perfs_mushrooms_gw_time = np.mean(perfs_mushrooms_gw['time'])\n",
    "\n",
    "prior = 0.786\n",
    "perfs_shuttle_gw, perfs_list_shuttle_gw = pgw.compute_perf_pgw('shuttle', 'shuttle', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_shuttle_gw =  perfs_shuttle_gw['pgw']\n",
    "avg_perfs_shuttle_gw_time = np.mean(perfs_shuttle_gw['time'])\n",
    "\n",
    "prior = 0.898\n",
    "perfs_pageblocks_gw, perfs_list_pageblocks_gw = pgw.compute_perf_pgw('pageblocks', 'pageblocks', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_pageblocks_gw =  perfs_pageblocks_gw['pgw']\n",
    "avg_perfs_pageblocks_gw_time = np.mean(perfs_pageblocks_gw['time'])\n",
    "\n",
    "prior = 0.167\n",
    "perfs_usps_gw, perfs_list_usps_gw = pgw.compute_perf_pgw('usps', 'usps', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_usps_gw =  perfs_usps_gw['pgw']\n",
    "avg_perfs_usps_gw_time = np.mean(perfs_usps_gw['time'])\n",
    "\n",
    "prior = 0.658\n",
    "perfs_connect4_gw, perfs_list_connect4_gw = pgw.compute_perf_pgw('connect-4', 'connect-4', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_connect4_gw =  perfs_connect4_gw['pgw']\n",
    "avg_perfs_connect4_gw_time = np.mean(perfs_connect4_gw['time'])\n",
    "\n",
    "prior = 0.394\n",
    "perfs_spambase_gw, perfs_list_spambase_gw = pgw.compute_perf_pgw('spambase', 'spambase', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_spambase_gw =  perfs_spambase_gw['pgw']\n",
    "avg_perfs_spambase_gw_time = np.mean(perfs_spambase_gw['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.savez('PGW_res.npz', perfs_mushrooms_gw=perfs_mushrooms_gw, perfs_shuttle_gw=perfs_shuttle_gw, perfs_list_shuttle_gw=perfs_list_shuttle_gw, perfs_pageblocks_gw=perfs_pageblocks_gw, perfs_list_pageblocks_gw=perfs_list_pageblocks_gw, perfs_usps_gw=perfs_usps_gw, perfs_list_usps_gw=perfs_list_usps_gw, perfs_connect4_gw=perfs_connect4_gw, perfs_list_connect4_gw=perfs_list_connect4_gw, perfs_spambase_gw=perfs_spambase_gw, perfs_list_spambase_gw=perfs_list_spambase_gw)\n",
    "np.savez('PGW_res.npz', perfs_mushrooms_gw=perfs_mushrooms_gw, perfs_list_mushrooms_gw=perfs_list_mushrooms_gw, perfs_shuttle_gw=perfs_shuttle_gw, perfs_list_shuttle_gw=perfs_list_shuttle_gw, perfs_pageblocks_gw=perfs_pageblocks_gw, perfs_list_pageblocks_gw=perfs_list_pageblocks_gw, perfs_usps_gw=perfs_usps_gw, perfs_list_usps_gw=perfs_list_usps_gw, perfs_connect4_gw=perfs_connect4_gw, perfs_list_connect4_gw=perfs_list_connect4_gw, perfs_spambase_gw=perfs_spambase_gw, perfs_list_spambase_gw=perfs_list_spambase_gw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_perfs_mushrooms_gw_std = np.std(perfs_list_mushrooms_gw['pgw'])\n",
    "perfs_mushrooms_gw_time_std = np.std(perfs_mushrooms_gw['time'])\n",
    "\n",
    "avg_perfs_shuttle_gw_std = np.std(perfs_list_shuttle_gw['pgw'])\n",
    "perfs_shuttle_gw_time_std = np.std(perfs_shuttle_gw['time'])\n",
    "\n",
    "avg_perfs_pageblocks_gw_std = np.std(perfs_list_pageblocks_gw['pgw'])\n",
    "perfs_pageblocks_gw_time_std = np.std(perfs_pageblocks_gw['time'])\n",
    "\n",
    "avg_perfs_usps_gw_std = np.std(perfs_list_usps_gw['pgw'])\n",
    "perfs_usps_gw_time_std = np.std(perfs_usps_gw['time'])\n",
    "\n",
    "avg_perfs_connect4_gw_std = np.std(perfs_list_connect4_gw['pgw'])\n",
    "perfs_connect4_gw_time_std = np.std(perfs_connect4_gw['time'])\n",
    "\n",
    "avg_perfs_spambase_gw_std = np.std(perfs_list_spambase_gw['pgw'])\n",
    "perfs_spambase_gw_time_std = np.std(perfs_spambase_gw['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OT-Profile"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OTP-wo-prior method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc_mushrooms:0.9973750000000001(0.0013050383136138399)\n",
      "time_mushrooms:0.6405527114868164(0.24822855480452705)\n",
      "pr_mushrooms:0.517375(0.0016250000000000106)\n",
      "acc_shuttle:0.9773750000000001(0.008337603072826152)\n",
      "time_shuttle:0.5603338718414307(0.07920472362583303)\n",
      "pr_shuttle:0.785440401266508(0.00784180784778209)\n",
      "acc_pageblocks:0.93075(0.011141139977578601)\n",
      "time_pageblocks:0.3030731678009033(0.02413955381854405)\n",
      "pr_pageblocks:0.9247081736203524(0.030287614149567748)\n",
      "acc_usps:0.924625(0.016686914783745992)\n",
      "time_usps:7.347267413139344(0.2137716847377636)\n",
      "pr_usps:0.09215094345830252(0.01755037190568004)\n",
      "acc_connect-4:0.67325(0.010142731387550421)\n",
      "time_connect-4:3.4880244016647337(0.24537069107899218)\n",
      "pr_connect-4:0.8600360920039085(0.04288851443778899)\n",
      "acc_spambase:0.659(0.12683256679575636)\n",
      "time_spambase:0.71748366355896(0.047196674425450356)\n",
      "pr_spambase:0.30064049569593987(0.34382244248024224)\n"
     ]
    }
   ],
   "source": [
    "prior = 0.518\n",
    "avg_perfs_mushrooms_otp, perfs_mushrooms_otp_std, avg_perfs_mushrooms_otp_time, avg_perfs_mushrooms_otp_time_std, avg_mushrooms_detected_priors, avg_mushrooms_detected_priors_std = pgw.compute_a_case('mushrooms', 'mushrooms', n_unl, n_pos, prior, nb_reps, True)\n",
    "\n",
    "prior = 0.786\n",
    "avg_perfs_shuttle_otp, perfs_shuttle_otp_std, avg_perfs_shuttle_otp_time, avg_perfs_shuttle_otp_time_std, avg_shuttle_detected_priors, avg_shuttle_detected_priors_std = pgw.compute_a_case('shuttle', 'shuttle', n_unl, n_pos, prior, nb_reps, True)\n",
    "\n",
    "prior = 0.898\n",
    "avg_perfs_pageblocks_otp, perfs_pageblocks_otp_std, avg_perfs_pageblocks_otp_time, avg_perfs_pageblocks_otp_time_std, avg_pageblocks_detected_priors, avg_pageblocks_detected_priors_std = pgw.compute_a_case('pageblocks', 'pageblocks', n_unl, n_pos, prior, nb_reps, True)\n",
    "\n",
    "prior = 0.167\n",
    "avg_perfs_usps_otp, perfs_usps_otp_std, avg_perfs_usps_otp_time, avg_perfs_usps_otp_time_std, avg_usps_detected_priors, avg_usps_detected_priors_std = pgw.compute_a_case('usps', 'usps', n_unl, n_pos, prior, nb_reps, True)\n",
    "\n",
    "prior = 0.658\n",
    "avg_perfs_connect4_otp, perfs_connect4_otp_std, avg_perfs_connect4_otp_time, avg_perfs_connect4_otp_time_std, avg_connect4_detected_priors, avg_connect4_detected_priors_std = pgw.compute_a_case('connect-4', 'connect-4', n_unl, n_pos, prior, nb_reps, True)\n",
    "\n",
    "prior = 0.394\n",
    "avg_perfs_spambase_otp, perfs_spambase_otp_std, avg_perfs_spambase_otp_time, avg_perfs_spambase_otp_time_std, avg_spambase_detected_priors, avg_spambase_detected_priors_std = pgw.compute_a_case('spambase', 'spambase', n_unl, n_pos, prior, nb_reps, True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OTP-w-prior method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc_mushrooms:0.9974999999999999(0.002738612787525823)\n",
      "time_mushrooms:0.5525274276733398(0.016350265209670634)\n",
      "pr_mushrooms:0.518(0.0)\n",
      "acc_shuttle:0.9792500000000001(0.008591420138719798)\n",
      "time_shuttle:0.5472853422164917(0.012078060730152744)\n",
      "pr_shuttle:0.7859999999999999(1.1102230246251565e-16)\n",
      "acc_pageblocks:0.93125(0.007004462863060927)\n",
      "time_pageblocks:0.30846025943756106(0.02155138064152986)\n",
      "pr_pageblocks:0.898(0.0)\n",
      "acc_usps:0.9727500000000001(0.004930770730829002)\n",
      "time_usps:7.4426860332489015(0.20957486580545923)\n",
      "pr_usps:0.167(0.0)\n",
      "acc_connect-4:0.747(0.014949916387726052)\n",
      "time_connect-4:3.4601741790771485(0.22995151101870942)\n",
      "pr_connect-4:0.6580000000000001(1.1102230246251565e-16)\n",
      "acc_spambase:0.73425(0.012847665157529602)\n",
      "time_spambase:0.7209930181503296(0.0637429141092371)\n",
      "pr_spambase:0.394(0.0)\n"
     ]
    }
   ],
   "source": [
    "prior = 0.518\n",
    "avg_perfs_mushrooms_otp_, perfs_mushrooms_otp_std_, avg_perfs_mushrooms_otp_time_, avg_perfs_mushrooms_otp_time_std_, avg_mushrooms_detected_priors_, avg_mushrooms_detected_priors_std_ = pgw.compute_a_case('mushrooms', 'mushrooms', n_unl, n_pos, prior, nb_reps, False)\n",
    "\n",
    "prior = 0.786\n",
    "avg_perfs_shuttle_otp_, perfs_shuttle_otp_std_, avg_perfs_shuttle_otp_time_, avg_perfs_shuttle_otp_time_std_, avg_shuttle_detected_priors_, avg_shuttle_detected_priors_std_ = pgw.compute_a_case('shuttle', 'shuttle', n_unl, n_pos, prior, nb_reps, False)\n",
    "\n",
    "prior = 0.898\n",
    "avg_perfs_pageblocks_otp_, perfs_pageblocks_otp_std_, avg_perfs_pageblocks_otp_time_, avg_perfs_pageblocks_otp_time_std_, avg_pageblocks_detected_priors_, avg_pageblocks_detected_priors_std_ = pgw.compute_a_case('pageblocks', 'pageblocks', n_unl, n_pos, prior, nb_reps, False)\n",
    "\n",
    "prior = 0.167\n",
    "avg_perfs_usps_otp_, perfs_usps_otp_std_, avg_perfs_usps_otp_time_, avg_perfs_usps_otp_time_std_, avg_usps_detected_priors_, avg_usps_detected_priors_std_ = pgw.compute_a_case('usps', 'usps', n_unl, n_pos, prior, nb_reps, False)\n",
    "\n",
    "prior = 0.658\n",
    "avg_perfs_connect4_otp_, perfs_connect4_otp_std_, avg_perfs_connect4_otp_time_, avg_perfs_connect4_otp_time_std_, avg_connect4_detected_priors_, avg_connect4_detected_priors_std_ = pgw.compute_a_case('connect-4', 'connect-4', n_unl, n_pos, prior, nb_reps, False)\n",
    "\n",
    "prior = 0.394\n",
    "avg_perfs_spambase_otp_, perfs_spambase_otp_std_, avg_perfs_spambase_otp_time_, avg_perfs_spambase_otp_time_std_, avg_spambase_detected_priors_, avg_spambase_detected_priors_std_ = pgw.compute_a_case('spambase', 'spambase', n_unl, n_pos, prior, nb_reps, False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tab.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Accuracy Table "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrr}\n",
      "\\toprule\n",
      "    dataset &  \\$\\textbackslash pi\\$ &   p-W &  p-GW &  OTP-w-prior &  OTP-wo-prior &  \\$\\textbackslash hat\\{\\textbackslash pi\\}\\$ \\\\\n",
      "\\midrule\n",
      "  mushrooms &  0.518 &  96.1 &  95.2 &         99.8 &          99.7 &     0.517375 \\\\\n",
      "    shuttle &  0.786 &  96.3 &  95.5 &         97.9 &          97.7 &     0.785440 \\\\\n",
      " pageblocks &  0.898 &  92.4 &  90.6 &         93.1 &          93.1 &     0.924708 \\\\\n",
      "       usps &  0.167 &  98.6 &  95.7 &         97.3 &          92.5 &     0.092151 \\\\\n",
      "  connect-4 &  0.658 &  61.0 &  55.4 &         74.7 &          67.3 &     0.860036 \\\\\n",
      "   spambase &  0.394 &  79.8 &  71.1 &         73.4 &          65.9 &     0.300640 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results_UCI_SCAR = {'dataset':['mushrooms', 'shuttle', 'pageblocks', 'usps', 'connect-4', 'spambase'], \n",
    "               '$\\\\pi$': [0.518, 0.786, 0.898, 0.167, 0.658, 0.394],\n",
    "               'p-W': [avg_perfs_mushrooms_pw, avg_perfs_shuttle_pw, avg_perfs_pageblocks_pw, avg_perfs_usps_pw, avg_perfs_connect4_pw, avg_perfs_spambase_pw],\n",
    "               'p-GW': [avg_perfs_mushrooms_gw, avg_perfs_shuttle_gw, avg_perfs_pageblocks_gw, avg_perfs_usps_gw, avg_perfs_connect4_gw, avg_perfs_spambase_gw],\n",
    "               'OTP-w-prior': [avg_perfs_mushrooms_otp_, avg_perfs_shuttle_otp_, avg_perfs_pageblocks_otp_, avg_perfs_usps_otp_, avg_perfs_connect4_otp_, avg_perfs_spambase_otp_],\n",
    "               'OTP-wo-prior': [avg_perfs_mushrooms_otp, avg_perfs_shuttle_otp, avg_perfs_pageblocks_otp, avg_perfs_usps_otp, avg_perfs_connect4_otp, avg_perfs_spambase_otp],\n",
    "               '$\\\\hat{\\\\pi}$': [avg_mushrooms_detected_priors, avg_shuttle_detected_priors, avg_pageblocks_detected_priors, avg_usps_detected_priors, avg_connect4_detected_priors, avg_spambase_detected_priors]\n",
    "              }\n",
    "results_UCI_SCAR = pd.DataFrame(data=results_UCI_SCAR)\n",
    "myscale = lambda x : round(x*100,1)\n",
    "results_UCI_SCAR[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']] = results_UCI_SCAR[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']].apply(myscale,axis=1)\n",
    "results_UCI_SCAR.to_csv('results_UCI_SCAR.csv')\n",
    "print(results_UCI_SCAR.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    dataset  $\\pi$   p-W  p-GW  OTP-w-prior  OTP-wo-prior  $\\hat{\\pi}$\n",
      "  mushrooms  0.518  96.1  95.2         99.8          99.7     0.517375\n",
      "    shuttle  0.786  96.3  95.5         97.9          97.7     0.785440\n",
      " pageblocks  0.898  92.4  90.6         93.1          93.1     0.924708\n",
      "       usps  0.167  98.6  95.7         97.3          92.5     0.092151\n",
      "  connect-4  0.658  61.0  55.4         74.7          67.3     0.860036\n",
      "   spambase  0.394  79.8  71.1         73.4          65.9     0.300640\n"
     ]
    }
   ],
   "source": [
    "print(results_UCI_SCAR.to_string(index=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "accuracy std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrr}\n",
      "\\toprule\n",
      "    dataset &  p-W &  p-GW &  OTP-w-prior &  OTP-wo-prior &  \\$\\textbackslash hat\\{\\textbackslash pi\\}\\$ \\\\\n",
      "\\midrule\n",
      "  mushrooms &  0.8 &   1.0 &          0.3 &           0.1 &     0.001625 \\\\\n",
      "    shuttle &  1.1 &   1.4 &          0.9 &           0.8 &     0.007842 \\\\\n",
      " pageblocks &  0.9 &   1.2 &          0.7 &           1.1 &     0.030288 \\\\\n",
      "       usps &  0.5 &   1.1 &          0.5 &           1.7 &     0.017550 \\\\\n",
      "  connect-4 &  2.0 &   1.7 &          1.5 &           1.0 &     0.042889 \\\\\n",
      "   spambase &  1.8 &   1.5 &          1.3 &          12.7 &     0.343822 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results_UCI_SCAR_std = {'dataset':['mushrooms', 'shuttle', 'pageblocks', 'usps', 'connect-4', 'spambase'], \n",
    "               'p-W': [avg_perfs_mushrooms_pw_std, avg_perfs_shuttle_pw_std, avg_perfs_pageblocks_pw_std, avg_perfs_usps_pw_std, avg_perfs_connect4_pw_std, avg_perfs_spambase_pw_std],\n",
    "               'p-GW': [avg_perfs_mushrooms_gw_std, avg_perfs_shuttle_gw_std, avg_perfs_pageblocks_gw_std, avg_perfs_usps_gw_std, avg_perfs_connect4_gw_std, avg_perfs_spambase_gw_std],\n",
    "               'OTP-w-prior': [perfs_mushrooms_otp_std_, perfs_shuttle_otp_std_, perfs_pageblocks_otp_std_, perfs_usps_otp_std_, perfs_connect4_otp_std_, perfs_spambase_otp_std_],\n",
    "               'OTP-wo-prior': [perfs_mushrooms_otp_std, perfs_shuttle_otp_std, perfs_pageblocks_otp_std, perfs_usps_otp_std, perfs_connect4_otp_std, perfs_spambase_otp_std],\n",
    "               '$\\\\hat{\\\\pi}$': [avg_mushrooms_detected_priors_std, avg_shuttle_detected_priors_std, avg_pageblocks_detected_priors_std, avg_usps_detected_priors_std, avg_connect4_detected_priors_std, avg_spambase_detected_priors_std]\n",
    "              }\n",
    "results_UCI_SCAR_std = pd.DataFrame(data=results_UCI_SCAR_std)\n",
    "myscale = lambda x : round(x*100,1)\n",
    "results_UCI_SCAR_std[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']] = results_UCI_SCAR_std[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']].apply(myscale,axis=1)\n",
    "results_UCI_SCAR_std.to_csv('results_UCI.csv')\n",
    "print(results_UCI_SCAR_std.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    dataset  p-W  p-GW  OTP-w-prior  OTP-wo-prior  $\\hat{\\pi}$\n",
      "  mushrooms  0.8   1.0          0.3           0.1     0.001625\n",
      "    shuttle  1.1   1.4          0.9           0.8     0.007842\n",
      " pageblocks  0.9   1.2          0.7           1.1     0.030288\n",
      "       usps  0.5   1.1          0.5           1.7     0.017550\n",
      "  connect-4  2.0   1.7          1.5           1.0     0.042889\n",
      "   spambase  1.8   1.5          1.3          12.7     0.343822\n"
     ]
    }
   ],
   "source": [
    "print(results_UCI_SCAR_std.to_string(index=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Running Time Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrr}\n",
      "\\toprule\n",
      "    dataset &   p-W &  p-GW &  OTP-w-prior &  OTP-wo-prior \\\\\n",
      "\\midrule\n",
      "  mushrooms &  13.6 &  83.9 &          0.0 &           0.2 \\\\\n",
      "    shuttle &  11.3 &  82.0 &          0.0 &           0.1 \\\\\n",
      " pageblocks &  14.1 &  92.0 &          0.0 &           0.0 \\\\\n",
      "       usps &  27.3 &  80.4 &          0.2 &           0.2 \\\\\n",
      "  connect-4 &  37.3 &  83.4 &          0.2 &           0.2 \\\\\n",
      "   spambase &  13.1 &  80.0 &          0.1 &           0.0 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "times_UCI_SCAR = {'dataset':['mushrooms', 'shuttle', 'pageblocks', 'usps', 'connect-4', 'spambase'], \n",
    "               'p-W': [avg_perfs_mushrooms_pw_time, avg_perfs_shuttle_pw_time, avg_perfs_pageblocks_pw_time, avg_perfs_usps_pw_time, avg_perfs_connect4_pw_time, avg_perfs_spambase_pw_time],\n",
    "               'p-GW': [avg_perfs_mushrooms_gw_time, avg_perfs_shuttle_gw_time, avg_perfs_pageblocks_gw_time, avg_perfs_usps_gw_time, avg_perfs_connect4_gw_time, avg_perfs_spambase_gw_time],\n",
    "               'OTP-w-prior': [avg_perfs_mushrooms_otp_time_std_, avg_perfs_shuttle_otp_time_std_, avg_perfs_pageblocks_otp_time_std_, avg_perfs_usps_otp_time_std_, avg_perfs_connect4_otp_time_std_, avg_perfs_spambase_otp_time_std_],\n",
    "               'OTP-wo-prior': [avg_perfs_mushrooms_otp_time_std, avg_perfs_shuttle_otp_time_std, avg_perfs_pageblocks_otp_time_std, avg_perfs_usps_otp_time_std, avg_perfs_connect4_otp_time_std, avg_perfs_spambase_otp_time_std]\n",
    "              }\n",
    "times_UCI_SCAR = pd.DataFrame(data=times_UCI_SCAR)\n",
    "myscale = lambda x : round(x,1)\n",
    "times_UCI_SCAR[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']] = times_UCI_SCAR[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']].apply(myscale,axis=1)\n",
    "times_UCI_SCAR.to_csv('times_UCI_SCAR.csv')\n",
    "print(times_UCI_SCAR.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrr}\n",
      "\\toprule\n",
      "    dataset &    p-W &   p-GW &  OTP-w-prior &  OTP-wo-prior \\\\\n",
      "\\midrule\n",
      "  mushrooms &  0.185 &  0.428 &        0.016 &         0.248 \\\\\n",
      "    shuttle &  0.027 &  1.396 &        0.012 &         0.079 \\\\\n",
      " pageblocks &  0.197 &  2.365 &        0.022 &         0.024 \\\\\n",
      "       usps &  0.187 &  2.145 &        0.210 &         0.214 \\\\\n",
      "  connect-4 &  0.157 &  1.087 &        0.230 &         0.245 \\\\\n",
      "   spambase &  0.100 &  0.766 &        0.064 &         0.047 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "times_UCI_SCAR_std = {'dataset':['mushrooms', 'shuttle', 'pageblocks', 'usps', 'connect-4', 'spambase'], \n",
    "               'p-W': [perfs_mushrooms_pw_time_std, perfs_shuttle_pw_time_std, perfs_pageblocks_pw_time_std, perfs_usps_pw_time_std, perfs_connect4_pw_time_std, perfs_spambase_pw_time_std],\n",
    "               'p-GW': [perfs_mushrooms_gw_time_std, perfs_shuttle_gw_time_std, perfs_pageblocks_gw_time_std, perfs_usps_gw_time_std, perfs_connect4_gw_time_std, perfs_spambase_gw_time_std],\n",
    "               'OTP-w-prior': [avg_perfs_mushrooms_otp_time_std_, avg_perfs_shuttle_otp_time_std_, avg_perfs_pageblocks_otp_time_std_, avg_perfs_usps_otp_time_std_, avg_perfs_connect4_otp_time_std_, avg_perfs_spambase_otp_time_std_],\n",
    "               'OTP-wo-prior': [avg_perfs_mushrooms_otp_time_std, avg_perfs_shuttle_otp_time_std, avg_perfs_pageblocks_otp_time_std, avg_perfs_usps_otp_time_std, avg_perfs_connect4_otp_time_std, avg_perfs_spambase_otp_time_std]\n",
    "              }\n",
    "times_UCI_SCAR_std = pd.DataFrame(data=times_UCI_SCAR_std)\n",
    "myscale = lambda x : round(x,3)\n",
    "times_UCI_SCAR_std[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']] = times_UCI_SCAR_std[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']].apply(myscale,axis=1)\n",
    "times_UCI_SCAR_std.to_csv('times_UCI_SCAR_std.csv')\n",
    "print(times_UCI_SCAR_std.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Colored MNIST dataset - SAR scenario"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_unl = 800\n",
    "n_pos = 800\n",
    "nb_reps = 10\n",
    "nb_dummies = 10\n",
    "prior = 0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Partial-W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "perfs_mnist_pw, perfs_list_mnist_pw = pgw.compute_perf_emd('mnist', 'mnist', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_mnist_pw =  perfs_mnist_pw['emd_groups']\n",
    "avg_perfs_mnist_pw_time = np.mean(perfs_mnist_pw['time'])\n",
    "\n",
    "perfs_color_mnist_pw, perfs_list_color_mnist_pw = pgw.compute_perf_emd('mnist_color_change_p', 'mnist_color_change_u', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_color_mnist_pw =  perfs_color_mnist_pw['emd_groups']\n",
    "avg_perfs_color_mnist_pw_time = np.mean(perfs_color_mnist_pw['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_perfs_mnist_pw_std = np.std(perfs_list_mnist_pw['emd'])\n",
    "perfs_mnist_pw_time_std = np.std(perfs_mnist_pw['time'])\n",
    "\n",
    "avg_perfs_color_mnist_pw_std = np.std(perfs_list_color_mnist_pw['emd'])\n",
    "perfs_color_mnist_pw_time_std = np.std(perfs_color_mnist_pw['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Partial-GW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "perfs_mnist_gw, perfs_list_mnist_gw = pgw.compute_perf_pgw('mnist', 'mnist', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_mnist_gw =  perfs_mnist_gw['pgw']\n",
    "avg_perfs_mnist_gw_time = np.mean(perfs_mnist_gw['time'])\n",
    "\n",
    "perfs_color_mnist_gw, perfs_list_color_mnist_gw = pgw.compute_perf_pgw('mnist_color_change_p', 'mnist_color_change_u', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_color_mnist_gw =  perfs_color_mnist_gw['pgw']\n",
    "avg_perfs_color_mnist_gw_time = np.mean(perfs_color_mnist_gw['time'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_perfs_mnist_gw_std = np.std(perfs_list_mnist_gw['pgw'])\n",
    "perfs_mnist_gw_time_std = np.std(perfs_mnist_gw['time'])\n",
    "\n",
    "avg_perfs_color_mnist_gw_std = np.std(perfs_list_color_mnist_gw['pgw'])\n",
    "perfs_color_mnist_gw_time_std = np.std(perfs_color_mnist_gw['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OT-profile"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OTP-wo-prior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc_mnist:0.9806250000000001(0.007969669064647533)\n",
      "time_mnist:2.665975284576416(0.2548475324483216)\n",
      "pr_mnist:0.08064315524812742(0.007972062661151778)\n",
      "acc_mnist_color_change_p:0.9758749999999999(0.010232454495378898)\n",
      "time_mnist_color_change_p:2.8132790088653565(0.22833344261537275)\n",
      "pr_mnist_color_change_p:0.07589310655124905(0.010238159163898083)\n"
     ]
    }
   ],
   "source": [
    "avg_perfs_mnist_otp, perfs_mnist_otp_std, avg_perfs_mnist_otp_time, avg_perfs_mnist_otp_time_std, avg_mnist_detected_priors, avg_mnist_detected_priors_std = pgw.compute_a_case('mnist', 'mnist', n_unl, n_pos, prior, nb_reps, True)\n",
    "\n",
    "avg_perfs_mnist_color_change_p_otp, perfs_mnist_color_change_p_otp_std, avg_perfs_mnist_color_change_p_otp_time, avg_perfs_mnist_color_change_p_otp_time_std, avg_mnist_color_change_p_detected_priors, avg_mnist_color_change_p_detected_priors_std = pgw.compute_a_case('mnist_color_change_p', 'mnist_color_change_p', n_unl, n_pos, prior, nb_reps, True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OTP-w-prior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc_mnist:0.9932500000000001(0.002750000000000015)\n",
      "time_mnist:2.571903371810913(0.04984347327518463)\n",
      "pr_mnist:0.1(0.0)\n",
      "acc_mnist_color_change_p:0.99275(0.002358495283014162)\n",
      "time_mnist_color_change_p:2.832256722450256(0.17523188084049945)\n",
      "pr_mnist_color_change_p:0.1(0.0)\n"
     ]
    }
   ],
   "source": [
    "avg_perfs_mnist_otp_, perfs_mnist_otp_std_, avg_perfs_mnist_otp_time_, avg_perfs_mnist_otp_time_std_, avg_mnist_detected_priors_, avg_mnist_detected_priors_std_ = pgw.compute_a_case('mnist', 'mnist', n_unl, n_pos, prior, nb_reps, False)\n",
    "\n",
    "avg_perfs_mnist_color_change_p_otp_, perfs_mnist_color_change_p_otp_std_, avg_perfs_mnist_color_change_p_otp_time_, avg_perfs_mnist_color_change_p_otp_time_std_, avg_mnist_color_change_p_detected_priors_, avg_mnist_color_change_p_detected_priors_std_ = pgw.compute_a_case('mnist_color_change_p', 'mnist_color_change_p', n_unl, n_pos, prior, nb_reps, False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tab. 1 (continued)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrrr}\n",
      "\\toprule\n",
      "       dataset &  \\$\\textbackslash pi\\$ &   p-W &  p-GW &  OTP-w-prior &  OTP-wo-prior &  detected priors from OT-profile \\\\\n",
      "\\midrule\n",
      "         mnist &    0.1 &  99.1 &  98.4 &         99.3 &          98.1 &                         0.080643 \\\\\n",
      " colored mnist &    0.1 &  91.6 &  97.5 &         99.3 &          97.6 &                         0.075893 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results_MNIST_SAR = {'dataset':['mnist', 'colored mnist'], \n",
    "               '$\\pi$': [0.1, 0.1],\n",
    "               'p-W': [avg_perfs_mnist_pw, avg_perfs_color_mnist_pw],\n",
    "               'p-GW': [avg_perfs_mnist_gw, avg_perfs_color_mnist_gw],\n",
    "               'OTP-w-prior': [avg_perfs_mnist_otp_, avg_perfs_mnist_color_change_p_otp_],\n",
    "               'OTP-wo-prior': [avg_perfs_mnist_otp, avg_perfs_mnist_color_change_p_otp],\n",
    "               'detected priors from OT-profile': [avg_mnist_detected_priors, avg_mnist_color_change_p_detected_priors]\n",
    "              }\n",
    "results_MNIST_SAR = pd.DataFrame(data=results_MNIST_SAR)\n",
    "myscale = lambda x : round(x*100,1)\n",
    "results_MNIST_SAR[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']] = results_MNIST_SAR[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']].apply(myscale,axis=1)\n",
    "results_MNIST_SAR.to_csv('results_MNIST_SAR.csv')\n",
    "print(results_MNIST_SAR.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "       dataset  $\\pi$   p-W  p-GW  OTP-w-prior  OTP-wo-prior  detected priors from OT-profile\n",
      "         mnist    0.1  99.1  98.4         99.3          98.1                         0.080643\n",
      " colored mnist    0.1  91.6  97.5         99.3          97.6                         0.075893\n"
     ]
    }
   ],
   "source": [
    "print(results_MNIST_SAR.to_string(index=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrr}\n",
      "\\toprule\n",
      "       dataset &    p-W &    p-GW &  OTP-w-prior &  OTP-wo-prior &  \\$\\textbackslash hat\\{\\textbackslash pi\\}\\$ \\\\\n",
      "\\midrule\n",
      "         mnist &  0.004 &   0.004 &        0.003 &         0.003 &     0.007972 \\\\\n",
      " colored mnist &  0.004 &  81.835 &        0.002 &         0.002 &     0.010238 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "time_MNIST_SAR_std = {'dataset':['mnist', 'colored mnist'], \n",
    "               'p-W': [avg_perfs_mnist_pw_std, avg_perfs_color_mnist_pw_std],\n",
    "               'p-GW': [avg_perfs_mnist_gw_std, avg_perfs_color_mnist_gw_time],\n",
    "               'OTP-w-prior': [perfs_mnist_otp_std_, perfs_mnist_color_change_p_otp_std_],\n",
    "               'OTP-wo-prior': [perfs_mnist_otp_std_, perfs_mnist_color_change_p_otp_std_],\n",
    "               '$\\\\hat{\\\\pi}$': [avg_mnist_detected_priors_std, avg_mnist_color_change_p_detected_priors_std]\n",
    "              }\n",
    "time_MNIST_SAR_std = pd.DataFrame(data=time_MNIST_SAR_std)\n",
    "myscale = lambda x : round(x,3)\n",
    "time_MNIST_SAR_std[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']] = time_MNIST_SAR_std[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']].apply(myscale,axis=1)\n",
    "time_MNIST_SAR_std.to_csv('times_MNIST.csv')\n",
    "print(time_MNIST_SAR_std.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrr}\n",
      "\\toprule\n",
      "       dataset &   p-W &  p-GW &  OTP-w-prior &  OTP-wo-prior \\\\\n",
      "\\midrule\n",
      "         mnist &  17.5 &  81.6 &          0.0 &           0.2 \\\\\n",
      " colored mnist &  19.6 &  81.8 &          2.8 &           2.8 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "times_MNIST_SAR = {'dataset':['mnist', 'colored mnist'], \n",
    "               'p-W': [avg_perfs_mnist_pw_time, avg_perfs_color_mnist_pw_time],\n",
    "               'p-GW': [avg_perfs_mnist_gw_time, avg_perfs_color_mnist_gw_time],\n",
    "               'OTP-w-prior': [avg_perfs_mushrooms_otp_time_std_, avg_perfs_mnist_color_change_p_otp_time_],\n",
    "               'OTP-wo-prior': [avg_perfs_mushrooms_otp_time_std, avg_perfs_mnist_color_change_p_otp_time]\n",
    "              }\n",
    "times_MNIST_SAR = pd.DataFrame(data=times_MNIST_SAR)\n",
    "myscale = lambda x : round(x,1)\n",
    "times_MNIST_SAR[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']] = times_MNIST_SAR[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']].apply(myscale,axis=1)\n",
    "times_MNIST_SAR.to_csv('times_MNIST_SAR.csv')\n",
    "print(times_MNIST_SAR.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrr}\n",
      "\\toprule\n",
      "       dataset &    p-W &   p-GW &  OTP-w-prior &  OTP-wo-prior \\\\\n",
      "\\midrule\n",
      "         mnist &  0.000 &  1.063 &        0.050 &         0.255 \\\\\n",
      " colored mnist &  0.027 &  1.396 &        0.012 &         0.079 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "times_MNIST_SAR_std = {'dataset':['mnist', 'colored mnist'], \n",
    "               'p-W': [perfs_mnist_pw_time_std, perfs_shuttle_pw_time_std],\n",
    "               'p-GW': [perfs_mnist_gw_time_std, perfs_shuttle_gw_time_std],\n",
    "               'OTP-w-prior': [avg_perfs_mnist_otp_time_std_, avg_perfs_shuttle_otp_time_std_],\n",
    "               'OTP-wo-prior': [avg_perfs_mnist_otp_time_std, avg_perfs_shuttle_otp_time_std]\n",
    "              }\n",
    "times_MNIST_SAR_std = pd.DataFrame(data=times_MNIST_SAR_std)\n",
    "myscale = lambda x : round(x,3)\n",
    "times_MNIST_SAR_std[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']] = times_MNIST_SAR_std[['p-W','p-GW','OTP-w-prior','OTP-wo-prior']].apply(myscale,axis=1)\n",
    "times_MNIST_SAR_std.to_csv('times_MNIST_SAR_std.csv')\n",
    "print(times_MNIST_SAR_std.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "af8a8dc5d9d5fe20562b2a82edc81ddace8b482cbd5fbca7942e2cde7b4f1d64"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
