{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from metrics import Metrics, MetricsDouble\n",
    "from sklearn.neighbors import KernelDensity\n",
    "from tqdm import tqdm\n",
    "from typing import List\n",
    "from utils import load2\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import numpy.typing as npt\n",
    "import os\n",
    "\n",
    "plt.style.use('qualitative.mplstyle')\n",
    "colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_density(score_1: npt.NDArray[float],\n",
    "                 score_2: npt.NDArray[float],\n",
    "                 ood: npt.NDArray[bool],\n",
    "                 axis: plt.Axes,\n",
    "                 correct: npt.NDArray[bool] = None,\n",
    "                 ngridx: int = 30,\n",
    "                 ngridy: int = 30,\n",
    "                 bandwidth: float = 0.03,\n",
    "                 contourf_alpha: float = 0.5,\n",
    "                 countour_levels: int = 10,\n",
    "                 scatter_alpha: float = 0.4,\n",
    "                 scatter_size: float = 0.7,\n",
    "                 scatter_color_intensity: float = 0.7,\n",
    "                 linewidths: float = 0.7):\n",
    "    \"\"\"\n",
    "    Estimates density of ID/OOD data and visualizes it.\n",
    "\n",
    "    Args:\n",
    "        score_1 (npt.NDArray[float]): Array of score from some method.\n",
    "        score_2 (npt.NDArray[float]): Array of score from some other method.\n",
    "        ood (npt.NDArray[bool]): Boolean array indicating whether samples are OOD (True) or ID (False).\n",
    "        axis (plt.Axes): Axis to plot into.\n",
    "        correct (npt.NDArray[bool], optional): If not None, is expected to be boolean array indicating whether prediction on ID data is correct. Defaults to None.\n",
    "        ngridx (int, optional): Density X axis grid. Defaults to 30.\n",
    "        ngridy (int, optional): Density Y axis grid. Defaults to 30.\n",
    "        bandwidth (float, optional): Density estimation bandwidth. Defaults to 0.03.\n",
    "        contourf_alpha (float, optional): Transparency of filled contour plot of density. Defaults to 0.5.\n",
    "        countour_levels (int, optional): Number of levels of contour to show. Defaults to 10.\n",
    "        scatter_alpha (float, optional): Transparency of data samples. Defaults to 0.4.\n",
    "        scatter_size (float, optional): Size of data samples. Defaults to 0.7.\n",
    "        scatter_color_intensity (float, optional): Color (from colormap) of data samples. Defaults to 0.7.\n",
    "        linewidths (float, optional): Width of contour and scatter lines. Defaults to 0.7.\n",
    "    \"\"\"\n",
    "\n",
    "    Xtrain = np.column_stack([score_2, score_1])\n",
    "    xi = np.linspace(np.min(score_1), np.max(score_1), ngridx)\n",
    "    yi = np.linspace(np.min(score_2), np.max(score_2), ngridy)\n",
    "\n",
    "    Xi, Yi = np.meshgrid(xi, yi)\n",
    "    XY = np.vstack([Yi.ravel(), Xi.ravel()]).T\n",
    "\n",
    "    kde = KernelDensity(\n",
    "        bandwidth=bandwidth,\n",
    "        metric=\"euclidean\",\n",
    "        kernel=\"exponential\"\n",
    "    )\n",
    "\n",
    "    # Fit to OOD samples\n",
    "    kde.fit(Xtrain[ood])\n",
    "    Zi = np.exp(kde.score_samples(XY)).reshape(Xi.shape)\n",
    "\n",
    "    # Plot contours of the density\n",
    "    levels = np.linspace(0, Zi.max(), countour_levels)\n",
    "    axis.contourf(Xi, Yi, Zi, levels=levels,\n",
    "                  cmap=plt.cm.Reds, alpha=contourf_alpha, extend='both')\n",
    "    CS1 = axis.contour(Xi, Yi, Zi, levels=levels,\n",
    "                       linewidths=linewidths, cmap=plt.cm.Reds)\n",
    "\n",
    "    # Fit to ID samples\n",
    "    kde.fit(Xtrain[~ood])\n",
    "    Zi = np.exp(kde.score_samples(XY)).reshape(Xi.shape)\n",
    "\n",
    "    # Plot contours of the density\n",
    "    levels = np.linspace(0, Zi.max(), countour_levels)\n",
    "    axis.contourf(Xi, Yi, Zi, levels=levels,\n",
    "                  cmap=plt.cm.Blues, alpha=contourf_alpha, extend='both')\n",
    "    CS1 = axis.contour(Xi, Yi, Zi, levels=levels,\n",
    "                       linewidths=linewidths, cmap=plt.cm.Blues)\n",
    "\n",
    "    # Scatter data points\n",
    "    # We do the scatter after plotting contours. Otherwise, the contours cover up the points\n",
    "    cmap = matplotlib.cm.get_cmap('Reds')\n",
    "    axis.scatter(score_1[ood], score_2[ood], alpha=scatter_alpha, s=scatter_size,\n",
    "                 marker='.', color=cmap(scatter_color_intensity), linewidths=linewidths, label=f\"OOD: {sum(ood)}\")\n",
    "    cmap = matplotlib.cm.get_cmap('Blues')\n",
    "    if correct is not None:\n",
    "        axis.scatter(score_1[~ood & correct], score_2[~ood & correct], alpha=scatter_alpha,\n",
    "                     marker='.', s=scatter_size, color=cmap(scatter_color_intensity), linewidths=linewidths)\n",
    "        axis.scatter(score_1[~ood & ~correct], score_2[~ood & ~correct], alpha=scatter_alpha,\n",
    "                     marker='x', s=scatter_size+1, color=cmap(scatter_color_intensity), linewidths=linewidths, label=f\"ID: {sum(~ood)}\")\n",
    "    else:\n",
    "        axis.scatter(score_1[~ood], score_2[~ood], alpha=scatter_alpha, marker='.',\n",
    "                     s=scatter_size, color=cmap(scatter_color_intensity), linewidths=linewidths, label=f\"ID: {sum(~ood)}\")\n",
    "        \n",
    "    lgnd = axis.legend(loc='lower right')\n",
    "    lgnd.legendHandles[0]._sizes = [30]\n",
    "    lgnd.legendHandles[1]._sizes = [30]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_scores(in_dataset: str = 'cifar10',\n",
    "                out_dataset: str = 'cifar100',\n",
    "                methods: List[str] = ['msp', 'odin'],\n",
    "                max_fpr: float = 0.35,\n",
    "                min_coverage: float = 0.9,\n",
    "                use_alpha_parameterization: bool = True):\n",
    "    \"\"\"\n",
    "    Loads data for all specified methods, evaluated on the given ID and OOD dataset and visualizes the data distribution\n",
    "    and the optimal double-score selective function.\n",
    "\n",
    "    Args:\n",
    "        in_dataset (str, optional): ID dataset. Defaults to 'cifar10'.\n",
    "        out_dataset (str, optional): OOD dataset. Defaults to 'cifar100'.\n",
    "        methods (List[str], optional): List of methods to compare. The first method (index 0) is combined with all remaining. Defaults to ['msp', 'odin'].\n",
    "        max_fpr (float, optional): Maximal admissible FPR. Defaults to 0.35.\n",
    "        min_coverage (float, optional): Minimal admissible TPR. Defaults to 0.9.\n",
    "        use_alpha_parameterization (bool, optional): If True uses parameterization of separating hyperplane angle. Defaults to True.\n",
    "    \"\"\"\n",
    "\n",
    "    figure, axes = plt.subplots(1, len(methods)-1, figsize=(3*len(methods), 4))\n",
    "\n",
    "    # Load data for reference method\n",
    "    y_gt_1, y_pred_1, score_1 = load2(f'data/{in_dataset}_{methods[0]}/{in_dataset}.npz',\n",
    "                                      f'data/{in_dataset}_{methods[0]}/{out_dataset}.npz')\n",
    "\n",
    "    for method, axis in tqdm(zip(methods[1:], axes), total=len(methods)-1):\n",
    "        # Load data for the other method\n",
    "        y_gt_2, y_pred_2, score_2 = load2(f'data/{in_dataset}_{method}/{in_dataset}.npz',\n",
    "                                          f'data/{in_dataset}_{method}/{out_dataset}.npz')\n",
    "\n",
    "        # Plot ID and OOD kernel estimated density\n",
    "        plot_density(score_1, score_2, y_gt_1 == -1, axis)\n",
    "\n",
    "        # Get limits of plot. Used for plotting shaded areas\n",
    "        left, right = axis.get_xlim()\n",
    "        bottom, top = axis.get_ylim()\n",
    "\n",
    "        # Initialize object which computes metric of double-score selective classifier\n",
    "        metric = MetricsDouble(\n",
    "            y_true=y_gt_1,\n",
    "            y_pred=y_pred_1,\n",
    "            scoreA=score_1,\n",
    "            scoreB=score_2)\n",
    "\n",
    "        # Angle parameterization\n",
    "        if use_alpha_parameterization:\n",
    "\n",
    "            # Get optimal alpha (depends on sampling rate of mu defined in metrics.py)\n",
    "            alpha, threshold = metric.get_alpha(\n",
    "                max_fpr=max_fpr,\n",
    "                min_coverage=min_coverage)\n",
    "\n",
    "            if alpha is None:\n",
    "                axis.annotate(\"Infeasible\",\n",
    "                              xy=(0, 1),\n",
    "                              xytext=(12, -12),\n",
    "                              va='top',\n",
    "                              xycoords='axes fraction',\n",
    "                              textcoords='offset points',\n",
    "                              bbox=dict(facecolor='white', edgecolor='k', pad=5.0))\n",
    "\n",
    "            # Plot separating hyperplane\n",
    "            if alpha is not None:                \n",
    "                if alpha == 0:\n",
    "                    axis.axvline(threshold, c='k')\n",
    "                    axis.fill_between([left, threshold], [bottom, bottom], [\n",
    "                        top, top], alpha=0.15, color='k')\n",
    "\n",
    "                else:\n",
    "                    x = np.array([left, right])\n",
    "                    y = (threshold - np.cos(alpha)*x)/np.sin(alpha)\n",
    "                    axis.plot(x, y, c='k')\n",
    "\n",
    "                    if alpha <= np.pi:\n",
    "                        axis.fill_between(\n",
    "                            x, y, [bottom for _ in y], alpha=0.15, color='k')\n",
    "                    else:\n",
    "                        axis.fill_between(\n",
    "                            x, y, [top for _ in y], alpha=0.15, color='k')\n",
    "\n",
    "                axis.set_xlim([left, right])\n",
    "                axis.set_ylim([bottom, top])\n",
    "\n",
    "                score = np.cos(alpha)*score_1+np.sin(alpha)*score_2\n",
    "                metric = Metrics(y_true=y_gt_1, y_pred=y_pred_1, score=score)\n",
    "                sel_risk = metric.sel_risk\n",
    "                coverage = metric.coverage\n",
    "                fpr = metric.fpr\n",
    "                where = (fpr <= max_fpr) & (coverage >= min_coverage)\n",
    "\n",
    "                potential_risks = sel_risk[where]\n",
    "                potential_fprs = fpr[where]\n",
    "                potential_coverages = coverage[where]\n",
    "\n",
    "                where = np.argmin(potential_risks)\n",
    "                risk = potential_risks[where]\n",
    "                tpr = potential_coverages[where]\n",
    "                fpr = potential_fprs[where]\n",
    "\n",
    "                axis.annotate(f\"TPR {np.round(tpr, 3)}, FPR {np.round(fpr, 3)}, Risk {np.round(risk, 5)}\",\n",
    "                              xy=(0, 1),\n",
    "                              xytext=(12, -12),\n",
    "                              va='top',\n",
    "                              xycoords='axes fraction',\n",
    "                              textcoords='offset points',\n",
    "                              bbox=dict(facecolor='white', edgecolor='k', pad=5.0))\n",
    "\n",
    "        # Use mu parameterization\n",
    "        else:\n",
    "            # Get optimal mu (depends on sampling rate of mu defined in metrics.py)\n",
    "            mu, threshold = metric.get_mu(\n",
    "                max_fpr=max_fpr,\n",
    "                min_coverage=min_coverage)\n",
    "\n",
    "            if mu is None:\n",
    "                axis.annotate(\"Infeasible\",\n",
    "                              xy=(0, 1),\n",
    "                              xytext=(12, -12),\n",
    "                              va='top',\n",
    "                              xycoords='axes fraction',\n",
    "                              textcoords='offset points',\n",
    "                              bbox=dict(facecolor='white', edgecolor='k', pad=5.0))\n",
    "\n",
    "            # Plot separating hyperplane\n",
    "            if mu is not None:\n",
    "                if mu == 0:\n",
    "                    axis.axvline(threshold, c='k', linestyle='--')\n",
    "                    axis.fill_between([left, threshold], [bottom, bottom], [\n",
    "                        top, top], alpha=0.15, color='k')\n",
    "\n",
    "                else:\n",
    "                    x = np.array([left, right])\n",
    "                    y = (threshold + (mu-1)*x)/mu\n",
    "                    axis.plot(x, y, c='k', linestyle='--')\n",
    "                    axis.fill_between(\n",
    "                        x, y, [bottom for _ in y], alpha=0.15, color='k')\n",
    "\n",
    "                axis.set_xlim([left, right])\n",
    "                axis.set_ylim([bottom, top])\n",
    "\n",
    "                score = (1-mu)*score_1+mu*score_2\n",
    "                metric = Metrics(y_true=y_gt_1, y_pred=y_pred_1, score=score)\n",
    "                sel_risk = metric.sel_risk\n",
    "                coverage = metric.coverage\n",
    "                fpr = metric.fpr\n",
    "                where = (fpr <= max_fpr) & (coverage >= min_coverage)\n",
    "\n",
    "                potential_risks = sel_risk[where]\n",
    "                potential_fprs = fpr[where]\n",
    "                potential_coverages = coverage[where]\n",
    "\n",
    "                where = np.argmin(potential_risks)\n",
    "                risk = potential_risks[where]\n",
    "                tpr = potential_coverages[where]\n",
    "                fpr = potential_fprs[where]\n",
    "\n",
    "                axis.annotate(f\"TPR {np.round(tpr, 3)}, FPR {np.round(fpr, 3)}, Risk {np.round(risk, 5)}\",\n",
    "                              xy=(0, 1),\n",
    "                              xytext=(12, -12),\n",
    "                              va='top',\n",
    "                              xycoords='axes fraction',\n",
    "                              textcoords='offset points',\n",
    "                              bbox=dict(facecolor='white', edgecolor='k', pad=5.0))\n",
    "\n",
    "        # Annotate X and Y Axes\n",
    "        axis.set_xlabel(f'{methods[0].upper()}')\n",
    "        axis.set_ylabel(f'{method.upper()}')\n",
    "        axis.grid('on')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare ID/OOD dataset pairs\n",
    "id_ood_pairs = []\n",
    "\n",
    "# ID cifar10\n",
    "for dataset in ['cifar100',\n",
    "                'tin',\n",
    "                'mnist']:\n",
    "    id_ood_pairs.append(('cifar10', dataset))\n",
    "\n",
    "# ID mnist\n",
    "for dataset in ['notmnist',\n",
    "                'fashionmnist',\n",
    "                'cifar10']:\n",
    "    id_ood_pairs.append(('mnist', dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare list of methods to compare, the first (index 0) method is reference which is combined with all others\n",
    "methods = ['msp',\n",
    "           'knn',\n",
    "           'vim']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_fprs = {('cifar10', 'cifar100'): 0.21,\n",
    "            ('cifar10', 'tin'): 0.19,\n",
    "            ('cifar10', 'mnist'): 0.19,\n",
    "            ('mnist', 'notmnist'): 0.08,\n",
    "            ('mnist', 'fashionmnist'): 0.1,\n",
    "            ('mnist', 'cifar10'): 0.29}\n",
    "\n",
    "min_tpr = 0.8\n",
    "\n",
    "os.makedirs('figures', exist_ok=True)\n",
    "for id, ood in tqdm(id_ood_pairs):\n",
    "    \n",
    "    plot_scores(methods=methods,\n",
    "                in_dataset=id,\n",
    "                out_dataset=ood,\n",
    "                min_coverage=min_tpr,\n",
    "                max_fpr=max_fprs[(id, ood)])\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'figures/scores_2D_id_{id}_ood_{ood}.svg')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.9 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "a235f45b7b77a5b57bb950c18034617d538482d711e762ea37678dbbcb584744"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
