{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "13404891",
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "import numpy as np\n",
    "import csv\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "plt.style.use(\"ggplot\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c82d63c2",
   "metadata": {},
   "source": [
    "# IMC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "97385607",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_csv(path_ari, path_nmi, n_runs, optimise_iter):\n",
    "    assert 'ari' in path_ari, f\"wrong path for ARI, {ari}\"\n",
    "    ARI = np.zeros((n_runs, optimise_iter))\n",
    "    header = []\n",
    "\n",
    "    with open(path_ari, 'r', newline='') as csvfile:\n",
    "        wandb_reader = csv.reader(csvfile, delimiter=',', quotechar='|')\n",
    "        for i, row in enumerate(wandb_reader):\n",
    "            if i==0:\n",
    "                continue\n",
    "            header.append(row[1])\n",
    "            ARI[i-1,:] = [float(k) for k in row[2:]]\n",
    "\n",
    "    assert 'nmi' in path_nmi, \"wrong path for NMI, {nmi}\"\n",
    "    NMI = np.zeros((n_runs, optimise_iter))\n",
    "\n",
    "    with open(path_nmi, 'r', newline='') as csvfile:\n",
    "        wandb_reader = csv.reader(csvfile, delimiter=',', quotechar='|')\n",
    "        for i, row in enumerate(wandb_reader):\n",
    "            if i==0:\n",
    "                continue\n",
    "            NMI[i-1,:] = [float(k) for k in row[2:]]\n",
    "        \n",
    "    return header, ARI, NMI\n",
    "\n",
    "def load_csv_nan(path_ari, path_nmi, n_runs, optimise_iter):\n",
    "    assert 'ari' in path_ari, \"wrong path for ARI, {ari}\"\n",
    "    ARI = np.zeros((n_runs, optimise_iter))\n",
    "    header = []\n",
    "\n",
    "    with open(path_ari, 'r', newline='') as csvfile:\n",
    "        wandb_reader = csv.reader(csvfile, delimiter=',', quotechar='|')\n",
    "        for i, row in enumerate(wandb_reader):\n",
    "            if i==0:\n",
    "                continue\n",
    "            header.append(row[1])\n",
    "            ARI[i-1,:] = [np.nan if k == \"\" else float(k) for k in row[2:]]\n",
    "\n",
    "    assert 'nmi' in path_nmi, f\"wrong path for NMI, {nmi}\"\n",
    "    NMI = np.zeros((n_runs, optimise_iter))\n",
    "\n",
    "    with open(path_nmi, 'r', newline='') as csvfile:\n",
    "        wandb_reader = csv.reader(csvfile, delimiter=',', quotechar='|')\n",
    "        for i, row in enumerate(wandb_reader):\n",
    "            if i==0:\n",
    "                continue\n",
    "            NMI[i-1,:] = [np.nan if k == \"\" else float(k) for k in row[2:]]\n",
    "        \n",
    "    return header, ARI, NMI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "45f14eba",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = 'meta_obj_csv/'\n",
    "path_ari = path + 'final_imc_manatee_ari_98_runs_35_iter.csv'\n",
    "path_nmi = path + 'final_imc_manatee_nmi_98_runs_35_iter.csv'\n",
    "_, imc_ARI_manatee_all, imc_NMI_manatee_all = load_csv(path_ari, path_nmi, 98, 35)\n",
    "\n",
    "path_ari = path + 'imc_ucb_scal_exhaust_latest_ari_98_runs_35_iter.csv'\n",
    "path_nmi = path + 'imc_ucb_scal_exhaust_latest_nmi_98_runs_35_iter.csv'\n",
    "_, imc_ARI_ucb_scal_exhaust_all, imc_NMI_ucb_scal_exhaust_all = load_csv(path_ari, path_nmi, 98, 35)\n",
    "\n",
    "path_ari = path + 'final_imc_random_prob_ari_98_runs_35_iter.csv'\n",
    "path_nmi = path + 'final_imc_random_prob_nmi_98_runs_35_iter.csv'\n",
    "_, imc_ARI_random_prob_all, imc_NMI_random_prob_all = load_csv(path_ari, path_nmi, 98, 35)\n",
    "\n",
    "path_ari = path + 'final_imc_random_loc_ari_98_runs_35_iter.csv'\n",
    "path_nmi = path + 'final_imc_random_loc_nmi_98_runs_35_iter.csv'\n",
    "_, imc_ARI_random_loc_all, imc_NMI_random_loc_all = load_csv(path_ari, path_nmi, 98, 35)\n",
    "\n",
    "path_ari = path + 'imc_qparego_ari_98_runs_35_iter.csv'\n",
    "path_nmi = path + 'imc_qparego_nmi_98_runs_35_iter.csv'\n",
    "_, imc_ARI_qparego_all, imc_NMI_qparego_all = load_csv_nan(path_ari, path_nmi, 98, 35)\n",
    "\n",
    "path_ari = path + 'imc_qnehvi_ari_98_runs_35_iter.csv'\n",
    "path_nmi = path + 'imc_qnehvi_nmi_98_runs_35_iter.csv'\n",
    "_, imc_ARI_qnehvi_all, imc_NMI_qnehvi_all = load_csv_nan(path_ari, path_nmi, 98, 35)\n",
    "\n",
    "path_ari = path + 'imc_usemo_ari_98_runs_35_iter.csv'\n",
    "path_nmi = path + 'imc_usemo_nmi_98_runs_35_iter.csv'\n",
    "_, imc_ARI_usemo_all, imc_NMI_usemo_all = load_csv(path_ari, path_nmi, 98, 35)\n",
    "\n",
    "path_ari = path + 'imc_manatee_no_max_ari_98_runs_35_iter.csv'\n",
    "path_nmi = path + 'imc_manatee_no_max_nmi_98_runs_35_iter.csv'\n",
    "_, imc_ARI_manatee_no_max, imc_NMI_manatee_no_max = load_csv_nan(path_ari, path_nmi, 98, 35)\n",
    "\n",
    "path_ari = path + 'imc_manatee_no_cor_ari_98_runs_35_iter.csv'\n",
    "path_nmi = path + 'imc_manatee_no_cor_nmi_98_runs_35_iter.csv'\n",
    "_, imc_ARI_manatee_no_cor, imc_NMI_manatee_no_cor = load_csv_nan(path_ari, path_nmi, 98, 35)\n",
    "\n",
    "path_ari = path + 'imc_manatee_no_noise_ari_98_runs_35_iter.csv'\n",
    "path_nmi = path + 'imc_manatee_no_noise_nmi_98_runs_35_iter.csv'\n",
    "_, imc_ARI_manatee_no_noise, imc_NMI_manatee_no_noise = load_csv_nan(path_ari, path_nmi, 98, 35)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ef2e63e4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "imc_true_ARI_max = np.max([imc_ARI_manatee_no_max.max(),\n",
    "                       imc_ARI_manatee_no_cor.max(),\n",
    "                       imc_ARI_manatee_no_noise.max(),\n",
    "                       imc_ARI_manatee_all.max(),\n",
    "                       imc_ARI_ucb_scal_exhaust_all.max(),\n",
    "                       imc_ARI_random_prob_all.max(),\n",
    "                       imc_ARI_random_loc_all.max(),\n",
    "                       imc_ARI_usemo_all.max(),\n",
    "                       np.nanmax(imc_ARI_qparego_all),\n",
    "                       np.nanmax(imc_ARI_qnehvi_all)])\n",
    "\n",
    "imc_true_NMI_max = np.max([imc_NMI_manatee_no_max.max(), \n",
    "                       imc_NMI_manatee_no_cor.max(), \n",
    "                       imc_NMI_manatee_no_noise.max(), \n",
    "                       imc_NMI_manatee_all.max(),\n",
    "                       imc_NMI_ucb_scal_exhaust_all.max(),\n",
    "                       imc_NMI_random_prob_all.max(),\n",
    "                       imc_NMI_random_loc_all.max(),\n",
    "                       imc_NMI_usemo_all.max(),\n",
    "                       np.nanmax(imc_NMI_qparego_all),\n",
    "                       np.nanmax(imc_NMI_qnehvi_all)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e9b65fe2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cumulative_regret(true_max, solutions):\n",
    "    n_runs = solutions.shape[0]\n",
    "    cumulative_regret = np.zeros(n_runs)\n",
    "    \n",
    "    for run in range(n_runs):\n",
    "        cumulative_regret[run] = np.nanmean(true_max - solutions[run,:])\n",
    "    return cumulative_regret\n",
    "\n",
    "def get_full_regret(true_max, solutions):\n",
    "    n_runs = solutions.shape[0]\n",
    "    full_regret = np.zeros(n_runs)\n",
    "    \n",
    "    for run in range(n_runs):\n",
    "        full_regret[run] = true_max - solutions[run,:].max()\n",
    "    return full_regret\n",
    "\n",
    "def get_bayes_regret(true_max, solutions):\n",
    "    n_runs = solutions.shape[0]\n",
    "    bayes_regret = np.zeros(n_runs)\n",
    "    \n",
    "    for run in range(n_runs):\n",
    "        bayes_regret[run] = np.nanmean(true_max - np.maximum.accumulate(solutions[run,:]))\n",
    "    return bayes_regret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "50a62e63",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_regret_stats(regrets, name, method, metric):\n",
    "    mean = np.nanmean(regrets)\n",
    "    std = np.nanstd(regrets)\n",
    "    median = np.nanmedian(regrets)\n",
    "    \n",
    "    print(f\"{name} {metric} for {method}: \\\n",
    "    {mean:.3f} +- {std:.3f}, median: {median:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "47595c44",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_regrets(true_max, metaobj_all, metaobj_ucb_scal_exhaust,\n",
    "                    metaobj_no_max, metaobj_no_cor, metaobj_no_noise, \n",
    "                    metaobj_random_prob_all, metaobj_random_loc_all,\n",
    "                    metaobj_qparego_all, metaobj_qnehvi_all, metaobj_usemo_all):\n",
    "    \n",
    "    regrets = {}\n",
    "    regrets[\"cumulative_regret_all\"] = get_cumulative_regret(true_max, metaobj_all)\n",
    "    regrets[\"full_regret_all\"] = get_full_regret(true_max, metaobj_all)\n",
    "    regrets[\"bayes_regret3_all\"] = get_bayes_regret(true_max, metaobj_all)\n",
    "    \n",
    "    regrets[\"cumulative_regret_ucb_scal_exhaust\"] = get_cumulative_regret(true_max, metaobj_ucb_scal_exhaust)\n",
    "    regrets[\"full_regret_ucb_scal_exhaust\"] = get_full_regret(true_max, metaobj_ucb_scal_exhaust)\n",
    "    regrets[\"bayes_regret3_ucb_scal_exhaust\"] = get_bayes_regret(true_max, metaobj_ucb_scal_exhaust)\n",
    "\n",
    "    regrets[\"cumulative_regret_no_max\"] = get_cumulative_regret(true_max, metaobj_no_max)\n",
    "    regrets[\"full_regret_no_max\"] = get_full_regret(true_max, metaobj_no_max)\n",
    "    regrets[\"bayes_regret3_no_max\"] = get_bayes_regret(true_max, metaobj_no_max)\n",
    "\n",
    "    regrets[\"cumulative_regret_no_cor\"] = get_cumulative_regret(true_max, metaobj_no_cor)\n",
    "    regrets[\"full_regret_no_cor\"] = get_full_regret(true_max, metaobj_no_cor)\n",
    "    regrets[\"bayes_regret3_no_cor\"] = get_bayes_regret(true_max, metaobj_no_cor)\n",
    "\n",
    "    regrets[\"cumulative_regret_no_noise\"] = get_cumulative_regret(true_max, metaobj_no_noise)\n",
    "    regrets[\"full_regret_no_noise\"] = get_full_regret(true_max, metaobj_no_noise)\n",
    "    regrets[\"bayes_regret3_no_noise\"] = get_bayes_regret(true_max, metaobj_no_noise)\n",
    "    \n",
    "    regrets[\"cumulative_regret_random_prob_all\"] = get_cumulative_regret(true_max, metaobj_random_prob_all)\n",
    "    regrets[\"full_regret_random_prob_all\"] = get_full_regret(true_max, metaobj_random_prob_all)\n",
    "    regrets[\"bayes_regret3_random_prob_all\"] = get_bayes_regret(true_max, metaobj_random_prob_all)\n",
    "    \n",
    "    regrets[\"cumulative_regret_random_loc_all\"] = get_cumulative_regret(true_max, metaobj_random_loc_all)\n",
    "    regrets[\"full_regret_random_loc_all\"] = get_full_regret(true_max, metaobj_random_loc_all)\n",
    "    regrets[\"bayes_regret3_random_loc_all\"] = get_bayes_regret(true_max, metaobj_random_loc_all)\n",
    "    \n",
    "    regrets[\"cumulative_regret_qparego_all\"] = get_cumulative_regret(true_max, metaobj_qparego_all)\n",
    "    regrets[\"full_regret_qparego_all\"] = get_full_regret(true_max, metaobj_qparego_all)\n",
    "    regrets[\"bayes_regret3_qparego_all\"] = get_bayes_regret(true_max, metaobj_qparego_all)\n",
    "    \n",
    "    regrets[\"cumulative_regret_qnehvi_all\"] = get_cumulative_regret(true_max, metaobj_qnehvi_all)\n",
    "    regrets[\"full_regret_qnehvi_all\"] = get_full_regret(true_max, metaobj_qnehvi_all)\n",
    "    regrets[\"bayes_regret3_qnehvi_all\"] = get_bayes_regret(true_max, metaobj_qnehvi_all)\n",
    "    \n",
    "    regrets[\"cumulative_regret_usemo_all\"] = get_cumulative_regret(true_max, metaobj_usemo_all)\n",
    "    regrets[\"full_regret_usemo_all\"] = get_full_regret(true_max, metaobj_usemo_all)\n",
    "    regrets[\"bayes_regret3_usemo_all\"] = get_bayes_regret(true_max, metaobj_usemo_all)\n",
    "    \n",
    "    return regrets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "fdeea3fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2715212/3068444794.py:6: RuntimeWarning: Mean of empty slice\n",
      "  cumulative_regret[run] = np.nanmean(true_max - solutions[run,:])\n",
      "/tmp/ipykernel_2715212/3068444794.py:22: RuntimeWarning: Mean of empty slice\n",
      "  bayes_regret[run] = np.nanmean(true_max - np.maximum.accumulate(solutions[run,:]))\n"
     ]
    }
   ],
   "source": [
    "imc_ARI_regrets = compute_regrets(imc_true_ARI_max, \n",
    "                              imc_ARI_manatee_all,\n",
    "                              imc_ARI_ucb_scal_exhaust_all,\n",
    "                              imc_ARI_manatee_no_max, \n",
    "                              imc_ARI_manatee_no_cor,\n",
    "                              imc_ARI_manatee_no_noise,\n",
    "                              imc_ARI_random_prob_all,\n",
    "                              imc_ARI_random_loc_all,\n",
    "                              imc_ARI_qparego_all,\n",
    "                              imc_ARI_qnehvi_all, \n",
    "                              imc_ARI_usemo_all)\n",
    "\n",
    "imc_NMI_regrets = compute_regrets(imc_true_NMI_max, \n",
    "                              imc_NMI_manatee_all,\n",
    "                              imc_NMI_ucb_scal_exhaust_all,\n",
    "                              imc_NMI_manatee_no_max, \n",
    "                              imc_NMI_manatee_no_cor,\n",
    "                              imc_NMI_manatee_no_noise,\n",
    "                              imc_NMI_random_prob_all,\n",
    "                              imc_NMI_random_loc_all,\n",
    "                              imc_NMI_qparego_all,\n",
    "                              imc_NMI_qnehvi_all,\n",
    "                              imc_NMI_usemo_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f00e4b15",
   "metadata": {},
   "outputs": [],
   "source": [
    "def latex_stats(regrets):\n",
    "    mean = np.nanmean(regrets)\n",
    "    std = np.nanstd(regrets)\n",
    "    return \"$\" + f\"{mean:.3f} \" + f\"({std:.3f})$\"\n",
    "\n",
    "def latex_stats_all_regrets(regrets1, regrets2, regrets3):\n",
    "    all_stats = []\n",
    "    all_stats.append(latex_stats(regrets1))\n",
    "    all_stats.append(latex_stats(regrets2))\n",
    "    all_stats.append(latex_stats(regrets3))\n",
    "    return all_stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "73855b5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def output_latex_table(imc_ARI_regrets, imc_NMI_regrets, title, postfix):\n",
    "    string_ari = ' & '.join(latex_stats_all_regrets(imc_ARI_regrets[f\"cumulative_regret_{postfix}\"],\n",
    "                                                 imc_ARI_regrets[f\"full_regret_{postfix}\"],\n",
    "                                                 imc_ARI_regrets[f\"bayes_regret3_{postfix}\"]))\n",
    "    string_nmi = ' & '.join(latex_stats_all_regrets(imc_NMI_regrets[f\"cumulative_regret_{postfix}\"],\n",
    "                                                 imc_NMI_regrets[f\"full_regret_{postfix}\"],\n",
    "                                                 imc_NMI_regrets[f\"bayes_regret3_{postfix}\"]))\n",
    "    \n",
    "    string = ' & '.join([f\"{title}\", string_ari, string_nmi])\n",
    "    string += \" \\\\\"\n",
    "    return string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d13ff92d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M-SA & $0.017 (0.005)$ & $0.003 (0.003)$ & $0.007 (0.006)$ & $0.019 (0.009)$ & $0.002 (0.005)$ & $0.008 (0.009)$ \\\n",
      "M-AS & $0.021 (0.010)$ & $0.006 (0.010)$ & $0.011 (0.011)$ & $0.025 (0.017)$ & $0.007 (0.016)$ & $0.015 (0.017)$ \\\n",
      "RA & $0.045 (0.001)$ & $0.024 (0.013)$ & $0.031 (0.008)$ & $0.065 (0.003)$ & $0.026 (0.013)$ & $0.036 (0.009)$ \\\n",
      "RS & $0.021 (0.004)$ & $0.003 (0.004)$ & $0.008 (0.005)$ & $0.025 (0.006)$ & $0.003 (0.006)$ & $0.008 (0.007)$ \\\n",
      "qNEHVI & $0.042 (0.004)$ & $0.011 (0.007)$ & $0.021 (0.011)$ & $0.061 (0.007)$ & $0.011 (0.007)$ & $0.026 (0.017)$ \\\n",
      "qNParEGO & $0.037 (0.005)$ & $0.002 (0.003)$ & $0.017 (0.012)$ & $0.049 (0.009)$ & $0.005 (0.003)$ & $0.023 (0.017)$ \\\n",
      "USeMO & $0.043 (0.003)$ & $0.010 (0.009)$ & $0.016 (0.009)$ & $0.062 (0.005)$ & $0.012 (0.011)$ & $0.018 (0.013)$ \\\n"
     ]
    }
   ],
   "source": [
    "print(output_latex_table(imc_ARI_regrets, imc_NMI_regrets, \"M-SA\", \"all\"))\n",
    "print(output_latex_table(imc_ARI_regrets, imc_NMI_regrets, \"M-AS\", \"ucb_scal_exhaust\"))\n",
    "print(output_latex_table(imc_ARI_regrets, imc_NMI_regrets, \"RA\", \"random_loc_all\"))\n",
    "print(output_latex_table(imc_ARI_regrets, imc_NMI_regrets, \"RS\", \"random_prob_all\"))\n",
    "print(output_latex_table(imc_ARI_regrets, imc_NMI_regrets, \"qNEHVI\", \"qnehvi_all\"))\n",
    "print(output_latex_table(imc_ARI_regrets, imc_NMI_regrets, \"qNParEGO\", \"qparego_all\"))\n",
    "print(output_latex_table(imc_ARI_regrets, imc_NMI_regrets, \"USeMO\", \"usemo_all\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c95d21c2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Explainability & $0.016 (0.004)$ & $0.002 (0.003)$ & $0.007 (0.004)$ & $0.017 (0.006)$ & $0.002 (0.005)$ & $0.007 (0.007)$ \\\n",
      "Inter-obj agreement & $0.018 (0.005)$ & $0.003 (0.005)$ & $0.007 (0.006)$ & $0.020 (0.009)$ & $0.003 (0.008)$ & $0.008 (0.010)$ \\\n",
      "Max not at boundary & $0.018 (0.007)$ & $0.004 (0.006)$ & $0.008 (0.007)$ & $0.020 (0.012)$ & $0.004 (0.011)$ & $0.009 (0.012)$ \\\n"
     ]
    }
   ],
   "source": [
    "print(output_latex_table(imc_ARI_regrets, imc_NMI_regrets, \"Explainability\", \"no_noise\"))\n",
    "print(output_latex_table(imc_ARI_regrets, imc_NMI_regrets, \"Inter-obj agreement\", \"no_cor\"))\n",
    "print(output_latex_table(imc_ARI_regrets, imc_NMI_regrets, \"Max not at boundary\", \"no_max\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f34ba431",
   "metadata": {},
   "source": [
    "# CITE-seq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "73a28f21",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = 'meta_obj_csv/'\n",
    "path_ari = path + 'new_citeseq_msa_ari_100_runs_36_iter.csv'\n",
    "path_nmi = path + 'new_citeseq_msa_nmi_100_runs_36_iter.csv'\n",
    "_, citeseq_ARI_manatee_all, citeseq_NMI_manatee_all = load_csv(path_ari, path_nmi, 100, 36)\n",
    "\n",
    "path_ari = path + 'new_citeseq_mas_ari_100_runs_36_iter.csv'\n",
    "path_nmi = path + 'new_citeseq_mas_nmi_100_runs_36_iter.csv'\n",
    "_, citeseq_ARI_ucb_scal_exhaust_all, citeseq_NMI_ucb_scal_exhaust_all = load_csv(path_ari, path_nmi, 100, 36)\n",
    "\n",
    "path_ari = path + 'new_citeseq_rs_ari_100_runs_36_iter.csv'\n",
    "path_nmi = path + 'new_citeseq_rs_nmi_100_runs_36_iter.csv'\n",
    "_, citeseq_ARI_random_prob_all, citeseq_NMI_random_prob_all = load_csv(path_ari, path_nmi, 100, 36)\n",
    "\n",
    "path_ari = path + 'new_citeseq_ra_ari_100_runs_36_iter.csv'\n",
    "path_nmi = path + 'new_citeseq_ra_nmi_100_runs_36_iter.csv'\n",
    "_, citeseq_ARI_random_loc_all, citeseq_NMI_random_loc_all = load_csv(path_ari, path_nmi, 100, 36)\n",
    "\n",
    "path_ari = path + 'new_citeseq_qparego_ari_100_runs_36_iter.csv'\n",
    "path_nmi = path + 'new_citeseq_qparego_nmi_100_runs_36_iter.csv'\n",
    "_, citeseq_ARI_qparego_all, citeseq_NMI_qparego_all = load_csv_nan(path_ari, path_nmi, 100, 36)\n",
    "\n",
    "path_ari = path + 'new_citeseq_qnehvi_ari_100_runs_36_iter.csv'\n",
    "path_nmi = path + 'new_citeseq_qnehvi_nmi_100_runs_36_iter.csv'\n",
    "_, citeseq_ARI_qnehvi_all, citeseq_NMI_qnehvi_all = load_csv_nan(path_ari, path_nmi, 100, 36)\n",
    "\n",
    "path_ari = path + 'new_citeseq_usemo_ari_100_runs_36_iter.csv'\n",
    "path_nmi = path + 'new_citeseq_usemo_nmi_100_runs_36_iter.csv'\n",
    "_, citeseq_ARI_usemo_all, citeseq_NMI_usemo_all = load_csv_nan(path_ari, path_nmi, 100, 36)\n",
    "\n",
    "path_ari = path + 'new_citeseq_manatee_no_max_ari_100_runs_36_iter.csv'\n",
    "path_nmi = path + 'new_citeseq_manatee_no_max_nmi_100_runs_36_iter.csv'\n",
    "_, citeseq_ARI_manatee_no_max, citeseq_NMI_manatee_no_max = load_csv_nan(path_ari, path_nmi, 100, 36)\n",
    "\n",
    "path_ari = path + 'new_citeseq_manatee_no_cor_ari_100_runs_36_iter.csv'\n",
    "path_nmi = path + 'new_citeseq_manatee_no_cor_nmi_100_runs_36_iter.csv'\n",
    "_, citeseq_ARI_manatee_no_cor, citeseq_NMI_manatee_no_cor = load_csv_nan(path_ari, path_nmi, 100, 36)\n",
    "\n",
    "path_ari = path + 'new_citeseq_manatee_no_noise_ari_100_runs_36_iter.csv'\n",
    "path_nmi = path + 'new_citeseq_manatee_no_noise_nmi_100_runs_36_iter.csv'\n",
    "_, citeseq_ARI_manatee_no_noise, citeseq_NMI_manatee_no_noise = load_csv_nan(path_ari, path_nmi, 100, 36)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ed1f3180",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "citeseq_true_ARI_max = np.max([citeseq_ARI_manatee_no_max.max(),\n",
    "                       citeseq_ARI_manatee_no_cor.max(),\n",
    "                       citeseq_ARI_manatee_no_noise.max(),\n",
    "                       citeseq_ARI_manatee_all.max(),\n",
    "                       citeseq_ARI_ucb_scal_exhaust_all.max(),\n",
    "                       citeseq_ARI_random_prob_all.max(),\n",
    "                       citeseq_ARI_random_loc_all.max(),\n",
    "                       citeseq_ARI_usemo_all.max(),\n",
    "                       np.nanmax(citeseq_ARI_qparego_all),\n",
    "                       np.nanmax(citeseq_ARI_qnehvi_all)])\n",
    "\n",
    "citeseq_true_NMI_max = np.max([citeseq_NMI_manatee_no_max.max(), \n",
    "                       citeseq_NMI_manatee_no_cor.max(), \n",
    "                       citeseq_NMI_manatee_no_noise.max(), \n",
    "                       citeseq_NMI_manatee_all.max(),\n",
    "                       citeseq_NMI_ucb_scal_exhaust_all.max(),\n",
    "                       citeseq_NMI_random_prob_all.max(),\n",
    "                       citeseq_NMI_random_loc_all.max(),\n",
    "                       citeseq_NMI_usemo_all.max(),\n",
    "                       np.nanmax(citeseq_NMI_qparego_all),\n",
    "                       np.nanmax(citeseq_NMI_qnehvi_all)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "a688cd7f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2715212/3068444794.py:6: RuntimeWarning: Mean of empty slice\n",
      "  cumulative_regret[run] = np.nanmean(true_max - solutions[run,:])\n",
      "/tmp/ipykernel_2715212/3068444794.py:22: RuntimeWarning: Mean of empty slice\n",
      "  bayes_regret[run] = np.nanmean(true_max - np.maximum.accumulate(solutions[run,:]))\n"
     ]
    }
   ],
   "source": [
    "citeseq_ARI_regrets = compute_regrets(citeseq_true_ARI_max, \n",
    "                              citeseq_ARI_manatee_all,\n",
    "                              citeseq_ARI_ucb_scal_exhaust_all,\n",
    "                              citeseq_ARI_manatee_no_max, \n",
    "                              citeseq_ARI_manatee_no_cor,\n",
    "                              citeseq_ARI_manatee_no_noise,\n",
    "                              citeseq_ARI_random_prob_all,\n",
    "                              citeseq_ARI_random_loc_all,\n",
    "                              citeseq_ARI_qparego_all,\n",
    "                              citeseq_ARI_qnehvi_all, \n",
    "                              citeseq_ARI_usemo_all)\n",
    "\n",
    "citeseq_NMI_regrets = compute_regrets(citeseq_true_NMI_max, \n",
    "                              citeseq_NMI_manatee_all,\n",
    "                              citeseq_NMI_ucb_scal_exhaust_all,\n",
    "                              citeseq_NMI_manatee_no_max, \n",
    "                              citeseq_NMI_manatee_no_cor,\n",
    "                              citeseq_NMI_manatee_no_noise,\n",
    "                              citeseq_NMI_random_prob_all,\n",
    "                              citeseq_NMI_random_loc_all,\n",
    "                              citeseq_NMI_qparego_all,\n",
    "                              citeseq_NMI_qnehvi_all,\n",
    "                              citeseq_NMI_usemo_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "8d52ea2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M-SA & $0.126 (0.019)$ & $0.056 (0.010)$ & $0.064 (0.012)$ & $0.125 (0.026)$ & $0.022 (0.010)$ & $0.031 (0.013)$ \\\n",
      "M-AS & $0.127 (0.026)$ & $0.058 (0.015)$ & $0.067 (0.015)$ & $0.124 (0.036)$ & $0.023 (0.013)$ & $0.034 (0.017)$ \\\n",
      "RA & $0.192 (0.020)$ & $0.053 (0.009)$ & $0.067 (0.013)$ & $0.199 (0.025)$ & $0.025 (0.009)$ & $0.040 (0.015)$ \\\n",
      "RS & $0.140 (0.018)$ & $0.049 (0.008)$ & $0.059 (0.010)$ & $0.130 (0.021)$ & $0.014 (0.007)$ & $0.026 (0.012)$ \\\n",
      "qNEHVI & $0.186 (0.038)$ & $0.050 (0.008)$ & $0.092 (0.045)$ & $0.191 (0.047)$ & $0.023 (0.009)$ & $0.073 (0.057)$ \\\n",
      "qNParEGO & $0.161 (0.039)$ & $0.055 (0.009)$ & $0.091 (0.048)$ & $0.158 (0.049)$ & $0.023 (0.009)$ & $0.070 (0.064)$ \\\n",
      "USeMO & $0.218 (0.031)$ & $0.057 (0.013)$ & $0.071 (0.014)$ & $0.231 (0.040)$ & $0.027 (0.014)$ & $0.042 (0.019)$ \\\n"
     ]
    }
   ],
   "source": [
    "print(output_latex_table(citeseq_ARI_regrets, citeseq_NMI_regrets, \"M-SA\", \"all\"))\n",
    "print(output_latex_table(citeseq_ARI_regrets, citeseq_NMI_regrets, \"M-AS\", \"ucb_scal_exhaust\"))\n",
    "print(output_latex_table(citeseq_ARI_regrets, citeseq_NMI_regrets, \"RA\", \"random_loc_all\"))\n",
    "print(output_latex_table(citeseq_ARI_regrets, citeseq_NMI_regrets, \"RS\", \"random_prob_all\"))\n",
    "print(output_latex_table(citeseq_ARI_regrets, citeseq_NMI_regrets, \"qNEHVI\", \"qnehvi_all\"))\n",
    "print(output_latex_table(citeseq_ARI_regrets, citeseq_NMI_regrets, \"qNParEGO\", \"qparego_all\"))\n",
    "print(output_latex_table(citeseq_ARI_regrets, citeseq_NMI_regrets, \"USeMO\", \"usemo_all\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "9779976a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Explainability & $0.126 (0.023)$ & $0.055 (0.012)$ & $0.063 (0.012)$ & $0.121 (0.032)$ & $0.021 (0.009)$ & $0.030 (0.012)$ \\\n",
      "Inter-obj agreement & $0.129 (0.018)$ & $0.055 (0.011)$ & $0.063 (0.012)$ & $0.126 (0.025)$ & $0.020 (0.010)$ & $0.030 (0.014)$ \\\n",
      "Max not at boundary & $0.120 (0.022)$ & $0.055 (0.012)$ & $0.063 (0.013)$ & $0.113 (0.033)$ & $0.020 (0.010)$ & $0.030 (0.013)$ \\\n"
     ]
    }
   ],
   "source": [
    "print(output_latex_table(citeseq_ARI_regrets, citeseq_NMI_regrets, \"Explainability\", \"no_noise\"))\n",
    "print(output_latex_table(citeseq_ARI_regrets, citeseq_NMI_regrets, \"Inter-obj agreement\", \"no_cor\"))\n",
    "print(output_latex_table(citeseq_ARI_regrets, citeseq_NMI_regrets, \"Max not at boundary\", \"no_max\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1b7bfc6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
