{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "13404891",
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "321728ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cumulative_regret_fold(true_max, solutions):\n",
    "    return np.nanmean(true_max - solutions)\n",
    "\n",
    "def get_full_regret_fold(true_max, solutions):\n",
    "    return true_max - solutions.max()\n",
    "\n",
    "def get_bayes_regret_fold(true_max, solutions):\n",
    "        return np.nanmean(true_max - np.maximum.accumulate(solutions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e71d1a68",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_regrets(folds_list):\n",
    "\n",
    "    cumulative_regrets = []\n",
    "    bayes_regrets = []\n",
    "\n",
    "    for run in range(len(folds_list)):\n",
    "        cumul_reg = np.zeros(5)\n",
    "        bayes_reg = np.zeros(5)\n",
    "        \n",
    "        for fold in range(5):\n",
    "            fold_max = folds_list[run][fold].max()\n",
    "            cumul_reg[fold] = get_cumulative_regret_fold(fold_max, folds_list[run][fold])\n",
    "            bayes_reg[fold] = get_bayes_regret_fold(fold_max, folds_list[run][fold])\n",
    "            \n",
    "        cumulative_regrets.append(cumul_reg.mean())\n",
    "        bayes_regrets.append(bayes_reg.mean())\n",
    "    \n",
    "    return cumulative_regrets, bayes_regrets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "50a62e63",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_regret_stats(regrets, name, method, metric):\n",
    "    mean = np.nanmean(regrets)\n",
    "    std = np.nanstd(regrets)\n",
    "    median = np.nanmedian(regrets)\n",
    "    \n",
    "    print(f\"{name} {metric} for {method}: \\\n",
    "    {mean:.3f} +- {std:.3f}, median: {median:.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32332488",
   "metadata": {},
   "source": [
    "# IMC:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba3f1d38",
   "metadata": {},
   "source": [
    "The ARI and NMI values from the cross-validation experiments are saved in `.pkl` files:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8796eb84",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "path='crossval_pkl/'\n",
    "with open(path+\"cv_imc_ari_train.pkl\", \"rb\") as f:\n",
    "    ARI_train_manatee = pickle.load(f)\n",
    "    \n",
    "with open(path+\"cv_imc_ari_test.pkl\", \"rb\") as f:\n",
    "    ARI_test_manatee = pickle.load(f)\n",
    "    \n",
    "with open(path+\"cv_imc_nmi_train.pkl\", \"rb\") as f:\n",
    "    NMI_train_manatee = pickle.load(f)\n",
    "    \n",
    "with open(path+\"cv_imc_nmi_test.pkl\", \"rb\") as f:\n",
    "    NMI_test_manatee = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d9e43723",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cumulative regret on train ARI for manatee:     0.015 +- 0.003, median: 0.015\n",
      "Cumulative regret on test ARI for manatee:     0.017 +- 0.003, median: 0.017\n",
      "Cumulative regret on train NMI for manatee:     0.021 +- 0.005, median: 0.021\n",
      "Cumulative regret on test NMI for manatee:     0.023 +- 0.005, median: 0.023\n"
     ]
    }
   ],
   "source": [
    "cumulative_regrets_manatee_train, bayes_regrets_manatee_train = compute_regrets(ARI_train_manatee)\n",
    "cumulative_regrets_manatee_test, bayes_regrets_manatee_test = compute_regrets(ARI_test_manatee)\n",
    "\n",
    "cumulative_regrets_manatee_train_nmi, bayes_regrets_manatee_train_nmi = compute_regrets(NMI_train_manatee)\n",
    "cumulative_regrets_manatee_test_nmi, bayes_regrets_manatee_test_nmi = compute_regrets(NMI_test_manatee)\n",
    "\n",
    "print_regret_stats(cumulative_regrets_manatee_train, 'Cumulative regret on train', 'manatee', 'ARI')\n",
    "print_regret_stats(cumulative_regrets_manatee_test, 'Cumulative regret on test', 'manatee', 'ARI')\n",
    "print_regret_stats(cumulative_regrets_manatee_train_nmi, 'Cumulative regret on train', 'manatee', 'NMI')\n",
    "print_regret_stats(cumulative_regrets_manatee_test_nmi, 'Cumulative regret on test', 'manatee', 'NMI')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "140345cf",
   "metadata": {},
   "source": [
    "# CITE-seq:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cac7ac3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "path='crossval_pkl/'\n",
    "with open(path+\"cv_citeseq_ari_train.pkl\", \"rb\") as f:\n",
    "    ARI_train_manatee = pickle.load(f)\n",
    "    \n",
    "with open(path+\"cv_citeseq_ari_test.pkl\", \"rb\") as f:\n",
    "    ARI_test_manatee = pickle.load(f)\n",
    "    \n",
    "with open(path+\"cv_citeseq_nmi_train.pkl\", \"rb\") as f:\n",
    "    NMI_train_manatee = pickle.load(f)\n",
    "    \n",
    "with open(path+\"cv_citeseq_nmi_test.pkl\", \"rb\") as f:\n",
    "    NMI_test_manatee = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48ff69d8",
   "metadata": {},
   "source": [
    "Compute regrets:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f202e9be",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cumulative regret on train ARI for manatee:     0.095 +- 0.015, median: 0.094\n",
      "Cumulative regret on test ARI for manatee:     0.089 +- 0.015, median: 0.090\n",
      "Cumulative regret on train NMI for manatee:     0.126 +- 0.021, median: 0.125\n",
      "Cumulative regret on test NMI for manatee:     0.126 +- 0.020, median: 0.123\n"
     ]
    }
   ],
   "source": [
    "cumulative_regrets_manatee_train, bayes_regrets_manatee_train = compute_regrets(ARI_train_manatee)\n",
    "cumulative_regrets_manatee_test, bayes_regrets_manatee_test = compute_regrets(ARI_test_manatee)\n",
    "\n",
    "cumulative_regrets_manatee_train_nmi, bayes_regrets_manatee_train_nmi = compute_regrets(NMI_train_manatee)\n",
    "cumulative_regrets_manatee_test_nmi, bayes_regrets_manatee_test_nmi = compute_regrets(NMI_test_manatee)\n",
    "\n",
    "print_regret_stats(cumulative_regrets_manatee_train, 'Cumulative regret on train', 'manatee', 'ARI')\n",
    "print_regret_stats(cumulative_regrets_manatee_test, 'Cumulative regret on test', 'manatee', 'ARI')\n",
    "print_regret_stats(cumulative_regrets_manatee_train_nmi, 'Cumulative regret on train', 'manatee', 'NMI')\n",
    "print_regret_stats(cumulative_regrets_manatee_test_nmi, 'Cumulative regret on test', 'manatee', 'NMI')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "878d7f4b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
