{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from metrics import Metrics, MetricsDouble\n",
    "from typing import List, Dict\n",
    "from utils import load2\n",
    "import copy\n",
    "import glob\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import numpy.typing as npt\n",
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "plt.style.use('sequential.mplstyle')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_ccr_fpr(in_dataset: str = 'cifar10',\n",
    "                 out_dataset: str = 'cifar100',\n",
    "                 methods: List[str] = ['msp', 'odin']) -> Dict[str, float]:\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "        Dict[str, float]: Dictionary of method -> oscr.\n",
    "    \"\"\"\n",
    "\n",
    "    # Dictionary holding mapping method -> oscr\n",
    "    oscr = {}\n",
    "\n",
    "    figure, axis = plt.subplots(1, 1)\n",
    "\n",
    "    for method in methods:\n",
    "        # Load data\n",
    "        y_gt_1, y_pred_1, score_1 = load2(f'data/{in_dataset}_{method}/{in_dataset}.npz',\n",
    "                                          f'data/{in_dataset}_{method}/{out_dataset}.npz')\n",
    "        # Compute metrics\n",
    "        metric = Metrics(y_true=y_gt_1, y_pred=y_pred_1, score=score_1)\n",
    "        # Plot CCR-FPR curve\n",
    "        axis.plot(metric.fpr, metric.ccr,\n",
    "                  label=f\"{method.upper()}: {np.round(metric.OSCR, 3)}\")\n",
    "        # Save achieved OSCR (area under the curve)\n",
    "        oscr[method] = metric.OSCR\n",
    "\n",
    "    axis.set_xlabel('FPR')\n",
    "    axis.set_ylabel('CCR')\n",
    "    axis.grid('on')\n",
    "    axis.legend()\n",
    "\n",
    "    return oscr\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_roc(in_dataset: str = 'cifar10',\n",
    "             out_dataset: str = 'cifar100',\n",
    "             methods: List[str] = ['msp', 'odin'],\n",
    "             coverage_bound: float = 0.8) -> [Dict[str, float], Dict[str, float]]:\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "        [Dict[str, float], Dict[str, float]]: Dictionary of method -> auroc, and dictionary of method -> lowest FPR, s.t. TPR > coverage bound.\n",
    "    \"\"\"\n",
    "\n",
    "    # Dictionary holding mapping method -> auroc\n",
    "    auroc = {}\n",
    "    # Dictionary holding mapping method -> fpr s.t. TPR > bound\n",
    "    fpr = {}\n",
    "\n",
    "    figure, axis = plt.subplots(1, 1)\n",
    "    axis.set_title(f\"ROC\")\n",
    "\n",
    "    for method in methods:\n",
    "        # Load data\n",
    "        y_gt_1, y_pred_1, score_1 = load2(f'data/{in_dataset}_{method}/{in_dataset}.npz',\n",
    "                                          f'data/{in_dataset}_{method}/{out_dataset}.npz')\n",
    "        # Compute metrics\n",
    "        metric = Metrics(y_true=y_gt_1, y_pred=y_pred_1, score=score_1)\n",
    "        # Plot TPR-FPR (ROC) curve\n",
    "        axis.plot(metric.fpr, metric.coverage,\n",
    "                  label=f\"{method.upper()}: {np.round(metric.AUROC, 3)}\")\n",
    "        # Save achieved AUROC (area under the curve)\n",
    "        auroc[method] = metric.AUROC\n",
    "        # Determine best achievable FPR of the method, given some minimal TPR (coverage)\n",
    "        potential_fpr_at_coverage = metric.fpr[metric.coverage >=\n",
    "                                               coverage_bound]\n",
    "        # Desired coverage is infeasible\n",
    "        if potential_fpr_at_coverage.size == 0:\n",
    "            fpr[method] = np.NaN\n",
    "        # If desired coverage is feasible, pick lowest FPR from admissible solutions\n",
    "        else:\n",
    "            fpr[method] = np.min(potential_fpr_at_coverage)\n",
    "\n",
    "    axis.set_xlabel('FPR')\n",
    "    axis.set_ylabel('TPR')\n",
    "    axis.grid('on')\n",
    "    axis.legend()\n",
    "\n",
    "    return auroc, fpr\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_prc(in_dataset: str = 'cifar10',\n",
    "             out_dataset: str = 'cifar100',\n",
    "             methods: List[str] = ['msp', 'odin'],\n",
    "             coverage_bound: float = 0.8) -> [Dict[str, float], Dict[str, float]]:\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "        [Dict[str, float], Dict[str, float]]: Dictionary of method -> aupr, and dictionary of method -> highest precision, s.t. TPR > coverage bound.\n",
    "    \"\"\"\n",
    "\n",
    "    # Dictionary holding mapping method -> aupr\n",
    "    aupr = {}\n",
    "    # Dictionary holding mapping method -> precision s.t. TPR > bound\n",
    "    precision = {}\n",
    "\n",
    "    figure, axis = plt.subplots(1, 1)\n",
    "    axis.set_title(f\"PRC\")\n",
    "\n",
    "    for method in methods:\n",
    "        # Load data\n",
    "        y_gt_1, y_pred_1, score_1 = load2(f'data/{in_dataset}_{method}/{in_dataset}.npz',\n",
    "                                          f'data/{in_dataset}_{method}/{out_dataset}.npz')\n",
    "        # Compute metrics\n",
    "        metric = Metrics(y_true=y_gt_1, y_pred=y_pred_1, score=score_1)\n",
    "        # Plot Precision-Recall curve (recall = coverage = TPR)\n",
    "        axis.plot(metric.coverage, metric.prec,\n",
    "                  label=f\"{method.upper()}: {np.round(metric.AUPR, 3)}\")\n",
    "        # Save achieved AUPR (area under the curve)\n",
    "        aupr[method] = metric.AUPR\n",
    "        # Determine best achievable precision of the method, given some minimal TPR (coverage)\n",
    "        potential_precision_at_coverage = metric.prec[metric.coverage >= coverage_bound]\n",
    "        # Desired coverage is infeasible\n",
    "        if potential_precision_at_coverage.size == 0:\n",
    "            precision[method] = np.NaN\n",
    "        # If desired coverage is feasible, pick highest precision from admissible solutions\n",
    "        else:\n",
    "            precision[method] = np.max(potential_precision_at_coverage)\n",
    "\n",
    "    axis.set_xlabel('Recall')\n",
    "    axis.set_ylabel('Precision')\n",
    "    axis.grid('on')\n",
    "    axis.legend()\n",
    "\n",
    "    return aupr, precision\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_risk_tpr_fpr_model(in_dataset: str = 'cifar10',\n",
    "                           out_dataset: str = 'cifar100',\n",
    "                           methods: List[str] = ['msp', 'odin'],\n",
    "                           min_coverage: float = 0.2,\n",
    "                           max_fpr: float = 0.2) -> Dict[str, float]:\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "        Dict[str, float]: Dictionary of method -> risk, s.t. TPR > bound and FPR < bound.\n",
    "    \"\"\"\n",
    "\n",
    "    # Dictionary holding mapping method -> best risk s.t. TPR > bound and FPR < bound\n",
    "    risk = {}\n",
    "\n",
    "    for method in methods:\n",
    "        # Load data\n",
    "        y_gt_1, y_pred_1, score_1 = load2(f'data/{in_dataset}_{method}/{in_dataset}.npz',\n",
    "                                          f'data/{in_dataset}_{method}/{out_dataset}.npz')\n",
    "        # Compute metrics\n",
    "        metric = Metrics(y_true=y_gt_1, y_pred=y_pred_1, score=score_1)\n",
    "        sel_risk = metric.sel_risk\n",
    "        coverage = metric.coverage\n",
    "        fpr = metric.fpr\n",
    "        # Determine best achievable risk of the method, given some minimal TPR (coverage) and maximal FPR\n",
    "        select = (fpr <= max_fpr) & (coverage >= min_coverage)\n",
    "        potential_risks = sel_risk[select]\n",
    "        # Infeasible\n",
    "        if potential_risks.size == 0:\n",
    "            risk[method] = np.NaN\n",
    "        # If feasible, pick lowest risk from admissible solutions\n",
    "        else:\n",
    "            risk[method] = np.min(potential_risks)\n",
    "\n",
    "    return risk\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_risk_tpr_fpr_model_combined(in_dataset: str = 'cifar10',\n",
    "                                    out_dataset: str = 'cifar100',\n",
    "                                    methods: List[str] = ['msp', 'odin'],\n",
    "                                    min_coverage: float = 0.2,\n",
    "                                    max_fpr: float = 0.2) -> Dict[str, float]:\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "        Dict[str, float]: Dictionary of method -> risk, s.t. TPR > bound and FPR < bound for the Double Score setting.\n",
    "    \"\"\"\n",
    "\n",
    "    # Dictionary holding mapping method -> best risk s.t. TPR > bound and FPR < bound\n",
    "    risk = {}\n",
    "\n",
    "    # Load data of referential method\n",
    "    y_gt_1, y_pred_1, score_1 = load2(f'data/{in_dataset}_{methods[0]}/{in_dataset}.npz',\n",
    "                                      f'data/{in_dataset}_{methods[0]}/{out_dataset}.npz')\n",
    "\n",
    "    for method in methods[1:]:\n",
    "        # Load data of a method to combine with the referential method in a Double Score setting\n",
    "        y_gt_2, y_pred_2, score_2 = load2(f'data/{in_dataset}_{method}/{in_dataset}.npz',\n",
    "                                          f'data/{in_dataset}_{method}/{out_dataset}.npz')\n",
    "\n",
    "        # Initialize the Double Score method object\n",
    "        metric2 = MetricsDouble(\n",
    "            y_true=y_gt_1, y_pred=y_pred_1, scoreA=score_2, scoreB=score_1)\n",
    "        # Compute achievable risk and coverage of the method, given some maximal FPR\n",
    "        sel_risk, coverage = metric2.risk_coverage_curve_at_fpr(max_fpr)\n",
    "        # Determine best achievable risk of the method, given some minimal TPR (coverage) and maximal FPR\n",
    "        potential_risks = sel_risk[coverage >= min_coverage]\n",
    "        # Infeasible\n",
    "        if potential_risks.size == 0:\n",
    "            risk[f\"{method} {methods[0]}\"] = np.NaN\n",
    "        # If feasible, pick lowest risk from admissible solutions\n",
    "        else:\n",
    "            risk[f\"{method} {methods[0]}\"] = np.min(potential_risks)\n",
    "\n",
    "    return risk\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_prc_combined(in_dataset: str = 'cifar10',\n",
    "                      out_dataset: str = 'cifar100',\n",
    "                      methods: List[str] = ['msp', 'odin']) -> Dict[str, float]:\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "        Dict[str, float]: Dictionary of method -> aupr for the Double Score setting.\n",
    "    \"\"\"\n",
    "\n",
    "    # Dictionary holding mapping method -> aupr\n",
    "    auprs = {}\n",
    "\n",
    "    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "\n",
    "    # Load data of referential method\n",
    "    y_gt_1, y_pred_1, score_1 = load2(f'data/{in_dataset}_{methods[0]}/{in_dataset}.npz',\n",
    "                                      f'data/{in_dataset}_{methods[0]}/{out_dataset}.npz')\n",
    "\n",
    "    figure, axis = plt.subplots(1, 1)\n",
    "\n",
    "    # Compute metrics of the referential method\n",
    "    metric = Metrics(y_true=y_gt_1, y_pred=y_pred_1, score=score_1)\n",
    "    # Plot the Precision-Recall curve of the referential method\n",
    "    axis.plot(metric.coverage, metric.prec,\n",
    "              label=f\"{methods[0].upper()}\", color=colors[0])\n",
    "\n",
    "    colors = colors[1:]\n",
    "    for method, color in zip(methods[1:], colors):\n",
    "        # Load data of a method to combine with the referential method in a Double Score setting\n",
    "        y_gt_2, y_pred_2, score_2 = load2(f'data/{in_dataset}_{method}/{in_dataset}.npz',\n",
    "                                          f'data/{in_dataset}_{method}/{out_dataset}.npz')\n",
    "\n",
    "        # Initialize the Double Score method object\n",
    "        metric2 = MetricsDouble(\n",
    "            y_true=y_gt_1, y_pred=y_pred_1, scoreA=score_2, scoreB=score_1)\n",
    "        # Compute the precision and recall curve\n",
    "        prec, recall, aupr = metric2.prec_vs_recall()\n",
    "        axis.plot(\n",
    "            recall, prec, label=f\"{method.upper()} + {methods[0].upper()}: {np.round(aupr, 3)}\", color=color)\n",
    "        # Save achieved AUPR (area under the curve)\n",
    "        auprs[f\"{method} {methods[0]}\"] = aupr\n",
    "\n",
    "    axis.set_xlabel('Precision')\n",
    "    axis.set_ylabel('Recall')\n",
    "    axis.grid('on')\n",
    "    axis.legend()\n",
    "\n",
    "    return auprs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_risk_prec_rec_model(in_dataset: str = 'cifar10',\n",
    "                           out_dataset: str = 'cifar100',\n",
    "                           methods: List[str] = ['msp', 'odin'],\n",
    "                           min_coverage:float=0.2,\n",
    "                           min_precision:float=0.2) -> Dict[str, float]:\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "        Dict[str, float]: Dictionary of method -> risk, s.t. TPR > bound and precision > bound.\n",
    "    \"\"\"\n",
    "    \n",
    "    # Dictionary holding mapping method -> best risk s.t. TPR > bound and precision > bound\n",
    "    risk = {}\n",
    "\n",
    "    for method in methods:\n",
    "        # Load data\n",
    "        y_gt_1, y_pred_1, score_1 = load2(f'data/{in_dataset}_{method}/{in_dataset}.npz',\n",
    "                                          f'data/{in_dataset}_{method}/{out_dataset}.npz')\n",
    "        # Compute metrics\n",
    "        metric = Metrics(y_true=y_gt_1, y_pred=y_pred_1, score=score_1)        \n",
    "        sel_risk = metric.sel_risk\n",
    "        coverage = metric.coverage\n",
    "        precision = metric.prec\n",
    "        # Determine best achievable risk of the method, given some minimal precision and minimal recall\n",
    "        select = (precision>=min_precision) & (coverage>=min_coverage)\n",
    "        potential_risks = sel_risk[select]\n",
    "        # Infeasible\n",
    "        if potential_risks.size == 0:\n",
    "            risk[method] = np.NaN\n",
    "        # If feasible, pick lowest risk from admissible solutions\n",
    "        else:            \n",
    "            risk[method] = np.min(potential_risks)\n",
    "\n",
    "    return risk\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_risk_prec_rec_model_combined(in_dataset: str = 'cifar10',\n",
    "                                     out_dataset: str = 'cifar100',\n",
    "                                     methods: List[str] = ['msp', 'odin'],\n",
    "                                     min_coverage: float = 0.2,\n",
    "                                     min_precision: float = 0.2) -> Dict[str, float]:\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "        Dict[str, float]: Dictionary of method -> risk, s.t. TPR > bound and precision > bound in the Double Score setting.\n",
    "    \"\"\"\n",
    "\n",
    "    # Dictionary holding mapping method -> best risk s.t. TPR > bound and precision > bound\n",
    "    risk = {}\n",
    "\n",
    "    # Load data of referential method\n",
    "    y_gt_1, y_pred_1, score_1 = load2(f'data/{in_dataset}_{methods[0]}/{in_dataset}.npz',\n",
    "                                      f'data/{in_dataset}_{methods[0]}/{out_dataset}.npz')\n",
    "\n",
    "    for method in methods[1:]:\n",
    "        # Load data of a method to combine with the referential method in a Double Score setting\n",
    "        y_gt_2, y_pred_2, score_2 = load2(f'data/{in_dataset}_{method}/{in_dataset}.npz',\n",
    "                                          f'data/{in_dataset}_{method}/{out_dataset}.npz')\n",
    "        # Initialize the Double Score method object\n",
    "        metric2 = MetricsDouble(\n",
    "            y_true=y_gt_1, y_pred=y_pred_1, scoreA=score_2, scoreB=score_1)\n",
    "        # Compute achievable risk and recall (coverage) of the method, given some minimal precision\n",
    "        sel_risk, coverage = metric2.risk_coverage_curve_at_prec(min_precision)\n",
    "        # Determine best achievable risk of the method, given some minimal recall (coverage) and minimal precision\n",
    "        potential_risks = sel_risk[coverage >= min_coverage]\n",
    "        # Infeasible\n",
    "        if potential_risks.size == 0:\n",
    "            risk[f\"{method} {methods[0]}\"] = np.NaN\n",
    "        # If feasible, pick lowest risk from admissible solutions\n",
    "        else:\n",
    "            risk[f\"{method} {methods[0]}\"] = np.min(potential_risks)\n",
    "\n",
    "    return risk\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_roc_combined(in_dataset: str = 'cifar10',\n",
    "                      out_dataset: str = 'cifar100',\n",
    "                      methods: List[str] = ['msp', 'odin']) -> Dict[str, float]:\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "        Dict[str, float]: Dictionary of method -> AUROC in the Double Score setting.\n",
    "    \"\"\"\n",
    "\n",
    "    # Dictionary holding mapping method -> auroc\n",
    "    aurocs = {}\n",
    "\n",
    "    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "\n",
    "    # Load data of referential method\n",
    "    y_gt_1, y_pred_1, score_1 = load2(f'data/{in_dataset}_{methods[0]}/{in_dataset}.npz',\n",
    "                                      f'data/{in_dataset}_{methods[0]}/{out_dataset}.npz')\n",
    "\n",
    "    figure, axis = plt.subplots(1, 1)\n",
    "    # Compute metrics of the referential method\n",
    "    metric = Metrics(y_true=y_gt_1, y_pred=y_pred_1, score=score_1)\n",
    "    # Plot ROC curve of the referential method\n",
    "    axis.plot(metric.fpr, metric.coverage,\n",
    "              label=f\"{methods[0].upper()}\", color=colors[0])\n",
    "\n",
    "    colors = colors[1:]\n",
    "    for method, color in zip(methods[1:], colors):\n",
    "        # Load data of a method to combine with the referential method in a Double Score setting\n",
    "        y_gt_2, y_pred_2, score_2 = load2(f'data/{in_dataset}_{method}/{in_dataset}.npz',\n",
    "                                          f'data/{in_dataset}_{method}/{out_dataset}.npz')\n",
    "        # Initialize the Double Score method object\n",
    "        metric2 = MetricsDouble(\n",
    "            y_true=y_gt_1, y_pred=y_pred_1, scoreA=score_2, scoreB=score_1)\n",
    "        # Compute the ROC curve\n",
    "        tpr, fpr, auroc = metric2.tpr_vs_fpr()\n",
    "        axis.plot(\n",
    "            fpr, tpr, label=f\"{method.upper()} + {methods[0].upper()}: {np.round(auroc, 3)}\", color=color)\n",
    "        # Save achieved AUROC (area under the curve)\n",
    "        aurocs[f\"{method} {methods[0]}\"] = auroc\n",
    "\n",
    "    axis.set_xlabel('FPR')\n",
    "    axis.set_ylabel('TPR')\n",
    "    axis.grid('on')\n",
    "    axis.legend()\n",
    "\n",
    "    return aurocs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_ccr_fpr_combined(in_dataset: str = 'cifar10',\n",
    "                          out_dataset: str = 'cifar100',\n",
    "                          methods: List[str] = ['msp', 'odin']) -> Dict[str, float]:\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "        Dict[str, float]: Dictionary of method -> OSCR in the Double Score setting.\n",
    "    \"\"\"\n",
    "\n",
    "    # Dictionary holding mapping method -> oscr\n",
    "    oscrs = {}\n",
    "\n",
    "    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "    \n",
    "    # Load data of referential method\n",
    "    y_gt_1, y_pred_1, score_1 = load2(f'data/{in_dataset}_{methods[0]}/{in_dataset}.npz',\n",
    "                                      f'data/{in_dataset}_{methods[0]}/{out_dataset}.npz')\n",
    "    figure, axis = plt.subplots(1, 1)\n",
    "    # Compute metrics of referential method\n",
    "    metric = Metrics(y_true=y_gt_1, y_pred=y_pred_1, score=score_1)\n",
    "    axis.plot(metric.fpr, metric.ccr,\n",
    "              label=f\"{methods[0].upper()}\", color=colors[0])\n",
    "\n",
    "    colors = colors[1:]\n",
    "    for method, color in zip(methods[1:], colors):\n",
    "        # Load data of a method to combine with the referential method in a Double Score setting\n",
    "        y_gt_2, y_pred_2, score_2 = load2(f'data/{in_dataset}_{method}/{in_dataset}.npz',\n",
    "                                          f'data/{in_dataset}_{method}/{out_dataset}.npz')\n",
    "        # Initialize the Double Score method object\n",
    "        metric2 = MetricsDouble(\n",
    "            y_true=y_gt_1, y_pred=y_pred_1, scoreA=score_2, scoreB=score_1)\n",
    "        # Compute the CCR-FPR curve\n",
    "        ccr, fpr, oscr = metric2.ccr_vs_fpr()\n",
    "        axis.plot(\n",
    "            fpr, ccr, label=f\"{method.upper()} + {methods[0].upper()}: {np.round(oscr, 3)}\", color=color)\n",
    "        # Save achieved OSCR (area under the curve)\n",
    "        oscrs[f\"{method} {methods[0]}\"] = oscr\n",
    "\n",
    "    axis.set_xlabel('FPR')\n",
    "    axis.set_ylabel('CCR')\n",
    "    axis.grid('on')\n",
    "    axis.legend()\n",
    "\n",
    "    return oscrs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare ID/OOD dataset pairs\n",
    "id_ood_pairs = []\n",
    "\n",
    "# ID cifar10\n",
    "for dataset in ['cifar100',\n",
    "                'mnist',\n",
    "                'place365',\n",
    "                'svhn',\n",
    "                'texture',\n",
    "                'tin']:\n",
    "    id_ood_pairs.append(('cifar10', dataset))\n",
    "\n",
    "# ID mnist\n",
    "for dataset in ['notmnist',\n",
    "                'fashionmnist',\n",
    "                'cifar10',\n",
    "                'places365',\n",
    "                'texture',\n",
    "                'tin']:\n",
    "    id_ood_pairs.append(('mnist', dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare list of methods to compare, the first (index 0) method is reference which is combined with all others\n",
    "methods = ['msp',\n",
    "           'odin',\n",
    "           'mls',\n",
    "           'react',\n",
    "           'knn',\n",
    "           'vim']\n",
    "\n",
    "additional_methods = [j + ' ' +methods[0] for j in methods[1:]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize DataFrames to hold obtained results\n",
    "dfs = {}\n",
    "bounds = {}\n",
    "for pair in id_ood_pairs:\n",
    "    dfs[pair] = pd.DataFrame(index=methods+additional_methods, columns=[\n",
    "                             'AUROC', 'AUPR', 'OSCR', 'TPRFPRRISK', 'PRECTPRRISK'])\n",
    "    bounds[pair] = pd.DataFrame(index=methods, columns=['PREC', 'FPR'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify condition on TPR\n",
    "MIN_COVERAGE = 0.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs('figures/ROC_Single_Score/', exist_ok=True)\n",
    "os.makedirs('figures/ROC_Double_Score/', exist_ok=True)\n",
    "\n",
    "# Plot ROC curves\n",
    "# The worst methods FPR at the TPR condition is taken and used as a maximum bound for FPR in the remainder of the script\n",
    "for id, ood in id_ood_pairs:\n",
    "    aurocs, fpr_bound = plot_roc(methods=methods,\n",
    "             in_dataset=id,\n",
    "             out_dataset=ood,\n",
    "             coverage_bound=MIN_COVERAGE)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'figures/ROC_Single_Score/id_{id}_ood_{ood}.png', dpi=300)\n",
    "    plt.close()\n",
    "    \n",
    "    aurocs_combined = plot_roc_combined(methods=methods,\n",
    "             in_dataset=id,\n",
    "             out_dataset=ood)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'figures/ROC_Double_Score/id_{id}_ood_{ood}.png', dpi=300)\n",
    "    plt.close()\n",
    "    \n",
    "    aurocs.update(aurocs_combined)    \n",
    "    dfs[(id, ood)] = dfs[(id, ood)].assign(AUROC=aurocs.values())\n",
    "    bounds[(id, ood)] = bounds[(id, ood)].assign(FPR=fpr_bound.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs('figures/PRC_Single_Score/', exist_ok=True)\n",
    "os.makedirs('figures/PRC_Double_Score/', exist_ok=True)\n",
    "\n",
    "# Plot PR curves\n",
    "# The worst methods Precision at the TPR condition is taken and used as a minimum bound for precision in the remainder of the script\n",
    "for id, ood in id_ood_pairs:\n",
    "    auprs, prec_bound = plot_prc(methods=methods,\n",
    "             in_dataset=id,\n",
    "             out_dataset=ood,\n",
    "             coverage_bound=MIN_COVERAGE)\n",
    "    \n",
    "    bounds[(id, ood)] = bounds[(id, ood)].assign(PREC=prec_bound.values())\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'figures/PRC_Single_Score/id_{id}_ood_{ood}.png', dpi=300)\n",
    "    plt.close()\n",
    "    \n",
    "    auprs_combined = plot_prc_combined(methods=methods,\n",
    "             in_dataset=id,\n",
    "             out_dataset=ood)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'figures/PRC_Double_Score/id_{id}_ood_{ood}.png', dpi=300)\n",
    "    plt.close()\n",
    "    \n",
    "    auprs.update(auprs_combined)    \n",
    "    dfs[(id, ood)] = dfs[(id, ood)].assign(AUPR=auprs.values())\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs('figures/CCR_Single_Score/', exist_ok=True)\n",
    "os.makedirs('figures/CCR_Double_Score/', exist_ok=True)\n",
    "\n",
    "# Plot CCR-FPR\n",
    "for id, ood in id_ood_pairs:\n",
    "    oscrs = plot_ccr_fpr(methods=methods,\n",
    "             in_dataset=id,\n",
    "             out_dataset=ood)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'figures/CCR_Single_Score/id_{id}_ood_{ood}.png', dpi=300)\n",
    "    plt.close()\n",
    "    \n",
    "    oscrs_combined = plot_ccr_fpr_combined(methods=methods,\n",
    "             in_dataset=id,\n",
    "             out_dataset=ood)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f'figures/CCR_Double_Score/id_{id}_ood_{ood}.png', dpi=300)\n",
    "    plt.close()\n",
    "\n",
    "    oscrs.update(oscrs_combined)\n",
    "    dfs[(id, ood)] = dfs[(id, ood)].assign(OSCR=oscrs.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get Risks of TPR-FPR models\n",
    "for id, ood in id_ood_pairs:\n",
    "    best_risk = get_risk_tpr_fpr_model(methods=methods,\n",
    "                  in_dataset=id,\n",
    "                  out_dataset=ood,\n",
    "                  min_coverage=MIN_COVERAGE,\n",
    "                  max_fpr=np.nanmax(bounds[(id, ood)]['FPR']))\n",
    "    \n",
    "    best_risk_combined = get_risk_tpr_fpr_model_combined(methods=methods,\n",
    "                  in_dataset=id,\n",
    "                  out_dataset=ood,\n",
    "                  min_coverage=MIN_COVERAGE,\n",
    "                  max_fpr=np.nanmax(bounds[(id, ood)]['FPR']))\n",
    "    \n",
    "    best_risk.update(best_risk_combined)    \n",
    "    dfs[(id, ood)] = dfs[(id, ood)].assign(TPRFPRRISK=best_risk.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get Risks of Precision-Recall models\n",
    "for id, ood in id_ood_pairs:\n",
    "    best_risk = get_risk_prec_rec_model(methods=methods,\n",
    "                  in_dataset=id,\n",
    "                  out_dataset=ood,\n",
    "                  min_coverage=MIN_COVERAGE,\n",
    "                  min_precision=np.nanmin(bounds[(id, ood)]['PREC']))\n",
    "    \n",
    "    best_risk_combined = get_risk_prec_rec_model_combined(methods=methods,\n",
    "                  in_dataset=id,\n",
    "                  out_dataset=ood,\n",
    "                  min_coverage=MIN_COVERAGE,\n",
    "                  min_precision=np.nanmin(bounds[(id, ood)]['PREC']))\n",
    "    \n",
    "    best_risk.update(best_risk_combined)\n",
    "    dfs[(id, ood)] = dfs[(id, ood)].assign(PRECTPRRISK=best_risk.values())\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_file = open(\"results.html\", \"w\")\n",
    "\n",
    "for key in dfs:\n",
    "    df = dfs[key]\n",
    "    df.to_pickle(f\"{key}.pkl\")\n",
    "    bound = bounds[key]\n",
    "    bound.to_pickle(f\"{key}_bounds.pkl\")\n",
    "    df = df.round(decimals=6)\n",
    "\n",
    "    text_file.write(f\"<h1>{key}</h1>\")\n",
    "    text_file.write(f\"<h2>Coverage >= {MIN_COVERAGE}, Precision >= \" + f\"{np.round(np.nanmin(bound['PREC']), 3)}\" + \", FPR <= \"  + f\"{np.round(np.nanmax(bound['FPR']), 3)}</h2>\")\n",
    "    text_file.write(df.to_html())    \n",
    "    \n",
    "text_file.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.16 ('facis_conda_env')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "20826f349a75d78172b81a1f065c11f0e5559647a8d56f3e666354192a06c4c0"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
