{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb454aa2-a9f5-4988-9d07-8e3950a00cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys; sys.path.append(\"../\") # For relative imports\n",
    "\n",
    "import glob\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "from utils.conformal_utils import *\n",
    "from utils.experiment_utils import get_inputs_folder, get_outputs_folder, get_figs_folder\n",
    "\n",
    "plt.rcParams.update({\n",
    "    'font.size': 16,        # base font size\n",
    "    'axes.titlesize': 18,   # subplot titles\n",
    "    'axes.labelsize': 16,   # x/y labels\n",
    "    'legend.fontsize': 16,  # legend text\n",
    "    'xtick.labelsize': 16,  # tick labels\n",
    "    'ytick.labelsize': 16,\n",
    "\n",
    "})\n",
    "# use tex with matplotlib\n",
    "plt.rc('text', usetex=True)\n",
    "plt.rc('font', family='serif')\n",
    "plt.rcParams['text.latex.preamble'] = r'\\usepackage{amsmath}'\n",
    "\n",
    "dataset_names = {\n",
    "    \"plantnet\": \"Pl@ntNet-300K\",\n",
    "    \"plantnet-trunc\": \"Pl@ntNet-300K (truncated)\",\n",
    "    \"inaturalist\": \"iNaturalist-2018\",\n",
    "    \"inaturalist-trunc\": \"iNaturalist-2018 (truncated)\",\n",
    "}\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29abda10-239f-4f0c-ad62-48de44bb2458",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Choose dataset to create figures for\n",
    "# dataset = 'plantnet'\n",
    "# dataset = 'plantnet-trunc'\n",
    "# dataset = 'inaturalist'\n",
    "dataset = 'inaturalist-trunc'\n",
    "\n",
    "methods = ['standard', 'classwise', 'classwise-exact', 'clustered', 'prevalence-adjusted'] + \\\n",
    "            [f'fuzzy-rarity-{bw}' for bw in [1e-16, 1e-12, 1e-8, 1e-6, 0.0001, 0.001, 0.01, .1 , 10, 1000]] +\\\n",
    "            [f'fuzzy-RErarity-{bw}' for bw in [1e-16, 1e-12, 1e-8, 1e-6, 0.0001, 0.001, 0.01, .1 , 10, 1000]] +\\\n",
    "            [f'fuzzy-READDrarity-{bw}' for bw in [1e-16, 1e-12, 1e-8, 1e-6, 0.0001, 0.001, 0.01, .1 , 10, 1000]] +\\\n",
    "            [f'cvx-cw_weight={w}' for w in [0, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99 , 0.999, 1]] +\\\n",
    "            [f'monotonic-cvx-cw_weight={w}' for w in 1 - np.array([0, .001, .01, .025, .05, .1, .15, .2, .4, .6, .8, 1])]\n",
    "\n",
    "\n",
    "alphas = [0.2, 0.1, 0.05, 0.01]\n",
    "\n",
    "score = 'softmax'\n",
    "\n",
    "# Load in paths from folders.json\n",
    "results_folder = get_outputs_folder()\n",
    "fig_folder = get_figs_folder()\n",
    "\n",
    "os.makedirs(f'{fig_folder}/{dataset}', exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0642256",
   "metadata": {},
   "outputs": [],
   "source": [
    "f'{fig_folder}/{dataset}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "408730be-6bc9-4e8b-a116-7ed6ca79eae3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load test labels\n",
    "test_labels = np.load(f'/home-warm/plantnet/conformal_cache/train_models/best-{dataset}-model_test_labels.npy')\n",
    "num_classes = np.max(test_labels) + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfb54053-bc3b-4120-a08a-43a17bfea021",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "# Load metrics\n",
    "\n",
    "def load_metrics(dataset, alpha, method_name, score='softmax'):\n",
    "    with open(f'{results_folder}/{dataset}_{score}_alpha={alpha}_{method_name}.pkl', 'rb') as f:\n",
    "        metrics = pickle.load(f)\n",
    "    # Extract set size quantiles for easy access later\n",
    "    metrics['set_size_metrics']['median'] = metrics['set_size_metrics']['[.25, .5, .75, .9] quantiles'][1]\n",
    "    metrics['set_size_metrics']['quantile90'] = metrics['set_size_metrics']['[.25, .5, .75, .9] quantiles'][3]\n",
    "    return metrics\n",
    "\n",
    "\n",
    "all_res = {}\n",
    "\n",
    "for alpha in alphas:\n",
    "    res = {}\n",
    "    for method in methods:\n",
    "        # print(method)\n",
    "        res[method] = load_metrics(dataset, alpha, method)\n",
    "    all_res[f'alpha={alpha}'] = res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b57ff38b-6309-49ea-a416-806be27dbf5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_class_cond_decision_accuracy(labels, is_covered, raw_set_sizes):\n",
    "    # (assuming a random decision maker)\n",
    "    num_classes = np.max(labels) + 1\n",
    "    decision_acc = np.zeros((num_classes,))\n",
    "    for k in range(num_classes):\n",
    "        idx = labels == k\n",
    "        # P(choose correct label) = 0 if label not in set\n",
    "        # P(choose correct label) = 1/(set size) if label in set\n",
    "        p_correct = is_covered[idx] * (1/raw_set_sizes[idx])\n",
    "        p_correct[np.isnan(p_correct)] = 0 # nans are due to empty sets, so replace with 0\n",
    "        decision_acc[k] = np.mean(p_correct)\n",
    "        if np.isnan(decision_acc[k]):\n",
    "            pdb.set_trace()\n",
    "\n",
    "    return decision_acc\n",
    "\n",
    "def compute_class_cond_decision_accuracy_for_method(res, method, labels):\n",
    "    is_covered = res[method]['coverage_metrics']['is_covered']\n",
    "    raw_set_sizes = res[method]['coverage_metrics']['raw_set_sizes']\n",
    "    \n",
    "    return compute_class_cond_decision_accuracy(labels, is_covered, raw_set_sizes)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a5b0e0d-bb12-4953-9591-fcad19578b40",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add class-conditional decision accuracies to metrics\n",
    "for res in all_res.values():\n",
    "    for method in methods:\n",
    "        dec_acc = compute_class_cond_decision_accuracy_for_method(res, method, test_labels)\n",
    "        res[method]['class-cond-decision-accuracy'] = dec_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "885fba4a-09b7-439f-bba6-603d4c09956b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_decision_acc_plot(res, method, color):\n",
    "\n",
    "    # Sort classes by decision accuracy of Standard CP\n",
    "    # idx = np.argsort(res['standard']['class-cond-decision-accuracy'])\n",
    "    \n",
    "    # Sort classes by class cond acc of Standard CP\n",
    "    idx = np.argsort(res['standard']['coverage_metrics']['raw_class_coverages'])[::-1]\n",
    "    \n",
    "    # # # Sort classes by train prevalence\n",
    "    # train_labels_path = f'/home-warm/plantnet/conformal_cache/train_models/{dataset}_train_labels.npy'\n",
    "    # train_labels = np.load(train_labels_path)\n",
    "    # train_class_distr = np.array([np.sum(train_labels == k) for k in range(num_classes)]) / len(train_labels) \n",
    "    # idx = np.argsort(train_class_distr)\n",
    "    \n",
    "    # fig, ax = plt.subplots(figsize=(10,2.5))\n",
    "    fig, ax = plt.subplots(figsize=(5,2.5))\n",
    "    \n",
    "    method_to_name = {'standard': 'Standard', \n",
    "                      'classwise': 'Classwise', \n",
    "                      'fuzzy-RErarity-0.0001': 'Fuzzy',\n",
    "                      'prevalence-adjusted': 'Standard with PAS',\n",
    "                      'cvx-cw_weight=0.99': 'Interp-Q'}\n",
    "    \n",
    "    # for method, color in zip(['standard', 'classwise', 'fuzzy-RErarity-0.0001', 'prevalence-adjusted'],\n",
    "    #                          ['tab:blue', 'tab:orange', 'tab:green', 'tab:purple']):\n",
    "    for method, color in zip(['standard', 'classwise', method],\n",
    "                             ['tab:blue', 'tab:red', color]):\n",
    "    # for method, color in zip(['standard', 'classwise', 'prevalence-adjusted'],\n",
    "    #                          ['tab:blue', 'tab:red', 'tab:green']):    \n",
    "        # fig, ax = plt.subplots(figsize=(9,1))\n",
    "    \n",
    "        if method == 'classwise':\n",
    "            zorder = 0\n",
    "            alpha = 0.1\n",
    "        else:\n",
    "            zorder = 2\n",
    "            alpha = 0.5\n",
    "\n",
    "        ax.plot(res[method]['class-cond-decision-accuracy'][idx], color=color, alpha=0.7, zorder=zorder,\n",
    "               label=f'{method_to_name[method]}, random guesser')\n",
    "        ax.plot(res[method]['coverage_metrics']['raw_class_coverages'][idx], color=color, alpha=0.7, \n",
    "                linestyle = ':',\n",
    "                zorder=zorder, label=f'{method_to_name[method]}, expert verifier')\n",
    "    \n",
    "        ax.fill_between(np.arange(num_classes), \n",
    "                         y1=res[method]['class-cond-decision-accuracy'][idx],\n",
    "                         y2=res[method]['coverage_metrics']['raw_class_coverages'][idx], \n",
    "                         label=f'{method_to_name[method]}, verifier-guesser mix',\n",
    "                         color=color, alpha=alpha)\n",
    "        # ax.set_title(method)\n",
    "        \n",
    "        ax.set_xlim(0, num_classes-1)\n",
    "        ax.set_ylabel('Decision accuracy')\n",
    "        # ax.set_xlabel('Class (sorted by rand decision acc of Standard CP)')\n",
    "        ax.set_xlabel('Class (sorted by $\\\\hat{c}_y$ of Standard CP)')\n",
    "        # ax.set_xlabel('Class (sorted by train prevalence)')\n",
    "\n",
    "        ax.set_title(dataset_names[dataset])\n",
    "    \n",
    "    fig_path = f'{fig_folder}/{dataset}/{dataset}_{score}_{method}_decision-acc.pdf'\n",
    "    \n",
    "    \n",
    "    # Save \n",
    "    plt.savefig(fig_path, bbox_inches='tight')\n",
    "    print('Saved plot to', fig_path)\n",
    "    ax.legend(loc='upper left', bbox_to_anchor=(1,1.05), fontsize=11)\n",
    "    new_path = fig_path.replace('.pdf', '_WITH_LEGEND.pdf')\n",
    "    plt.savefig(new_path, bbox_inches='tight')\n",
    "    print('Saved plot to', new_path)\n",
    "\n",
    "# Plot for alpha = 0.1\n",
    "res = all_res['alpha=0.1']\n",
    "\n",
    "method = 'fuzzy-RErarity-0.0001'\n",
    "color = 'tab:purple'\n",
    "make_decision_acc_plot(res, method, color)\n",
    "\n",
    "method = 'prevalence-adjusted'\n",
    "color = 'tab:green'\n",
    "make_decision_acc_plot(res, method, color)\n",
    "\n",
    "method = 'cvx-cw_weight=0.99'\n",
    "color = 'dodgerblue'\n",
    "make_decision_acc_plot(res, method, color)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2fd1b7c-39a8-4797-84ee-d28bfe25d603",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot for alpha = 0.1\n",
    "res = all_res['alpha=0.1']\n",
    "\n",
    "# # Sort classes by train prevalence\n",
    "# train_labels_path = f'/home-warm/plantnet/conformal_cache/train_models/{dataset}_train_labels.npy'\n",
    "# train_labels = np.load(train_labels_path)\n",
    "# train_class_distr = np.array([np.sum(train_labels == k) for k in range(num_classes)]) / len(train_labels) \n",
    "# idx = np.argsort(train_class_distr)\n",
    "\n",
    "# Sort classes by decision accuracy of Standard CP\n",
    "idx = np.argsort(res['standard']['class-cond-decision-accuracy'])\n",
    "\n",
    "# for method in methods:\n",
    "for method in ['standard', 'classwise', 'fuzzy-RErarity-0.0001', 'prevalence-adjusted']:\n",
    "    plt.plot(res[method]['class-cond-decision-accuracy'][idx], 'o', markersize=2, label=method, alpha=0.6)\n",
    "\n",
    "# metric = 'class-cond-decision-accuracy'\n",
    "# plt.figure(figsize=(10,3))\n",
    "# plt.plot(res['standard'][metric][idx], label='standard', alpha=1, zorder=10)\n",
    "# plt.plot(res['classwise'][metric][idx], label='classwise', alpha=1)\n",
    "# plt.plot(res['fuzzy-RErarity-0.001'][metric][idx], # 'o', markersize=2,\n",
    "#          label='fuzzy-RErarity-0.001', alpha=0.6)\n",
    "# plt.plot(res['prevalence-adjusted'][metric][idx], 'o', markersize=2,\n",
    "#          label='prevalence-adjusted', alpha=0.6)\n",
    "\n",
    "plt.ylabel('Class-conditional decision accuracy')\n",
    "plt.xlabel('Class (sorted by decision accuracy of Standard CP)')\n",
    "plt.title('Random decision maker')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb547ee7-e11e-41ce-993d-8b330b9531d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot for alpha = 0.1\n",
    "res = all_res['alpha=0.1']\n",
    "\n",
    "# Sort classes by class cond coverage of Standard CP\n",
    "idx = np.argsort(res['standard']['coverage_metrics']['raw_class_coverages'])\n",
    "\n",
    "# for method in methods:\n",
    "for method in ['standard', 'classwise', 'fuzzy-RErarity-0.0001', 'prevalence-adjusted']:\n",
    "    plt.plot(res[method]['coverage_metrics']['raw_class_coverages'][idx], 'o', markersize=2,\n",
    "             label=method, alpha=0.6)\n",
    "\n",
    "# metric = 'raw_class_coverages'\n",
    "# plt.figure(figsize=(10,3))\n",
    "# plt.plot(res['standard']['coverage_metrics'][metric][idx], label='standard', alpha=1, zorder=10)\n",
    "# plt.plot(res['classwise']['coverage_metrics'][metric][idx], label='classwise', alpha=1)\n",
    "# plt.plot(res['fuzzy-RErarity-0.001']['coverage_metrics'][metric][idx], # 'o', markersize=2,\n",
    "#          label='fuzzy-RErarity-0.001', alpha=0.6)\n",
    "# plt.plot(res['prevalence-adjusted']['coverage_metrics'][metric][idx], 'o', markersize=2,\n",
    "#          label='prevalence-adjusted', alpha=0.6)\n",
    "\n",
    "plt.ylabel('Class-conditional coverage')\n",
    "plt.xlabel('Class (sorted by class-cond cov of Standard CP)')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b11775e-c23a-47d9-8184-b1d04dc35922",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Macro decision accuracy\n",
    "\n",
    "# Plot for alpha = 0.1\n",
    "res = all_res['alpha=0.1']\n",
    "\n",
    "p_rands = np.linspace(0,1,11) # P(random decision maker)\n",
    "\n",
    "for method in ['standard', 'classwise', 'fuzzy-RErarity-0.0001', 'prevalence-adjusted']:\n",
    "    decision_accs = np.zeros(p_rands.shape)\n",
    "    for i, p_rand in enumerate(p_rands):\n",
    "        acc_random = res[method]['class-cond-decision-accuracy']\n",
    "        acc_discerning = res[method]['coverage_metrics']['raw_class_coverages']\n",
    "        decision_accs[i] = np.mean(p_rand * acc_random + (1-p_rand) * acc_discerning)\n",
    "    plt.plot(p_rands, decision_accs, label=method)\n",
    "\n",
    "plt.xlabel('P(random)')\n",
    "plt.ylabel('Macro decision accuracy')\n",
    "plt.legend()\n",
    "plt.show()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47d6c821-bdf6-4dfe-a0fc-d60b03a59023",
   "metadata": {},
   "outputs": [],
   "source": [
    "# res.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6dec2cc-5eaa-4a88-843c-e4570a4ea813",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch-oban",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
