{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# This notebook contains the figures relatives to the Table 1 and 2 of the paper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The performances may slightly differ from the figures reported in the paper as we did not fix the seed to draw the datasets. However, it does not change any conclusions!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Library/Java/JavaVirtualMachines/jdk1.8.0_161.jdk/Contents/Home/jre/lib/jli/libjli.dylib\n"
     ]
    }
   ],
   "source": [
    "import utils\n",
    "import numpy as np\n",
    "from IPython.display import HTML, display, Markdown\n",
    "import pandas as pd\n",
    "\n",
    "import ot\n",
    "import partial_gw as pgw\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## UCI dataset - SCAR scenario"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_unl = 800\n",
    "n_pos = 800\n",
    "nb_reps = 10\n",
    "nb_dummies = 10\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Partial-W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "prior = 0.518\n",
    "perfs_mushrooms_pw, perfs_list_mushrooms_pw = pgw.compute_perf_emd('mushrooms', 'mushrooms', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_mushrooms_pw =  perfs_mushrooms_pw['emd_groups']\n",
    "avg_perfs_mushrooms_pw_time = np.mean(perfs_mushrooms_pw['time'])\n",
    "\n",
    "prior = 0.786\n",
    "perfs_shuttle_pw, perfs_list_shuttle_pw = pgw.compute_perf_emd('shuttle', 'shuttle', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_shuttle_pw =  perfs_shuttle_pw['emd_groups']\n",
    "avg_perfs_shuttle_pw_time = np.mean(perfs_shuttle_pw['time'])\n",
    "\n",
    "prior = 0.898\n",
    "perfs_pageblocks_pw, perfs_list_pageblocks_pw = pgw.compute_perf_emd('pageblocks', 'pageblocks', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_pageblocks_pw =  perfs_pageblocks_pw['emd_groups']\n",
    "avg_perfs_pageblocks_pw_time = np.mean(perfs_pageblocks_pw['time'])\n",
    "\n",
    "prior = 0.167\n",
    "perfs_usps_pw, perfs_list_usps_pw = pgw.compute_perf_emd('usps', 'usps', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_usps_pw =  perfs_usps_pw['emd_groups']\n",
    "avg_perfs_usps_pw_time = np.mean(perfs_usps_pw['time'])\n",
    "\n",
    "prior = 0.658\n",
    "perfs_connect4_pw, perfs_list_connect4_pw = pgw.compute_perf_emd('connect-4', 'connect-4', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_connect4_pw =  perfs_connect4_pw['emd_groups']\n",
    "avg_perfs_connect4_pw_time = np.mean(perfs_connect4_pw['time'])\n",
    "\n",
    "prior = 0.394\n",
    "perfs_spambase_pw, perfs_list_spambase_pw = pgw.compute_perf_emd('spambase', 'spambase', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_spambase_pw =  perfs_spambase_pw['emd_groups']\n",
    "avg_perfs_spambase_pw_time = np.mean(perfs_spambase_pw['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Partial-GW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/kaiyizhang/opt/anaconda3/envs/OT-GAN/lib/python3.6/site-packages/ot/lp/__init__.py:329: UserWarning: numItermax reached before optimality. Try to increase numItermax.\n",
      "  result_code_string = check_result(result_code)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error in the EMD!!!!!!!\n"
     ]
    }
   ],
   "source": [
    "prior = 0.518\n",
    "perfs_mushrooms_gw, perfs_list_mushrooms_gw = pgw.compute_perf_pgw('mushrooms', 'mushrooms', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_mushrooms_gw =  perfs_mushrooms_gw['pgw']\n",
    "avg_perfs_mushrooms_gw_time = np.mean(perfs_mushrooms_gw['time'])\n",
    "\n",
    "prior = 0.786\n",
    "perfs_shuttle_gw, perfs_list_shuttle_gw = pgw.compute_perf_pgw('shuttle', 'shuttle', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_shuttle_gw =  perfs_shuttle_gw['pgw']\n",
    "avg_perfs_shuttle_gw_time = np.mean(perfs_shuttle_gw['time'])\n",
    "\n",
    "prior = 0.898\n",
    "perfs_pageblocks_gw, perfs_list_pageblocks_gw = pgw.compute_perf_pgw('pageblocks', 'pageblocks', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_pageblocks_gw =  perfs_pageblocks_gw['pgw']\n",
    "avg_perfs_pageblocks_gw_time = np.mean(perfs_pageblocks_gw['time'])\n",
    "\n",
    "prior = 0.167\n",
    "perfs_usps_gw, perfs_list_usps_gw = pgw.compute_perf_pgw('usps', 'usps', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_usps_gw =  perfs_usps_gw['pgw']\n",
    "avg_perfs_usps_gw_time = np.mean(perfs_usps_gw['time'])\n",
    "\n",
    "prior = 0.658\n",
    "perfs_connect4_gw, perfs_list_connect4_gw = pgw.compute_perf_pgw('connect-4', 'connect-4', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_connect4_gw =  perfs_connect4_gw['pgw']\n",
    "avg_perfs_connect4_gw_time = np.mean(perfs_connect4_gw['time'])\n",
    "\n",
    "prior = 0.394\n",
    "perfs_spambase_gw, perfs_list_spambase_gw = pgw.compute_perf_pgw('spambase', 'spambase', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_spambase_gw =  perfs_spambase_gw['pgw']\n",
    "avg_perfs_spambase_gw_time = np.mean(perfs_spambase_gw['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OTP"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "w prior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "prior = 0.518\n",
    "perfs_mushrooms_otp, perfs_list_mushrooms_otp, detected_priors_mushrooms = pgw.compute_perf_GTOT('mushrooms', 'mushrooms', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_mushrooms_otp =  perfs_mushrooms_otp['GTOT']\n",
    "avg_perfs_mushrooms_otp_time = np.mean(perfs_mushrooms_otp['time'])\n",
    "avg_detected_priors_mushrooms = np.mean(detected_priors_mushrooms)\n",
    "\n",
    "prior = 0.786\n",
    "perfs_shuttle_otp, perfs_list_shuttle_otp, detected_priors_shuttle = pgw.compute_perf_GTOT('shuttle', 'shuttle', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_shuttle_otp =  perfs_shuttle_otp['GTOT']\n",
    "avg_perfs_shuttle_otp_time = np.mean(perfs_shuttle_otp['time'])\n",
    "avg_detected_priors_shuttle = np.mean(detected_priors_shuttle)\n",
    "\n",
    "prior = 0.898\n",
    "perfs_pageblocks_otp, perfs_list_pageblocks_otp, detected_priors_pageblocks = pgw.compute_perf_GTOT('pageblocks', 'pageblocks', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_pageblocks_otp =  perfs_pageblocks_otp['GTOT']\n",
    "avg_perfs_pageblocks_otp_time = np.mean(perfs_pageblocks_otp['time'])\n",
    "avg_detected_priors_pageblocks = np.mean(detected_priors_pageblocks)\n",
    "\n",
    "prior = 0.167\n",
    "perfs_usps_otp, perfs_list_usps_otp, detected_priors_usps = pgw.compute_perf_GTOT('usps', 'usps', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_usps_otp =  perfs_usps_otp['GTOT']\n",
    "avg_perfs_usps_otp_time = np.mean(perfs_usps_otp['time'])\n",
    "avg_detected_priors_usps = np.mean(detected_priors_usps)\n",
    "\n",
    "prior = 0.658\n",
    "perfs_connect4_otp, perfs_list_connect4_otp, detected_priors_connect4 = pgw.compute_perf_GTOT('connect-4', 'connect-4', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_connect4_otp =  perfs_connect4_otp['GTOT']\n",
    "avg_perfs_connect4_otp_time = np.mean(perfs_connect4_otp['time'])\n",
    "avg_detected_priors_connect4 = np.mean(detected_priors_connect4)\n",
    "\n",
    "prior = 0.394\n",
    "perfs_spambase_otp, perfs_list_spambase_otp, detected_priors_spambase = pgw.compute_perf_GTOT('spambase', 'spambase', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_spambase_otp =  perfs_spambase_otp['GTOT']\n",
    "avg_perfs_spambase_otp_time = np.mean(perfs_spambase_otp['time'])\n",
    "avg_detected_priors_spambase = np.mean(detected_priors_spambase)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "wo prior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "prior = 0.518\n",
    "perfs_mushrooms_otp_, perfs_list_mushrooms_otp_, detected_priors_mushrooms_ = pgw.compute_perf_GTOT('mushrooms', 'mushrooms', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_mushrooms_otp_ =  perfs_mushrooms_otp_['GTOT']\n",
    "avg_perfs_mushrooms_otp_time_ = np.mean(perfs_mushrooms_otp_['time'])\n",
    "\n",
    "prior = 0.786\n",
    "perfs_shuttle_otp_, perfs_list_shuttle_otp_, detected_priors_shuttle_ = pgw.compute_perf_GTOT('shuttle', 'shuttle', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_shuttle_otp_ =  perfs_shuttle_otp_['GTOT']\n",
    "avg_perfs_shuttle_otp_time_ = np.mean(perfs_shuttle_otp_['time'])\n",
    "\n",
    "prior = 0.898\n",
    "perfs_pageblocks_otp_, perfs_list_pageblocks_otp_, detected_priors_pageblocks_ = pgw.compute_perf_GTOT('pageblocks', 'pageblocks', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_pageblocks_otp_ =  perfs_pageblocks_otp_['GTOT']\n",
    "avg_perfs_pageblocks_otp_time_ = np.mean(perfs_pageblocks_otp_['time'])\n",
    "\n",
    "prior = 0.167\n",
    "perfs_usps_otp_, perfs_list_usps_otp_, detected_priors_usps_ = pgw.compute_perf_GTOT('usps', 'usps', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_usps_otp_ =  perfs_usps_otp_['GTOT']\n",
    "avg_perfs_usps_otp_time_ = np.mean(perfs_usps_otp_['time'])\n",
    "\n",
    "prior = 0.658\n",
    "perfs_connect4_otp_, perfs_list_connect4_otp_, detected_priors_connect4_ = pgw.compute_perf_GTOT('connect-4', 'connect-4', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_connect4_otp_ =  perfs_connect4_otp_['GTOT']\n",
    "avg_perfs_connect4_otp_time_ = np.mean(perfs_connect4_otp_['time'])\n",
    "\n",
    "prior = 0.394\n",
    "perfs_spambase_otp_, perfs_list_spambase_otp_, detected_priors_spambase_ = pgw.compute_perf_GTOT('spambase', 'spambase', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_spambase_otp_ =  perfs_spambase_otp_['GTOT']\n",
    "avg_perfs_spambase_otp_time_ = np.mean(perfs_spambase_otp_['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tab.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>$\\pi$</th>\n",
       "      <th>p-W</th>\n",
       "      <th>p-GW</th>\n",
       "      <th>ot-profile</th>\n",
       "      <th>$\\hat{\\pi}$ estimated by OT profile</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>mushrooms</td>\n",
       "      <td>0.518</td>\n",
       "      <td>0.96100</td>\n",
       "      <td>0.95250</td>\n",
       "      <td>0.994000</td>\n",
       "      <td>0.512125</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>shuttle</td>\n",
       "      <td>0.786</td>\n",
       "      <td>0.96275</td>\n",
       "      <td>0.95525</td>\n",
       "      <td>0.981750</td>\n",
       "      <td>0.786942</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>pageblocks</td>\n",
       "      <td>0.898</td>\n",
       "      <td>0.92375</td>\n",
       "      <td>0.90575</td>\n",
       "      <td>0.926750</td>\n",
       "      <td>0.893845</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>usps</td>\n",
       "      <td>0.167</td>\n",
       "      <td>0.98600</td>\n",
       "      <td>0.95725</td>\n",
       "      <td>0.967250</td>\n",
       "      <td>0.165422</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>connect-4</td>\n",
       "      <td>0.658</td>\n",
       "      <td>0.60975</td>\n",
       "      <td>0.55375</td>\n",
       "      <td>0.682875</td>\n",
       "      <td>0.859036</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>spambase</td>\n",
       "      <td>0.394</td>\n",
       "      <td>0.79850</td>\n",
       "      <td>0.71075</td>\n",
       "      <td>0.713375</td>\n",
       "      <td>0.139417</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      dataset  $\\pi$      p-W     p-GW  ot-profile  \\\n",
       "0   mushrooms  0.518  0.96100  0.95250    0.994000   \n",
       "1     shuttle  0.786  0.96275  0.95525    0.981750   \n",
       "2  pageblocks  0.898  0.92375  0.90575    0.926750   \n",
       "3        usps  0.167  0.98600  0.95725    0.967250   \n",
       "4   connect-4  0.658  0.60975  0.55375    0.682875   \n",
       "5    spambase  0.394  0.79850  0.71075    0.713375   \n",
       "\n",
       "   $\\hat{\\pi}$ estimated by OT profile  \n",
       "0                             0.512125  \n",
       "1                             0.786942  \n",
       "2                             0.893845  \n",
       "3                             0.165422  \n",
       "4                             0.859036  \n",
       "5                             0.139417  "
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_UCI = {'dataset':['mushrooms', 'shuttle', 'pageblocks', 'usps', 'connect-4', 'spambase'], \n",
    "               '$\\\\pi$': [0.518, 0.786, 0.898, 0.167, 0.658, 0.394],\n",
    "               'p-W': [avg_perfs_mushrooms_pw, avg_perfs_shuttle_pw, avg_perfs_pageblocks_pw, avg_perfs_usps_pw, avg_perfs_connect4_pw, avg_perfs_spambase_pw],\n",
    "               'p-GW': [avg_perfs_mushrooms_gw, avg_perfs_shuttle_gw, avg_perfs_pageblocks_gw, avg_perfs_usps_gw, avg_perfs_connect4_gw, avg_perfs_spambase_gw],\n",
    "               'ot-profile': [avg_perfs_mushrooms_otp, avg_perfs_shuttle_otp, avg_perfs_pageblocks_otp, avg_perfs_usps_otp, avg_perfs_connect4_otp, avg_perfs_spambase_otp],\n",
    "               '$\\\\hat{\\\\pi}$ estimated by OT profile': [avg_detected_priors_mushrooms, avg_detected_priors_shuttle, avg_detected_priors_pageblocks, avg_detected_priors_usps, avg_detected_priors_connect4, avg_detected_priors_spambase]\n",
    "              }\n",
    "results_UCI = pd.DataFrame(data=results_UCI)\n",
    "results_UCI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "myscale = lambda x : round(x*100,1)\n",
    "results_UCI[['p-W','p-GW','ot-profile']] = results_UCI[['p-W','p-GW','ot-profile']].apply(myscale,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrr}\n",
      "\\toprule\n",
      "    dataset &  \\$\\textbackslash pi\\$ &  ot-profile \\\\\n",
      "\\midrule\n",
      "  mushrooms &  0.518 &        99.8 \\\\\n",
      "    shuttle &  0.786 &        97.9 \\\\\n",
      " pageblocks &  0.898 &        93.1 \\\\\n",
      "       usps &  0.167 &        97.3 \\\\\n",
      "  connect-4 &  0.658 &        74.7 \\\\\n",
      "   spambase &  0.394 &        73.4 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results_UCI_wprior = {'dataset':['mushrooms', 'shuttle', 'pageblocks', 'usps', 'connect-4', 'spambase'], \n",
    "               '$\\\\pi$': [0.518, 0.786, 0.898, 0.167, 0.658, 0.394],\n",
    "               'ot-profile': [avg_perfs_mushrooms_otp_, avg_perfs_shuttle_otp_, avg_perfs_pageblocks_otp_, avg_perfs_usps_otp_, avg_perfs_connect4_otp_, avg_perfs_spambase_otp_]\n",
    "              }\n",
    "results_UCI_wprior = pd.DataFrame(data=results_UCI_wprior)\n",
    "myscale = lambda x : round(x*100,1)\n",
    "results_UCI_wprior[['ot-profile']] = results_UCI_wprior[['ot-profile']].apply(myscale,axis=1)\n",
    "results_UCI_wprior.to_csv('results_UCI.csv')\n",
    "print(results_UCI_wprior.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lr}\n",
      "\\toprule\n",
      "    dataset &  ot-profile \\\\\n",
      "\\midrule\n",
      "  mushrooms &         0.6 \\\\\n",
      "    shuttle &         0.5 \\\\\n",
      " pageblocks &         0.3 \\\\\n",
      "       usps &         8.1 \\\\\n",
      "  connect-4 &         1.9 \\\\\n",
      "   spambase &         0.8 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "times_UCI_wprrior = {'dataset':['mushrooms', 'shuttle', 'pageblocks', 'usps', 'connect-4', 'spambase'], \n",
    "               'ot-profile': [avg_perfs_mushrooms_otp_time_, avg_perfs_shuttle_otp_time_, avg_perfs_pageblocks_otp_time_, avg_perfs_usps_otp_time_, avg_perfs_connect4_otp_time_, avg_perfs_spambase_otp_time_]\n",
    "              }\n",
    "times_UCI_wprrior = pd.DataFrame(data=times_UCI_wprrior)\n",
    "myscale = lambda x : round(x/nb_reps,1)\n",
    "times_UCI_wprrior[['ot-profile']] = times_UCI_wprrior[['ot-profile']].apply(myscale,axis=1)\n",
    "times_UCI_wprrior.to_csv('times_UCI_wprrior.csv')\n",
    "print(times_UCI_wprrior.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Colored MNIST dataset - SAR scenario"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_unl = 800\n",
    "n_pos = 800\n",
    "nb_reps = 10\n",
    "nb_dummies = 10\n",
    "prior = 0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Partial-W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "perfs_mnist_pw, perfs_list_mnist_pw = pgw.compute_perf_emd('mnist', 'mnist', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_mnist_pw =  perfs_mnist_pw['emd_groups']\n",
    "avg_perfs_mnist_pw_time = np.mean(perfs_mnist_pw['time'])\n",
    "\n",
    "perfs_color_mnist_pw, perfs_list_color_mnist_pw = pgw.compute_perf_emd('mnist_color_change_p', 'mnist_color_change_u', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_color_mnist_pw =  perfs_color_mnist_pw['emd_groups']\n",
    "avg_perfs_color_mnist_pw_time = np.mean(perfs_color_mnist_pw['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Partial-GW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "perfs_mnist_gw, perfs_list_mnist_gw = pgw.compute_perf_pgw('mnist', 'mnist', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_mnist_gw =  perfs_mnist_gw['pgw']\n",
    "avg_perfs_mnist_gw_time = np.mean(perfs_mnist_gw['time'])\n",
    "\n",
    "perfs_color_mnist_gw, perfs_list_color_mnist_gw = pgw.compute_perf_pgw('mnist_color_change_p', 'mnist_color_change_u', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "avg_perfs_color_mnist_gw =  perfs_color_mnist_gw['pgw']\n",
    "avg_perfs_color_mnist_gw_time = np.mean(perfs_color_mnist_gw['time'])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OT-profile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "perfs_mnist_otp, perfs_list_mnist_otp, detected_priors_mnist = pgw.compute_perf_GTOT('mnist', 'mnist', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_mnist_otp =  perfs_mnist_otp['GTOT']\n",
    "avg_detected_priors_mnist = np.mean(detected_priors_mnist)\n",
    "avg_perfs_mnist_otp_time = np.mean(perfs_mnist_otp['time'])\n",
    "\n",
    "# perfs_mnist, perfs_list_mnist = pgw.compute_perf_emd('mnist_color_change_p', 'mnist_color_change_u', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "perfs_color_mnist_otp, perfs_list_color_mnist_otp, detected_priors_color_mnist = pgw.compute_perf_GTOT('mnist_color_change_p', 'mnist_color_change_p', n_unl, n_pos, prior, nb_reps)\n",
    "avg_perfs_color_mnist_otp =  perfs_color_mnist_otp['GTOT']\n",
    "avg_detected_priors_color_mnist = np.mean(detected_priors_color_mnist)\n",
    "avg_perfs_color_mnist_otp_time = np.mean(perfs_color_mnist_otp['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "perfs_mnist_otp_, perfs_list_mnist_otp_, detected_priors_mnist_ = pgw.compute_perf_GTOT('mnist', 'mnist', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_mnist_otp_ =  perfs_mnist_otp_['GTOT']\n",
    "avg_perfs_mnist_otp_time_ = np.mean(perfs_mnist_otp_['time'])\n",
    "\n",
    "# perfs_mnist, perfs_list_mnist = pgw.compute_perf_emd('mnist_color_change_p', 'mnist_color_change_u', n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "perfs_color_mnist_otp_, perfs_list_color_mnist_otp_, detected_priors_color_mnist_ = pgw.compute_perf_GTOT('mnist_color_change_p', 'mnist_color_change_p', n_unl, n_pos, prior, nb_reps, False)\n",
    "avg_perfs_color_mnist_otp_ =  perfs_color_mnist_otp_['GTOT']\n",
    "avg_perfs_color_mnist_otp_time_ = np.mean(perfs_color_mnist_otp_['time'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tab. 1 (continued)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>$\\pi$</th>\n",
       "      <th>p-W</th>\n",
       "      <th>p-GW</th>\n",
       "      <th>OT-profile</th>\n",
       "      <th>detected priors from OT-profile</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>mnist</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.9915</td>\n",
       "      <td>0.984</td>\n",
       "      <td>0.960125</td>\n",
       "      <td>0.109025</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>colored mnist</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.9160</td>\n",
       "      <td>0.975</td>\n",
       "      <td>0.957500</td>\n",
       "      <td>0.117403</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         dataset  $\\pi$     p-W   p-GW  OT-profile  \\\n",
       "0          mnist    0.1  0.9915  0.984    0.960125   \n",
       "1  colored mnist    0.1  0.9160  0.975    0.957500   \n",
       "\n",
       "   detected priors from OT-profile  \n",
       "0                         0.109025  \n",
       "1                         0.117403  "
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_MNIST = {'dataset':['mnist', 'colored mnist'], \n",
    "               '$\\pi$': [0.1, 0.1],\n",
    "               'p-W': [avg_perfs_mnist_pw, avg_perfs_color_mnist_pw],\n",
    "               'p-GW': [avg_perfs_mnist_gw, avg_perfs_color_mnist_gw],\n",
    "               'OT-profile': [avg_perfs_mnist_otp, avg_perfs_color_mnist_otp],\n",
    "               'detected priors from OT-profile': [avg_detected_priors_mnist, avg_detected_priors_color_mnist]\n",
    "              }\n",
    "results_MNIST = pd.DataFrame(data=results_MNIST)\n",
    "results_MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrr}\n",
      "\\toprule\n",
      "       dataset &  \\$\\textbackslash pi\\$ &  OT-profile \\\\\n",
      "\\midrule\n",
      "         mnist &    0.1 &        99.3 \\\\\n",
      " colored mnist &    0.1 &        99.3 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results_MNIST_wprior = {'dataset':['mnist', 'colored mnist'], \n",
    "               '$\\pi$': [0.1, 0.1],\n",
    "               'OT-profile': [avg_perfs_mnist_otp_, avg_perfs_color_mnist_otp_]\n",
    "              }\n",
    "results_MNIST_wprior = pd.DataFrame(data=results_MNIST_wprior)\n",
    "myscale = lambda x : round(x*100,1)\n",
    "results_MNIST_wprior[['OT-profile']] = results_MNIST_wprior[['OT-profile']].apply(myscale,axis=1)\n",
    "results_MNIST_wprior.to_csv('results_MNIST_wprior.csv')\n",
    "print(results_MNIST_wprior.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>p-W (s)</th>\n",
       "      <th>p-GW (s)</th>\n",
       "      <th>OT-profile (s)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>mnist</td>\n",
       "      <td>17.576264</td>\n",
       "      <td>822.676631</td>\n",
       "      <td>25.683767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>colored mnist</td>\n",
       "      <td>19.491232</td>\n",
       "      <td>818.308240</td>\n",
       "      <td>29.046227</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         dataset    p-W (s)    p-GW (s)  OT-profile (s)\n",
       "0          mnist  17.576264  822.676631       25.683767\n",
       "1  colored mnist  19.491232  818.308240       29.046227"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "time_MNIST = {'dataset':['mnist', 'colored mnist'], \n",
    "               'p-W (s)': [avg_perfs_mnist_pw_time, avg_perfs_color_mnist_pw_time],\n",
    "               'p-GW (s)': [avg_perfs_mnist_gw_time, avg_perfs_color_mnist_gw_time],\n",
    "               'OT-profile (s)': [avg_perfs_mnist_otp_time, avg_perfs_color_mnist_otp_time],\n",
    "              }\n",
    "time_MNIST = pd.DataFrame(data=time_MNIST)\n",
    "time_MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lr}\n",
      "\\toprule\n",
      "       dataset &  ot-profile \\\\\n",
      "\\midrule\n",
      "         mnist &         3.1 \\\\\n",
      " colored mnist &         3.2 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "time_MNIST_wprior = {'dataset':['mnist', 'colored mnist'], \n",
    "               'ot-profile': [avg_perfs_mnist_otp_time_, avg_perfs_color_mnist_otp_time_],\n",
    "              }\n",
    "time_MNIST_wprior = pd.DataFrame(data=time_MNIST_wprior)\n",
    "myscale = lambda x : round(x/nb_reps,1)\n",
    "time_MNIST_wprior[['ot-profile']] = time_MNIST_wprior[['ot-profile']].apply(myscale,axis=1)\n",
    "time_MNIST_wprior.to_csv('times_MNIST.csv')\n",
    "print(time_MNIST_wprior.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_MNIST.to_csv('results_MNIST')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Caltech dataset - PU on different domains"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_unl = 100\n",
    "n_pos = 100\n",
    "nb_reps = 10\n",
    "nb_dummies = 10\n",
    "prior = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_caltech_emd_groups = []\n",
    "avg_caltech_gw_groups = []\n",
    "domain_u = ['surf_Caltech10', 'surf_amazon', 'surf_webcam', 'surf_dslr']\n",
    "for d in domain_u:\n",
    "    perfs_caltech_surf, perfs_list_caltech_surf = pgw.compute_perf_emd('surf_Caltech10', d, n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "    avg_caltech_emd_groups.append(perfs_caltech_surf['emd_groups'])\n",
    "    perfs_caltech_surf, perfs_list_caltech_surf = pgw.compute_perf_pgw('surf_Caltech10', d, n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "    avg_caltech_gw_groups.append(perfs_caltech_surf['pgw'])\n",
    "domain_u = ['decaf_caltech', 'decaf_amazon', 'decaf_webcam', 'decaf_dslr']\n",
    "for d in domain_u:\n",
    "    perfs_caltech_surf, perfs_list_caltech_surf = pgw.compute_perf_emd('decaf_caltech', d, n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "    avg_caltech_emd_groups.append(perfs_caltech_surf['emd_groups'])\n",
    "    perfs_caltech_surf, perfs_list_caltech_surf = pgw.compute_perf_pgw('decaf_caltech', d, n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "    avg_caltech_gw_groups.append(perfs_caltech_surf['pgw'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tab.1 (continued)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>$\\pi$</th>\n",
       "      <th>p-W</th>\n",
       "      <th>p-GW</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>surf C -&gt; surf C</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.900</td>\n",
       "      <td>0.864</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>surf C -&gt; surf A</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.816</td>\n",
       "      <td>0.870</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>surf C -&gt; surf W</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.822</td>\n",
       "      <td>0.862</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>surf C -&gt; surf D</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.800</td>\n",
       "      <td>0.880</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>decaf C -&gt; decaf C</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.940</td>\n",
       "      <td>0.862</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>decaf C -&gt; decaf A</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.802</td>\n",
       "      <td>0.878</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>decaf C -&gt; decaf W</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.802</td>\n",
       "      <td>0.886</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>decaf C -&gt; decaf D</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.808</td>\n",
       "      <td>0.924</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              dataset  $\\pi$    p-W   p-GW\n",
       "0    surf C -> surf C    0.1  0.900  0.864\n",
       "1    surf C -> surf A    0.1  0.816  0.870\n",
       "2    surf C -> surf W    0.1  0.822  0.862\n",
       "3    surf C -> surf D    0.1  0.800  0.880\n",
       "4  decaf C -> decaf C    0.1  0.940  0.862\n",
       "5  decaf C -> decaf A    0.1  0.802  0.878\n",
       "6  decaf C -> decaf W    0.1  0.802  0.886\n",
       "7  decaf C -> decaf D    0.1  0.808  0.924"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_caltech_diff_domains = {'dataset':['surf C -> surf C', 'surf C -> surf A', 'surf C -> surf W', 'surf C -> surf D', 'decaf C -> decaf C', 'decaf C -> decaf A', 'decaf C -> decaf W', 'decaf C -> decaf D'], \n",
    "               '$\\pi$': [0.1]*8,\n",
    "               'p-W': avg_caltech_emd_groups,\n",
    "               'p-GW': avg_caltech_gw_groups\n",
    "              }\n",
    "results_caltech_diff_domains = pd.DataFrame(data=results_caltech_diff_domains)\n",
    "results_caltech_diff_domains"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Caltech dataset - PU on different feature spaces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_caltech_gw_groups_surf = []\n",
    "avg_caltech_gw_groups_decaf = []\n",
    "domain_u = ['surf_Caltech10', 'surf_amazon', 'surf_webcam', 'surf_dslr']\n",
    "for d in domain_u:\n",
    "    perfs_caltech_surf, perfs_list_caltech_surf = pgw.compute_perf_pgw('decaf_caltech', d, n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "    avg_caltech_gw_groups_decaf.append(perfs_caltech_surf['pgw'])\n",
    "domain_u = ['decaf_caltech', 'decaf_amazon', 'decaf_webcam', 'decaf_dslr']\n",
    "for d in domain_u:\n",
    "    perfs_caltech_surf, perfs_list_caltech_surf = pgw.compute_perf_pgw('surf_Caltech10', d, n_unl, n_pos, prior, nb_reps, nb_dummies)\n",
    "    avg_caltech_gw_groups_surf.append(perfs_caltech_surf['pgw'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tab. 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>domains</th>\n",
       "      <th>surf C -&gt; decaf *</th>\n",
       "      <th>decaf C -&gt; surf *</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>*= C</td>\n",
       "      <td>0.874</td>\n",
       "      <td>0.860</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>*= A</td>\n",
       "      <td>0.940</td>\n",
       "      <td>0.866</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>*= W</td>\n",
       "      <td>0.944</td>\n",
       "      <td>0.894</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>*= D</td>\n",
       "      <td>0.966</td>\n",
       "      <td>0.866</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  domains  surf C -> decaf *  decaf C -> surf *\n",
       "0    *= C              0.874              0.860\n",
       "1    *= A              0.940              0.866\n",
       "2    *= W              0.944              0.894\n",
       "3    *= D              0.966              0.866"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_caltech_diff_features = {'domains':['*= C', '*= A', '*= W', '*= D'], \n",
    "               'surf C -> decaf *': avg_caltech_gw_groups_surf,\n",
    "               'decaf C -> surf *': avg_caltech_gw_groups_decaf,\n",
    "              }\n",
    "results_caltech_diff_features = pd.DataFrame(data=results_caltech_diff_features)\n",
    "results_caltech_diff_features"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "af8a8dc5d9d5fe20562b2a82edc81ddace8b482cbd5fbca7942e2cde7b4f1d64"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
