{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# This notebook contains the figures relatives to the Table 1 and 2 of the paper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The performances may slightly differ from the figures reported in the paper as we did not fix the seed to draw the datasets. However, it does not change any conclusions!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Library/Java/JavaVirtualMachines/jdk1.8.0_161.jdk/Contents/Home/jre/lib/jli/libjli.dylib\n"
     ]
    }
   ],
   "source": [
    "import utils\n",
    "import numpy as np\n",
    "from IPython.display import HTML, display, Markdown\n",
    "import pandas as pd\n",
    "\n",
    "import ot\n",
    "import partial_gw as pgw\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## UCI dataset - SCAR scenario"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_unl = 800\n",
    "n_pos = 800\n",
    "nb_reps = 10\n",
    "nb_dummies = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Partial-W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "prior = 0.518\n",
    "perfs_mushrooms_pw, perfs_list_mushrooms_pw = pgw.compute_perf_emd('mushrooms', 'mushrooms', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_mushrooms_pw =  perfs_mushrooms_pw['emd_groups']\n",
    "avg_perfs_mushrooms_pw_time = np.mean(perfs_mushrooms_pw['time'])\n",
    "\n",
    "prior = 0.786\n",
    "perfs_shuttle_pw, perfs_list_shuttle_pw = pgw.compute_perf_emd('shuttle', 'shuttle', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_shuttle_pw =  perfs_shuttle_pw['emd_groups']\n",
    "avg_perfs_shuttle_pw_time = np.mean(perfs_shuttle_pw['time'])\n",
    "\n",
    "prior = 0.898\n",
    "perfs_pageblocks_pw, perfs_list_pageblocks_pw = pgw.compute_perf_emd('pageblocks', 'pageblocks', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_pageblocks_pw =  perfs_pageblocks_pw['emd_groups']\n",
    "avg_perfs_pageblocks_pw_time = np.mean(perfs_pageblocks_pw['time'])\n",
    "\n",
    "prior = 0.167\n",
    "perfs_usps_pw, perfs_list_usps_pw = pgw.compute_perf_emd('usps', 'usps', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_usps_pw =  perfs_usps_pw['emd_groups']\n",
    "avg_perfs_usps_pw_time = np.mean(perfs_usps_pw['time'])\n",
    "\n",
    "prior = 0.658\n",
    "perfs_connect4_pw, perfs_list_connect4_pw = pgw.compute_perf_emd('connect-4', 'connect-4', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_connect4_pw =  perfs_connect4_pw['emd_groups']\n",
    "avg_perfs_connect4_pw_time = np.mean(perfs_connect4_pw['time'])\n",
    "\n",
    "prior = 0.394\n",
    "perfs_spambase_pw, perfs_list_spambase_pw = pgw.compute_perf_emd('spambase', 'spambase', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_spambase_pw =  perfs_spambase_pw['emd_groups']\n",
    "avg_perfs_spambase_pw_time = np.mean(perfs_spambase_pw['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Partial-GW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/kaiyizhang/opt/anaconda3/envs/OT-GAN/lib/python3.6/site-packages/ot/lp/__init__.py:329: UserWarning: numItermax reached before optimality. Try to increase numItermax.\n",
      "  result_code_string = check_result(result_code)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error in the EMD!!!!!!!\n"
     ]
    }
   ],
   "source": [
    "prior = 0.518\n",
    "perfs_mushrooms_gw, perfs_list_mushrooms_gw = pgw.compute_perf_pgw('mushrooms', 'mushrooms', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_mushrooms_gw =  perfs_mushrooms_gw['pgw']\n",
    "avg_perfs_mushrooms_gw_time = np.mean(perfs_mushrooms_gw['time'])\n",
    "\n",
    "prior = 0.786\n",
    "perfs_shuttle_gw, perfs_list_shuttle_gw = pgw.compute_perf_pgw('shuttle', 'shuttle', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_shuttle_gw =  perfs_shuttle_gw['pgw']\n",
    "avg_perfs_shuttle_gw_time = np.mean(perfs_shuttle_gw['time'])\n",
    "\n",
    "prior = 0.898\n",
    "perfs_pageblocks_gw, perfs_list_pageblocks_gw = pgw.compute_perf_pgw('pageblocks', 'pageblocks', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_pageblocks_gw =  perfs_pageblocks_gw['pgw']\n",
    "avg_perfs_pageblocks_gw_time = np.mean(perfs_pageblocks_gw['time'])\n",
    "\n",
    "prior = 0.167\n",
    "perfs_usps_gw, perfs_list_usps_gw = pgw.compute_perf_pgw('usps', 'usps', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_usps_gw =  perfs_usps_gw['pgw']\n",
    "avg_perfs_usps_gw_time = np.mean(perfs_usps_gw['time'])\n",
    "\n",
    "prior = 0.658\n",
    "perfs_connect4_gw, perfs_list_connect4_gw = pgw.compute_perf_pgw('connect-4', 'connect-4', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_connect4_gw =  perfs_connect4_gw['pgw']\n",
    "avg_perfs_connect4_gw_time = np.mean(perfs_connect4_gw['time'])\n",
    "\n",
    "prior = 0.394\n",
    "perfs_spambase_gw, perfs_list_spambase_gw = pgw.compute_perf_pgw('spambase', 'spambase', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_spambase_gw =  perfs_spambase_gw['pgw']\n",
    "avg_perfs_spambase_gw_time = np.mean(perfs_spambase_gw['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OT-Profile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "prior = 0.518\n",
    "perfs_mushrooms_otp, perfs_list_mushrooms_otp, detected_priors_mushrooms = pgw.compute_perf_GTOT('mushrooms', 'mushrooms', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_mushrooms_otp =  perfs_mushrooms_otp['GTOT']\n",
    "avg_perfs_mushrooms_otp_time = np.mean(perfs_mushrooms_otp['time'])\n",
    "avg_detected_priors_mushrooms = np.mean(detected_priors_mushrooms)\n",
    "\n",
    "prior = 0.786\n",
    "perfs_shuttle_otp, perfs_list_shuttle_otp, detected_priors_shuttle = pgw.compute_perf_GTOT('shuttle', 'shuttle', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_shuttle_otp =  perfs_shuttle_otp['GTOT']\n",
    "avg_perfs_shuttle_otp_time = np.mean(perfs_shuttle_otp['time'])\n",
    "avg_detected_priors_shuttle = np.mean(detected_priors_shuttle)\n",
    "\n",
    "prior = 0.898\n",
    "perfs_pageblocks_otp, perfs_list_pageblocks_otp, detected_priors_pageblocks = pgw.compute_perf_GTOT('pageblocks', 'pageblocks', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_pageblocks_otp =  perfs_pageblocks_otp['GTOT']\n",
    "avg_perfs_pageblocks_otp_time = np.mean(perfs_pageblocks_otp['time'])\n",
    "avg_detected_priors_pageblocks = np.mean(detected_priors_pageblocks)\n",
    "\n",
    "prior = 0.167\n",
    "perfs_usps_otp, perfs_list_usps_otp, detected_priors_usps = pgw.compute_perf_GTOT('usps', 'usps', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_usps_otp =  perfs_usps_otp['GTOT']\n",
    "avg_perfs_usps_otp_time = np.mean(perfs_usps_otp['time'])\n",
    "avg_detected_priors_usps = np.mean(detected_priors_usps)\n",
    "\n",
    "prior = 0.658\n",
    "perfs_connect4_otp, perfs_list_connect4_otp, detected_priors_connect4 = pgw.compute_perf_GTOT('connect-4', 'connect-4', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_connect4_otp =  perfs_connect4_otp['GTOT']\n",
    "avg_perfs_connect4_otp_time = np.mean(perfs_connect4_otp['time'])\n",
    "avg_detected_priors_connect4 = np.mean(detected_priors_connect4)\n",
    "\n",
    "prior = 0.394\n",
    "perfs_spambase_otp, perfs_list_spambase_otp, detected_priors_spambase = pgw.compute_perf_GTOT('spambase', 'spambase', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_spambase_otp =  perfs_spambase_otp['GTOT']\n",
    "avg_perfs_spambase_otp_time = np.mean(perfs_spambase_otp['time'])\n",
    "avg_detected_priors_spambase = np.mean(detected_priors_spambase)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "prior = 0.518\n",
    "perfs_mushrooms_otp_, perfs_list_mushrooms_otp_, detected_priors_mushrooms_ = pgw.compute_perf_GTOT('mushrooms', 'mushrooms', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_mushrooms_otp_ =  perfs_mushrooms_otp_['GTOT']\n",
    "avg_perfs_mushrooms_otp_time_ = np.mean(perfs_mushrooms_otp_['time'])\n",
    "\n",
    "prior = 0.786\n",
    "perfs_shuttle_otp_, perfs_list_shuttle_otp_, detected_priors_shuttle_ = pgw.compute_perf_GTOT('shuttle', 'shuttle', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_shuttle_otp_ =  perfs_shuttle_otp_['GTOT']\n",
    "avg_perfs_shuttle_otp_time_ = np.mean(perfs_shuttle_otp_['time'])\n",
    "\n",
    "prior = 0.898\n",
    "perfs_pageblocks_otp_, perfs_list_pageblocks_otp_, detected_priors_pageblocks_ = pgw.compute_perf_GTOT('pageblocks', 'pageblocks', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_pageblocks_otp_ =  perfs_pageblocks_otp_['GTOT']\n",
    "avg_perfs_pageblocks_otp_time_ = np.mean(perfs_pageblocks_otp_['time'])\n",
    "\n",
    "prior = 0.167\n",
    "perfs_usps_otp_, perfs_list_usps_otp_, detected_priors_usps_ = pgw.compute_perf_GTOT('usps', 'usps', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_usps_otp_ =  perfs_usps_otp_['GTOT']\n",
    "avg_perfs_usps_otp_time_ = np.mean(perfs_usps_otp_['time'])\n",
    "\n",
    "prior = 0.658\n",
    "perfs_connect4_otp_, perfs_list_connect4_otp_, detected_priors_connect4_ = pgw.compute_perf_GTOT('connect-4', 'connect-4', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_connect4_otp_ =  perfs_connect4_otp_['GTOT']\n",
    "avg_perfs_connect4_otp_time_ = np.mean(perfs_connect4_otp_['time'])\n",
    "\n",
    "prior = 0.394\n",
    "perfs_spambase_otp_, perfs_list_spambase_otp_, detected_priors_spambase_ = pgw.compute_perf_GTOT('spambase', 'spambase', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_spambase_otp_ =  perfs_spambase_otp_['GTOT']\n",
    "avg_perfs_spambase_otp_time_ = np.mean(perfs_spambase_otp_['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tab.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrr}\n",
      "\\toprule\n",
      "    dataset &  \\$\\textbackslash pi\\$ &   p-W &  p-GW &  ot-profile &  \\$\\textbackslash hat\\{\\textbackslash pi\\}\\$ estimated by OT profile \\\\\n",
      "\\midrule\n",
      "  mushrooms &  0.518 &  96.1 &  95.2 &        99.4 &                             0.512125 \\\\\n",
      "    shuttle &  0.786 &  96.3 &  95.5 &        98.2 &                             0.786942 \\\\\n",
      " pageblocks &  0.898 &  92.4 &  90.6 &        92.7 &                             0.893845 \\\\\n",
      "       usps &  0.167 &  98.6 &  95.7 &        96.7 &                             0.165422 \\\\\n",
      "  connect-4 &  0.658 &  61.0 &  55.4 &        68.3 &                             0.859036 \\\\\n",
      "   spambase &  0.394 &  79.8 &  71.1 &        71.3 &                             0.139417 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results_UCI = {'dataset':['mushrooms', 'shuttle', 'pageblocks', 'usps', 'connect-4', 'spambase'], \n",
    "               '$\\\\pi$': [0.518, 0.786, 0.898, 0.167, 0.658, 0.394],\n",
    "               'p-W': [avg_perfs_mushrooms_pw, avg_perfs_shuttle_pw, avg_perfs_pageblocks_pw, avg_perfs_usps_pw, avg_perfs_connect4_pw, avg_perfs_spambase_pw],\n",
    "               'p-GW': [avg_perfs_mushrooms_gw, avg_perfs_shuttle_gw, avg_perfs_pageblocks_gw, avg_perfs_usps_gw, avg_perfs_connect4_gw, avg_perfs_spambase_gw],\n",
    "               'ot-profile': [avg_perfs_mushrooms_otp, avg_perfs_shuttle_otp, avg_perfs_pageblocks_otp, avg_perfs_usps_otp, avg_perfs_connect4_otp, avg_perfs_spambase_otp],\n",
    "               '$\\\\hat{\\\\pi}$ estimated by OT profile': [avg_detected_priors_mushrooms, avg_detected_priors_shuttle, avg_detected_priors_pageblocks, avg_detected_priors_usps, avg_detected_priors_connect4, avg_detected_priors_spambase]\n",
    "              }\n",
    "results_UCI = pd.DataFrame(data=results_UCI)\n",
    "myscale = lambda x : round(x*100,1)\n",
    "results_UCI[['p-W','p-GW','ot-profile']] = results_UCI[['p-W','p-GW','ot-profile']].apply(myscale,axis=1)\n",
    "results_UCI.to_csv('results_UCI.csv')\n",
    "print(results_UCI.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    dataset  $\\pi$   p-W  p-GW  ot-profile  $\\hat{\\pi}$ estimated by OT profile\n",
      "  mushrooms  0.518  96.1  95.2        99.4                             0.512125\n",
      "    shuttle  0.786  96.3  95.5        98.2                             0.786942\n",
      " pageblocks  0.898  92.4  90.6        92.7                             0.893845\n",
      "       usps  0.167  98.6  95.7        96.7                             0.165422\n",
      "  connect-4  0.658  61.0  55.4        68.3                             0.859036\n",
      "   spambase  0.394  79.8  71.1        71.3                             0.139417\n"
     ]
    }
   ],
   "source": [
    "print(results_UCI.to_string(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrr}\n",
      "\\toprule\n",
      "    dataset &  p-W &    p-GW &  ot-profile \\\\\n",
      "\\midrule\n",
      "  mushrooms &  1.3 &    88.8 &         0.7 \\\\\n",
      "    shuttle &  1.1 &    87.5 &         0.6 \\\\\n",
      " pageblocks &  1.5 &    94.9 &         0.3 \\\\\n",
      "       usps &  2.3 &    82.4 &         8.1 \\\\\n",
      "  connect-4 &  2.4 &  1179.7 &         2.1 \\\\\n",
      "   spambase &  1.3 &    96.0 &         0.8 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "times_UCI = {'dataset':['mushrooms', 'shuttle', 'pageblocks', 'usps', 'connect-4', 'spambase'], \n",
    "               'p-W': [avg_perfs_mushrooms_pw_time, avg_perfs_shuttle_pw_time, avg_perfs_pageblocks_pw_time, avg_perfs_usps_pw_time, avg_perfs_connect4_pw_time, avg_perfs_spambase_pw_time],\n",
    "               'p-GW': [avg_perfs_mushrooms_gw_time, avg_perfs_shuttle_gw_time, avg_perfs_pageblocks_gw_time, avg_perfs_usps_gw_time, avg_perfs_connect4_gw_time, avg_perfs_spambase_gw_time],\n",
    "               'ot-profile': [avg_perfs_mushrooms_otp_time, avg_perfs_shuttle_otp_time, avg_perfs_pageblocks_otp_time, avg_perfs_usps_otp_time, avg_perfs_connect4_otp_time, avg_perfs_spambase_otp_time]\n",
    "              }\n",
    "times_UCI = pd.DataFrame(data=times_UCI)\n",
    "myscale = lambda x : round(x/nb_reps,1)\n",
    "times_UCI[['p-W','p-GW','ot-profile']] = times_UCI[['p-W','p-GW','ot-profile']].apply(myscale,axis=1)\n",
    "times_UCI.to_csv('times_UCI.csv')\n",
    "print(times_UCI.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrr}\n",
      "\\toprule\n",
      "    dataset &  \\$\\textbackslash pi\\$ &  ot-profile \\\\\n",
      "\\midrule\n",
      "  mushrooms &  0.518 &        99.8 \\\\\n",
      "    shuttle &  0.786 &        97.9 \\\\\n",
      " pageblocks &  0.898 &        93.1 \\\\\n",
      "       usps &  0.167 &        97.3 \\\\\n",
      "  connect-4 &  0.658 &        74.7 \\\\\n",
      "   spambase &  0.394 &        73.4 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results_UCI_wprior = {'dataset':['mushrooms', 'shuttle', 'pageblocks', 'usps', 'connect-4', 'spambase'], \n",
    "               '$\\\\pi$': [0.518, 0.786, 0.898, 0.167, 0.658, 0.394],\n",
    "               'ot-profile': [avg_perfs_mushrooms_otp_, avg_perfs_shuttle_otp_, avg_perfs_pageblocks_otp_, avg_perfs_usps_otp_, avg_perfs_connect4_otp_, avg_perfs_spambase_otp_]\n",
    "              }\n",
    "results_UCI_wprior = pd.DataFrame(data=results_UCI_wprior)\n",
    "myscale = lambda x : round(x*100,1)\n",
    "results_UCI_wprior[['ot-profile']] = results_UCI_wprior[['ot-profile']].apply(myscale,axis=1)\n",
    "results_UCI_wprior.to_csv('results_UCI.csv')\n",
    "print(results_UCI_wprior.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lr}\n",
      "\\toprule\n",
      "    dataset &  ot-profile \\\\\n",
      "\\midrule\n",
      "  mushrooms &         0.6 \\\\\n",
      "    shuttle &         0.5 \\\\\n",
      " pageblocks &         0.3 \\\\\n",
      "       usps &         8.1 \\\\\n",
      "  connect-4 &         2.0 \\\\\n",
      "   spambase &         0.8 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "times_UCI_wprrior = {'dataset':['mushrooms', 'shuttle', 'pageblocks', 'usps', 'connect-4', 'spambase'], \n",
    "               'ot-profile': [avg_perfs_mushrooms_otp_time_, avg_perfs_shuttle_otp_time_, avg_perfs_pageblocks_otp_time_, avg_perfs_usps_otp_time_, avg_perfs_connect4_otp_time_, avg_perfs_spambase_otp_time_]\n",
    "              }\n",
    "times_UCI_wprrior = pd.DataFrame(data=times_UCI_wprrior)\n",
    "myscale = lambda x : round(x/nb_reps,1)\n",
    "times_UCI_wprrior[['ot-profile']] = times_UCI_wprrior[['ot-profile']].apply(myscale,axis=1)\n",
    "times_UCI_wprrior.to_csv('times_UCI_wprrior.csv')\n",
    "print(times_UCI_wprrior.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Colored MNIST dataset - SAR scenario"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_unl = 800\n",
    "n_pos = 800\n",
    "nb_reps = 10\n",
    "nb_dummies = 10\n",
    "prior = 0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Partial-W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "perfs_mnist_pw, perfs_list_mnist_pw = pgw.compute_perf_emd('mnist', 'mnist', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_mnist_pw =  perfs_mnist_pw['emd_groups']\n",
    "avg_perfs_mnist_pw_time = np.mean(perfs_mnist_pw['time'])\n",
    "\n",
    "perfs_color_mnist_pw, perfs_list_color_mnist_pw = pgw.compute_perf_emd('mnist_color_change_p', 'mnist_color_change_u', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_color_mnist_pw =  perfs_color_mnist_pw['emd_groups']\n",
    "avg_perfs_color_mnist_pw_time = np.mean(perfs_color_mnist_pw['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Partial-GW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "perfs_mnist_gw, perfs_list_mnist_gw = pgw.compute_perf_pgw('mnist', 'mnist', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_mnist_gw =  perfs_mnist_gw['pgw']\n",
    "avg_perfs_mnist_gw_time = np.mean(perfs_mnist_gw['time'])\n",
    "\n",
    "perfs_color_mnist_gw, perfs_list_color_mnist_gw = pgw.compute_perf_pgw('mnist_color_change_p', 'mnist_color_change_u', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_color_mnist_gw =  perfs_color_mnist_gw['pgw']\n",
    "avg_perfs_color_mnist_gw_time = np.mean(perfs_color_mnist_gw['time'])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OT-profile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "perfs_mnist_otp, perfs_list_mnist_otp, detected_priors_mnist = pgw.compute_perf_GTOT('mnist', 'mnist', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_mnist_otp =  perfs_mnist_otp['GTOT']\n",
    "avg_detected_priors_mnist = np.mean(detected_priors_mnist)\n",
    "avg_perfs_mnist_otp_time = np.mean(perfs_mnist_otp['time'])\n",
    "\n",
    "# perfs_mnist, perfs_list_mnist = pgw.compute_perf_emd('mnist_color_change_p', 'mnist_color_change_u', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "perfs_color_mnist_otp, perfs_list_color_mnist_otp, detected_priors_color_mnist = pgw.compute_perf_GTOT('mnist_color_change_p', 'mnist_color_change_p', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_color_mnist_otp =  perfs_color_mnist_otp['GTOT']\n",
    "avg_detected_priors_color_mnist = np.mean(detected_priors_color_mnist)\n",
    "avg_perfs_color_mnist_otp_time = np.mean(perfs_color_mnist_otp['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "perfs_mnist_otp_, perfs_list_mnist_otp_, detected_priors_mnist_ = pgw.compute_perf_GTOT('mnist', 'mnist', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_mnist_otp_ =  perfs_mnist_otp_['GTOT']\n",
    "avg_perfs_mnist_otp_time_ = np.mean(perfs_mnist_otp_['time'])\n",
    "\n",
    "# perfs_mnist, perfs_list_mnist = pgw.compute_perf_emd('mnist_color_change_p', 'mnist_color_change_u', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "perfs_color_mnist_otp_, perfs_list_color_mnist_otp_, detected_priors_color_mnist_ = pgw.compute_perf_GTOT('mnist_color_change_p', 'mnist_color_change_p', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_color_mnist_otp_ =  perfs_color_mnist_otp_['GTOT']\n",
    "avg_perfs_color_mnist_otp_time_ = np.mean(perfs_color_mnist_otp_['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tab. 1 (continued)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrrrr}\n",
      "\\toprule\n",
      "       dataset &  \\$\\textbackslash pi\\$ &   p-W &  p-GW &  OT-profile &  detected priors from OT-profile \\\\\n",
      "\\midrule\n",
      "         mnist &    0.1 &  99.1 &  98.4 &        96.0 &                         0.109025 \\\\\n",
      " colored mnist &    0.1 &  91.6 &  97.5 &        95.8 &                         0.117403 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results_MNIST = {'dataset':['mnist', 'colored mnist'], \n",
    "               '$\\pi$': [0.1, 0.1],\n",
    "               'p-W': [avg_perfs_mnist_pw, avg_perfs_color_mnist_pw],\n",
    "               'p-GW': [avg_perfs_mnist_gw, avg_perfs_color_mnist_gw],\n",
    "               'OT-profile': [avg_perfs_mnist_otp, avg_perfs_color_mnist_otp],\n",
    "               'detected priors from OT-profile': [avg_detected_priors_mnist, avg_detected_priors_color_mnist]\n",
    "              }\n",
    "results_MNIST = pd.DataFrame(data=results_MNIST)\n",
    "myscale = lambda x : round(x*100,1)\n",
    "results_MNIST[['p-W','p-GW','OT-profile']] = results_MNIST[['p-W','p-GW','OT-profile']].apply(myscale,axis=1)\n",
    "results_MNIST.to_csv('results_MNIST.csv')\n",
    "print(results_MNIST.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "       dataset  $\\pi$   p-W  p-GW  OT-profile  detected priors from OT-profile\n",
      "         mnist    0.1  99.1  98.4        96.0                         0.109025\n",
      " colored mnist    0.1  91.6  97.5        95.8                         0.117403\n"
     ]
    }
   ],
   "source": [
    "print(results_MNIST.to_string(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrr}\n",
      "\\toprule\n",
      "       dataset &  p-W &  p-GW &  ot-profile \\\\\n",
      "\\midrule\n",
      "         mnist &  2.1 &  97.7 &         3.0 \\\\\n",
      " colored mnist &  2.3 &  97.6 &         3.3 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "time_MNIST = {'dataset':['mnist', 'colored mnist'], \n",
    "               'p-W': [avg_perfs_mnist_pw_time, avg_perfs_color_mnist_pw_time],\n",
    "               'p-GW': [avg_perfs_mnist_gw_time, avg_perfs_color_mnist_gw_time],\n",
    "               'ot-profile': [avg_perfs_mnist_otp_time, avg_perfs_color_mnist_otp_time],\n",
    "              }\n",
    "time_MNIST = pd.DataFrame(data=time_MNIST)\n",
    "myscale = lambda x : round(x/nb_reps,1)\n",
    "time_MNIST[['p-W','p-GW','ot-profile']] = time_MNIST[['p-W','p-GW','ot-profile']].apply(myscale,axis=1)\n",
    "time_MNIST.to_csv('times_MNIST.csv')\n",
    "print(time_MNIST.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrr}\n",
      "\\toprule\n",
      "       dataset &  \\$\\textbackslash pi\\$ &  OT-profile \\\\\n",
      "\\midrule\n",
      "         mnist &    0.1 &        99.3 \\\\\n",
      " colored mnist &    0.1 &        99.3 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results_MNIST_wprior = {'dataset':['mnist', 'colored mnist'], \n",
    "               '$\\pi$': [0.1, 0.1],\n",
    "               'OT-profile': [avg_perfs_mnist_otp_, avg_perfs_color_mnist_otp_]\n",
    "              }\n",
    "results_MNIST_wprior = pd.DataFrame(data=results_MNIST_wprior)\n",
    "myscale = lambda x : round(x*100,1)\n",
    "results_MNIST_wprior[['OT-profile']] = results_MNIST_wprior[['OT-profile']].apply(myscale,axis=1)\n",
    "results_MNIST_wprior.to_csv('results_MNIST_wprior.csv')\n",
    "print(results_MNIST_wprior.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lr}\n",
      "\\toprule\n",
      "       dataset &  ot-profile \\\\\n",
      "\\midrule\n",
      "         mnist &         3.0 \\\\\n",
      " colored mnist &         3.3 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "time_MNIST_wprior = {'dataset':['mnist', 'colored mnist'], \n",
    "               'ot-profile': [avg_perfs_mnist_otp_time_, avg_perfs_color_mnist_otp_time_],\n",
    "              }\n",
    "time_MNIST_wprior = pd.DataFrame(data=time_MNIST_wprior)\n",
    "myscale = lambda x : round(x/nb_reps,1)\n",
    "time_MNIST_wprior[['ot-profile']] = time_MNIST_wprior[['ot-profile']].apply(myscale,axis=1)\n",
    "time_MNIST_wprior.to_csv('times_MNIST.csv')\n",
    "print(time_MNIST_wprior.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Caltech dataset - PU on different domains"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_unl = 100\n",
    "n_pos = 100\n",
    "nb_reps = 10\n",
    "nb_dummies = 10\n",
    "prior = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_caltech_emd_groups = []\n",
    "avg_caltech_gw_groups = []\n",
    "domain_u = ['surf_Caltech10', 'surf_amazon', 'surf_webcam', 'surf_dslr']\n",
    "for d in domain_u:\n",
    "    perfs_caltech_surf, perfs_list_caltech_surf = pgw.compute_perf_emd('surf_Caltech10', d, n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "    avg_caltech_emd_groups.append(perfs_caltech_surf['emd_groups'])\n",
    "    perfs_caltech_surf, perfs_list_caltech_surf = pgw.compute_perf_pgw('surf_Caltech10', d, n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "    avg_caltech_gw_groups.append(perfs_caltech_surf['pgw'])\n",
    "domain_u = ['decaf_caltech', 'decaf_amazon', 'decaf_webcam', 'decaf_dslr']\n",
    "for d in domain_u:\n",
    "    perfs_caltech_surf, perfs_list_caltech_surf = pgw.compute_perf_emd('decaf_caltech', d, n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "    avg_caltech_emd_groups.append(perfs_caltech_surf['emd_groups'])\n",
    "    perfs_caltech_surf, perfs_list_caltech_surf = pgw.compute_perf_pgw('decaf_caltech', d, n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "    avg_caltech_gw_groups.append(perfs_caltech_surf['pgw'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tab.1 (continued)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>$\\pi$</th>\n",
       "      <th>p-W</th>\n",
       "      <th>p-GW</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>surf C -&gt; surf C</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.900</td>\n",
       "      <td>0.864</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>surf C -&gt; surf A</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.816</td>\n",
       "      <td>0.870</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>surf C -&gt; surf W</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.822</td>\n",
       "      <td>0.862</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>surf C -&gt; surf D</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.800</td>\n",
       "      <td>0.880</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>decaf C -&gt; decaf C</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.940</td>\n",
       "      <td>0.862</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>decaf C -&gt; decaf A</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.802</td>\n",
       "      <td>0.878</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>decaf C -&gt; decaf W</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.802</td>\n",
       "      <td>0.886</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>decaf C -&gt; decaf D</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.808</td>\n",
       "      <td>0.924</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              dataset  $\\pi$    p-W   p-GW\n",
       "0    surf C -> surf C    0.1  0.900  0.864\n",
       "1    surf C -> surf A    0.1  0.816  0.870\n",
       "2    surf C -> surf W    0.1  0.822  0.862\n",
       "3    surf C -> surf D    0.1  0.800  0.880\n",
       "4  decaf C -> decaf C    0.1  0.940  0.862\n",
       "5  decaf C -> decaf A    0.1  0.802  0.878\n",
       "6  decaf C -> decaf W    0.1  0.802  0.886\n",
       "7  decaf C -> decaf D    0.1  0.808  0.924"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_caltech_diff_domains = {'dataset':['surf C -> surf C', 'surf C -> surf A', 'surf C -> surf W', 'surf C -> surf D', 'decaf C -> decaf C', 'decaf C -> decaf A', 'decaf C -> decaf W', 'decaf C -> decaf D'], \n",
    "               '$\\pi$': [0.1]*8,\n",
    "               'p-W': avg_caltech_emd_groups,\n",
    "               'p-GW': avg_caltech_gw_groups\n",
    "              }\n",
    "results_caltech_diff_domains = pd.DataFrame(data=results_caltech_diff_domains)\n",
    "results_caltech_diff_domains"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Caltech dataset - PU on different feature spaces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_caltech_gw_groups_surf = []\n",
    "avg_caltech_gw_groups_decaf = []\n",
    "domain_u = ['surf_Caltech10', 'surf_amazon', 'surf_webcam', 'surf_dslr']\n",
    "for d in domain_u:\n",
    "    perfs_caltech_surf, perfs_list_caltech_surf = pgw.compute_perf_pgw('decaf_caltech', d, n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "    avg_caltech_gw_groups_decaf.append(perfs_caltech_surf['pgw'])\n",
    "domain_u = ['decaf_caltech', 'decaf_amazon', 'decaf_webcam', 'decaf_dslr']\n",
    "for d in domain_u:\n",
    "    perfs_caltech_surf, perfs_list_caltech_surf = pgw.compute_perf_pgw('surf_Caltech10', d, n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "    avg_caltech_gw_groups_surf.append(perfs_caltech_surf['pgw'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tab. 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>domains</th>\n",
       "      <th>surf C -&gt; decaf *</th>\n",
       "      <th>decaf C -&gt; surf *</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>*= C</td>\n",
       "      <td>0.874</td>\n",
       "      <td>0.860</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>*= A</td>\n",
       "      <td>0.940</td>\n",
       "      <td>0.866</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>*= W</td>\n",
       "      <td>0.944</td>\n",
       "      <td>0.894</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>*= D</td>\n",
       "      <td>0.966</td>\n",
       "      <td>0.866</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  domains  surf C -> decaf *  decaf C -> surf *\n",
       "0    *= C              0.874              0.860\n",
       "1    *= A              0.940              0.866\n",
       "2    *= W              0.944              0.894\n",
       "3    *= D              0.966              0.866"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_caltech_diff_features = {'domains':['*= C', '*= A', '*= W', '*= D'], \n",
    "               'surf C -> decaf *': avg_caltech_gw_groups_surf,\n",
    "               'decaf C -> surf *': avg_caltech_gw_groups_decaf,\n",
    "              }\n",
    "results_caltech_diff_features = pd.DataFrame(data=results_caltech_diff_features)\n",
    "results_caltech_diff_features"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "af8a8dc5d9d5fe20562b2a82edc81ddace8b482cbd5fbca7942e2cde7b4f1d64"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
