{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b10076ec-22bc-4d8a-bb6d-b408af13795f",
   "metadata": {},
   "source": [
    "This notebook generates the Pareto plots in the paper using the saved results obtained by running `run_get_results.sh`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5b18fc4-d576-4149-94c0-559ac84aba25",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys; sys.path.append(\"../\") # For relative imports\n",
    "\n",
    "import glob\n",
    "import os\n",
    "import pandas as pd\n",
    "import pickle\n",
    "\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset\n",
    "from matplotlib.ticker import LogFormatter, LogLocator, MultipleLocator\n",
    "\n",
    "from utils.conformal_utils import *\n",
    "from utils.experiment_utils import get_inputs_folder, get_outputs_folder, get_figs_folder\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "plt.rcParams.update({\n",
    "    'font.size': 16,        # base font size\n",
    "    'axes.titlesize': 18,   # subplot titles\n",
    "    'axes.labelsize': 16,   # x/y labels\n",
    "    'legend.fontsize': 16,  # legend text\n",
    "    'xtick.labelsize': 16,  # tick labels\n",
    "    'ytick.labelsize': 16,\n",
    "\n",
    "})\n",
    "# use tex with matplotlib\n",
    "plt.rc('text', usetex=True)\n",
    "plt.rc('font', family='serif')\n",
    "plt.rcParams['text.latex.preamble'] = r'\\usepackage{amsmath}'\n",
    "\n",
    "dataset_names = {\n",
    "    \"plantnet\": \"Pl@ntNet-300K\",\n",
    "    \"plantnet-trunc\": \"Pl@ntNet-300K (truncated)\",\n",
    "    \"inaturalist\": \"iNaturalist-2018\",\n",
    "    \"inaturalist-trunc\": \"iNaturalist-2018 (truncated)\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8f3f8e5-6296-4749-be6f-fd9b08f98cef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load in paths from folders.json\n",
    "inputs_folder = get_inputs_folder()\n",
    "results_folder = get_outputs_folder()\n",
    "fig_folder = get_figs_folder()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d92c49a",
   "metadata": {},
   "outputs": [],
   "source": [
    "    cw_weights = [0, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99 , 0.999, 1]\n",
    "    rarity_bandwidths = [1e-30, 1e-15, 1e-10, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 10, 1000]\n",
    "    random_bandwidths = [1e-30, 1e-15, 1e-10, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 10, 1000]\n",
    "    quantile_bandwidths = [1e-30, 1e-15, 1e-10, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 10, 1000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba858d8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "rarity_bandwidths = [1e-30, 1e-15, 1e-10, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 10, 1000]\n",
    "random_bandwidths = [1e-30, 1e-15, 1e-10, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 10, 1000]\n",
    "quantile_bandwidths = [1e-30, 1e-15, 1e-10, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 10, 1000]\n",
    "cw_weights = [0, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99 , 0.999, 1] # CHANGED!\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b61581df-9429-4472-a437-4054302f9c0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_train_weighted_average_set_size(dataset, metrics, train_class_distr, test_labels):\n",
    "    num_classes = np.max(test_labels) + 1\n",
    "    \n",
    "    # Get average set size by class\n",
    "    set_sizes = metrics['coverage_metrics']['raw_set_sizes']\n",
    "    avg_size_by_class = np.array([np.mean(set_sizes[test_labels == k]) for k in range(num_classes)])\n",
    "\n",
    "    return np.sum(train_class_distr * avg_size_by_class)\n",
    "\n",
    "def load_one_result(dataset, alpha, method_name, score='softmax',\n",
    "                train_class_distr=None, test_labels=None):\n",
    "    \n",
    "    with open(f'{results_folder}/{dataset}_{score}_alpha={alpha}_{method_name}.pkl', 'rb') as f:\n",
    "        metrics = pickle.load(f)\n",
    "\n",
    "    # Compute train-weighted average set size\n",
    "    # Compute average set size by class, then weight\n",
    "    if (train_class_distr is not None) and (test_labels is not None):\n",
    "        metrics['set_size_metrics']['train_mean'] = compute_train_weighted_average_set_size(dataset, \n",
    "                                                                                            metrics, \n",
    "                                                                                            train_class_distr, \n",
    "                                                                                            test_labels)\n",
    "    \n",
    "    return metrics\n",
    "\n",
    "def load_all_results(dataset, alphas, methods, score='softmax'):\n",
    "    # For truncated datasets, we need to load these in to compute train-weighted average set size\n",
    "    if dataset.endswith('-trunc'): \n",
    "        train_labels_path = f'{inputs_folder}/{dataset}_train_labels.npy'\n",
    "        train_labels = np.load(train_labels_path)\n",
    "        num_classes = np.max(train_labels) + 1\n",
    "        train_class_distr = np.array([np.sum(train_labels == k) for k in range(num_classes)]) / len(train_labels) \n",
    "\n",
    "        test_labels = test_labels = np.load(f'{inputs_folder}/best-{dataset}-model_test_labels.npy')\n",
    "        \n",
    "    all_res = {}\n",
    "    for alpha in alphas:\n",
    "        res = {}\n",
    "        for method in methods:\n",
    "            if dataset.endswith('-trunc'): # Compute train-weighted average set size\n",
    "                res[method] = load_one_result(dataset, alpha, method, score=score,\n",
    "                                           train_class_distr=train_class_distr, test_labels=test_labels)\n",
    "            else:\n",
    "                res[method] = load_one_result(dataset, alpha, method, score=score)\n",
    "        all_res[f'alpha={alpha}'] = res\n",
    "\n",
    "    return all_res\n",
    "\n",
    "def plot_set_size_vs_cov_metric(\n",
    "    all_res,\n",
    "    coverage_metric,\n",
    "    alphas=None,\n",
    "    set_size_metric='mean',\n",
    "    ax=None,\n",
    "    show_legend=True,\n",
    "    markersizes=None,\n",
    "    label_prefix='',\n",
    "    add_inset=True,\n",
    "    inset_loc='upper right',\n",
    "    inset_lims=(0.1, 0.3, 1, 15),\n",
    "    inset_width=\"50%\",\n",
    "    inset_pad=0\n",
    "):\n",
    "    # --- prepare main axis ---\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(3.9, 2.3))\n",
    "\n",
    "    # --- determine alphas ---\n",
    "    if alphas is None:\n",
    "        alphas = sorted(\n",
    "            float(k.split('=')[1]) for k in all_res if k.startswith('alpha=')\n",
    "        )\n",
    "\n",
    "    # --- markersizes ---\n",
    "    n = len(alphas)\n",
    "    if markersizes is None:\n",
    "        markersizes = [5, 6, 7, 8]\n",
    "\n",
    "    # --- fixed definitions ---\n",
    "    # cw_weights = 1 - np.array([0, .001, .01, .025, .05, .1, .15, .2, .4, .6, .8, 1])\n",
    "\n",
    "    core_methods = [\n",
    "        ('standard',             'Standard',       'blue',      'X'),\n",
    "        ('classwise',            'Classwise',      'red',       'X'),\n",
    "        ('classwise-exact',      'Exact Classwise','magenta',  'x'),\n",
    "        ('clustered',            'Clustered',      'limegreen', 'x'),\n",
    "        ('prevalence-adjusted',  'Standard w. PAS','green',    '^'),\n",
    "        ('standard-softmax',     'Standard w. softmax', 'blue',  'x'),\n",
    "    ]\n",
    "\n",
    "    # --- plot main curves ---\n",
    "    for alpha, ms in zip(alphas, markersizes):\n",
    "        res = all_res.get(f'alpha={alpha}', {})\n",
    "        suffix = f'{label_prefix}, $\\\\alpha={alpha}$'\n",
    "\n",
    "        # core scatter methods (only if present)\n",
    "        for key, label, color, mk in core_methods:\n",
    "            if key in res:\n",
    "                ax.plot(\n",
    "                    res[key]['coverage_metrics'][coverage_metric],\n",
    "                    res[key]['set_size_metrics'][set_size_metric],\n",
    "                    linestyle='', marker=mk, color=color,\n",
    "                    markersize=ms, alpha=0.8,\n",
    "                    label=(label + suffix),\n",
    "                    zorder=10\n",
    "                )\n",
    "\n",
    "        # convex interpolation (only weights that exist)\n",
    "        valid_w = [w for w in cw_weights if f'cvx-cw_weight={w}' in res]\n",
    "        if valid_w:\n",
    "            x_cvx = [res[f'cvx-cw_weight={w}']['coverage_metrics'][coverage_metric] for w in valid_w]\n",
    "            y_cvx = [res[f'cvx-cw_weight={w}']['set_size_metrics'][set_size_metric]     for w in valid_w]\n",
    "            ax.plot(\n",
    "                x_cvx, y_cvx,\n",
    "                '-o', color='dodgerblue', markersize=ms,\n",
    "                alpha=0.5,\n",
    "                label=('Interp-Q' + suffix)\n",
    "            )\n",
    "\n",
    "        # fuzzy rarity projection \n",
    "        for tag, mk_sym, alpha_val, label in [\n",
    "            ## Uncomment to use short name for Fuzzy\n",
    "            ('fuzzy-rarity',   'o', 0.3, 'Raw Fuzzy'),\n",
    "            ('fuzzy-RErarity', '^', 0.5, 'Fuzzy'),\n",
    "            # # Uncomment to use long name for Fuzzy\n",
    "            # ('fuzzy-rarity',   'o', 0.3, 'Raw Fuzzy-$\\\\Pi_{\\\\mathrm{prevalence}}$'),\n",
    "            # ('fuzzy-RErarity', '^', 0.5, 'Fuzzy-$\\\\Pi_{\\\\mathrm{prevalence}}$'),\n",
    "        ]:\n",
    "            valid_b = [b for b in rarity_bandwidths if f'{tag}-{b}' in res]\n",
    "            if valid_b:\n",
    "                xs = [res[f'{tag}-{b}']['coverage_metrics'][coverage_metric] for b in valid_b]\n",
    "                ys = [res[f'{tag}-{b}']['set_size_metrics'][set_size_metric]    for b in valid_b]\n",
    "                ax.plot(\n",
    "                    xs, ys,\n",
    "                    '-' + mk_sym, color='salmon', markersize=ms,\n",
    "                    alpha=alpha_val,\n",
    "                    label=(label + suffix)\n",
    "                )\n",
    "\n",
    "        # fuzzy random projection \n",
    "        for tag, mk_sym, alpha_val, label in [\n",
    "            ('fuzzy-random',   'o', 0.3, 'Raw Fuzzy-$\\\\Pi_{\\\\mathrm{random}}$'),\n",
    "            ('fuzzy-RErandom', '^', 0.5, 'Fuzzy-$\\\\Pi_{\\\\mathrm{random}}$'),\n",
    "        ]:\n",
    "            valid_b = [b for b in rarity_bandwidths if f'{tag}-{b}' in res]\n",
    "            if valid_b:\n",
    "                xs = [res[f'{tag}-{b}']['coverage_metrics'][coverage_metric] for b in valid_b]\n",
    "                ys = [res[f'{tag}-{b}']['set_size_metrics'][set_size_metric]    for b in valid_b]\n",
    "                ax.plot(\n",
    "                    xs, ys,\n",
    "                    '-' + mk_sym, color='tab:purple', markersize=ms,\n",
    "                    alpha=alpha_val,\n",
    "                    label=(label + suffix)\n",
    "                )\n",
    "                \n",
    "        # fuzzy quantile projection \n",
    "        for tag, mk_sym, alpha_val, label in [\n",
    "            ('fuzzy-quantile',   'o', 0.3, 'Raw Fuzzy-$\\\\Pi_{\\\\mathrm{quantile}}$'),\n",
    "            ('fuzzy-REquantile', '^', 0.5, 'Fuzzy-$\\\\Pi_{\\\\mathrm{quantile}}$'),\n",
    "        ]:\n",
    "            valid_b = [b for b in rarity_bandwidths if f'{tag}-{b}' in res]\n",
    "            if valid_b:\n",
    "                xs = [res[f'{tag}-{b}']['coverage_metrics'][coverage_metric] for b in valid_b]\n",
    "                ys = [res[f'{tag}-{b}']['set_size_metrics'][set_size_metric]    for b in valid_b]\n",
    "                ax.plot(\n",
    "                    xs, ys,\n",
    "                    '-' + mk_sym, color='gold', markersize=ms,\n",
    "                    alpha=alpha_val,\n",
    "                    label=(label + suffix)\n",
    "                )\n",
    "\n",
    "\n",
    "    # --- style main axis ---\n",
    "    ax.set_yscale('log')\n",
    "    ax.spines[['right', 'top']].set_visible(False)\n",
    "\n",
    "    # --- inset ---\n",
    "    if add_inset:\n",
    "        xmin, xmax, ymin, ymax = inset_lims\n",
    "        axins = inset_axes(\n",
    "            ax, width=inset_width, height=\"40%\", loc=inset_loc, borderpad=inset_pad\n",
    "        )\n",
    "        axins.patch.set_alpha(0.7)\n",
    "\n",
    "        # replot all alphas for all methods, with the same presence checks\n",
    "        for alpha, ms in zip(alphas, 0.5 * np.array(markersizes)):\n",
    "            res = all_res.get(f'alpha={alpha}', {})\n",
    "\n",
    "            # core\n",
    "            for key, _, color, mk in core_methods:\n",
    "                if key in res:\n",
    "                    axins.plot(\n",
    "                        res[key]['coverage_metrics'][coverage_metric],\n",
    "                        res[key]['set_size_metrics'][set_size_metric],\n",
    "                        linestyle='', marker=mk, color=color,\n",
    "                        markersize=ms, alpha=0.8, zorder=10\n",
    "                    )\n",
    "\n",
    "            # convex\n",
    "            valid_w = [w for w in cw_weights if f'cvx-cw_weight={w}' in res]\n",
    "            if valid_w:\n",
    "                axins.plot(\n",
    "                    [res[f'cvx-cw_weight={w}']['coverage_metrics'][coverage_metric] for w in valid_w],\n",
    "                    [res[f'cvx-cw_weight={w}']['set_size_metrics'][set_size_metric]    for w in valid_w],\n",
    "                    '-o', color='dodgerblue', markersize=ms, alpha=0.5\n",
    "                )\n",
    "\n",
    "            # fuzzy\n",
    "            for tag, mk_sym, alpha_val in [('fuzzy-rarity','o',0.3),('fuzzy-RErarity','^',0.5)]:\n",
    "                valid_b = [b for b in rarity_bandwidths if f'{tag}-{b}' in res]\n",
    "                if valid_b:\n",
    "                    axins.plot(\n",
    "                        [res[f'{tag}-{b}']['coverage_metrics'][coverage_metric] for b in valid_b],\n",
    "                        [res[f'{tag}-{b}']['set_size_metrics'][set_size_metric]    for b in valid_b],\n",
    "                        '-' + mk_sym, color='salmon', markersize=ms, alpha=alpha_val\n",
    "                    )\n",
    "\n",
    "        # zoom & ticks\n",
    "        axins.set_xlim(xmin, xmax)\n",
    "        axins.set_ylim(ymin, ymax)\n",
    "        yticks = np.arange(int(np.ceil(ymin)), int(ymax) + 1)\n",
    "        axins.set_yticks(yticks)\n",
    "        axins.set_yticklabels([str(int(y)) if (int(y) % 2 == 0) else '' for y in yticks])\n",
    "        for spine in axins.spines.values():\n",
    "            spine.set_edgecolor('grey')\n",
    "        axins.tick_params(axis='both', colors='grey', labelsize=8, pad=1)\n",
    "\n",
    "        if inset_loc == 'upper left':\n",
    "            mark_inset(ax, axins, loc1=1, loc2=3, fc=\"none\", ec=(0.1,0.1,0.1,0.2))\n",
    "        else:\n",
    "            mark_inset(ax, axins, loc1=2, loc2=4, fc=\"none\", ec=(0.1,0.1,0.1,0.2))\n",
    "\n",
    "    # --- legend & layout ---\n",
    "    if show_legend:\n",
    "        ax.legend(fontsize=6, bbox_to_anchor=(1, 1))\n",
    "    plt.tight_layout(pad=1.3)\n",
    "\n",
    "# Specify where to zoom in \n",
    "def get_inset_lims(dataset, coverage_metric):\n",
    "    xlims = {'plantnet': {'cov_below50': (0.1, 0.3), \n",
    "                          'undercov_gap': (0.1, 0.38),\n",
    "                          'macro_cov': (0.6, 0.9)},\n",
    "             'plantnet-trunc': {'cov_below50': (0.01, 0.17), \n",
    "                                'undercov_gap': (0.02, 0.2),\n",
    "                                'macro_cov': (0.77, 0.87)},\n",
    "             'inaturalist': {'cov_below50': (0.025, 0.16), \n",
    "                            'undercov_gap': (0.05, 0.175),\n",
    "                            'macro_cov': (0.71, 0.9)},\n",
    "             'inaturalist-trunc': {'cov_below50': (0.01, 0.07), \n",
    "                                   'undercov_gap': (0.03, 0.11),\n",
    "                                   'macro_cov': (0.75, 0.92)}\n",
    "            }\n",
    "    ylims = {'plantnet': (0.9, 8), \n",
    "             'plantnet-trunc': (0.9, 5),\n",
    "             'inaturalist': (2, 15),\n",
    "             'inaturalist-trunc': (1.1, 13)}\n",
    "\n",
    "    return *xlims[dataset][coverage_metric], *ylims[dataset]\n",
    "\n",
    "\n",
    "def generate_all_pareto_plots(dataset, score, alphas, methods, show_grid=False,\n",
    "                              save_suffix='', show_inset=None, legendfontsize=14):\n",
    "    \n",
    "    # Load pre-computed metrics\n",
    "    all_res = load_all_results(dataset, alphas, methods, score=score)\n",
    "\n",
    "    if score == 'PAS':\n",
    "        std_with_softmax = load_all_results(dataset, alphas, ['standard'], score='softmax')\n",
    "        for alpha in alphas:\n",
    "            all_res[f'{alpha=}']['standard-softmax'] = std_with_softmax[f'{alpha=}']['standard']\n",
    "\n",
    "    # Make plots\n",
    "    set_size_metric = 'train_mean' if dataset.endswith('-trunc') else 'mean'\n",
    "    print(f'{set_size_metric=}')\n",
    "\n",
    "    cov_metrics = ['cov_below50', 'undercov_gap', 'macro_cov',\n",
    "               'train_marginal_cov']\n",
    "    cov_metric_names = ['FracBelow50$\\\\%$', \n",
    "                       'UnderCovGap',\n",
    "                        'MacroCov',\n",
    "                       'MarginalCov']\n",
    "    fig, axes = plt.subplots(1, len(cov_metrics), figsize=(13, 3), sharey=True)\n",
    "    axes = axes.flatten()\n",
    "\n",
    "    for i, (coverage_metric, xlabel) in enumerate(zip(cov_metrics, cov_metric_names)):\n",
    "        ax = axes[i]\n",
    "\n",
    "        if (not show_inset) or (coverage_metric == 'train_marginal_cov') or (score == 'PAS'):\n",
    "            add_inset = False\n",
    "        else:\n",
    "            add_inset = True \n",
    "   \n",
    "        if add_inset:\n",
    "            inset_lims = get_inset_lims(dataset, coverage_metric)\n",
    "        else:\n",
    "            inset_lims = None\n",
    "    \n",
    "        if coverage_metric == 'macro_cov':\n",
    "            inset_loc = 'upper left'\n",
    "            inset_pad = 1\n",
    "        else:\n",
    "            inset_loc = 'upper right'\n",
    "            inset_pad = 1\n",
    "    \n",
    "        if coverage_metric == 'train_marginal_cov':\n",
    "            for a in alphas:\n",
    "                ax.axvline(1-a, linestyle='--', color='grey')\n",
    "                \n",
    "        plot_set_size_vs_cov_metric(\n",
    "            all_res,\n",
    "            coverage_metric=coverage_metric,\n",
    "            set_size_metric=set_size_metric,\n",
    "            alphas=alphas,\n",
    "            ax=ax,\n",
    "            add_inset=add_inset,\n",
    "            inset_loc=inset_loc,\n",
    "            inset_lims=inset_lims,\n",
    "            inset_width=\"50%\",\n",
    "            inset_pad=inset_pad,\n",
    "            show_legend=False\n",
    "        )\n",
    "        \n",
    "        ax.set_xlabel(xlabel)\n",
    "\n",
    "    # if dataset == 'plantnet':\n",
    "    #     axes[0].set_xlim(-.02, .61)\n",
    "    #     axes[1].set_xlim(-0.1, 0.5)\n",
    "    #     axes[2].set_xlim(.37, 1.1)\n",
    "    #     axes[3].set_xlim(.77, 1.01)\n",
    "\n",
    "    # Optionally plot transparent grid\n",
    "    if show_grid:\n",
    "        for ax in axes:\n",
    "            ax.grid(True, alpha=0.2)\n",
    "\n",
    "    if len(alphas) == 1:\n",
    "        axes[3].set_xlim(1-alphas[0]-.05, 1-alphas[0]+.05)\n",
    "    \n",
    "    axes[0].set_ylabel('Average set size')\n",
    "    plt.suptitle(dataset_names[dataset], y=1)\n",
    "    \n",
    "    fig_path = f'{fig_folder}/{dataset}/ALL_metrics_{dataset}_{score}{save_suffix}_pareto_NO_LEGEND.pdf'\n",
    "    os.makedirs(f'{fig_folder}/{dataset}', exist_ok=True)\n",
    "    \n",
    "    # Save version without legend\n",
    "    plt.savefig(fig_path)\n",
    "    \n",
    "    # # Save version with legend\n",
    "    if len(alphas) > 1:\n",
    "        axes[0].legend(ncols=len(alphas), loc='upper left', bbox_to_anchor=(-0.25,-0.35), fontsize=legendfontsize)\n",
    "    else:\n",
    "        axes[0].legend(ncols=3, loc='upper left', bbox_to_anchor=(0.5,-0.35), fontsize=legendfontsize)\n",
    "\n",
    "    plt.savefig(fig_path.replace('NO_LEGEND.pdf', 'WITH_LEGEND.pdf'), bbox_inches='tight')\n",
    "    plt.savefig(fig_path.replace('NO_LEGEND.pdf', 'WITH_LEGEND.jpg'), bbox_inches='tight') # Also save as jpg\n",
    "    \n",
    "    print('Saved no-legend version to', fig_path, 'and version with legend to [...]WITH_LEGEND.pdf and .jpg' )\n",
    "    \n",
    "    plt.show()\n",
    "\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2413a580-64cc-4216-afe9-7943090b65a3",
   "metadata": {},
   "source": [
    "## Main Pareto plots\n",
    "\n",
    "### (a) Simple version for main paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6482131-eb0c-4191-bd1c-1510138352bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "alphas = [0.1] # Use alpha = 0.1 only\n",
    "methods = ['standard', 'classwise', 'prevalence-adjusted'] + \\\n",
    "            [f'fuzzy-rarity-{bw}' for bw in rarity_bandwidths] +\\\n",
    "            [f'fuzzy-RErarity-{bw}' for bw in rarity_bandwidths] +\\\n",
    "            [f'cvx-cw_weight={w}' for w in cw_weights] \n",
    "score = 'softmax'\n",
    "\n",
    "for dataset in ['plantnet', 'inaturalist']:\n",
    "    generate_all_pareto_plots(dataset, score, alphas, methods, save_suffix='_alpha=0.1', legendfontsize=14, show_grid=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1ac23ee-83d6-4ab1-893e-598fac1eac79",
   "metadata": {},
   "source": [
    "### (b) Full version for Appendix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c719bcf-8f49-4a0c-b350-8ee2201ad7d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "alphas = [0.2, 0.1, 0.05, 0.01]\n",
    "score = 'softmax'\n",
    "\n",
    "# rarity_bandwidths = [1e-15, 1e-10, 1e-5, 0.0001, 0.001, 0.01, .1 , 10, 1000]\n",
    "cw_weights = [0, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99 , 0.999, 1] # CHANGED!\n",
    "# cw_weights = 1 - np.array([0, .001, .01, .025, .05, .1, .2, .4, .8, 1])\n",
    "methods = ['standard', 'classwise', 'classwise-exact', 'clustered', 'prevalence-adjusted'] + \\\n",
    "            [f'fuzzy-rarity-{bw}' for bw in rarity_bandwidths] +\\\n",
    "            [f'fuzzy-RErarity-{bw}' for bw in rarity_bandwidths] +\\\n",
    "            [f'cvx-cw_weight={w}' for w in cw_weights] \n",
    "\n",
    "for dataset in dataset_names.keys():\n",
    "# for dataset in ['plantnet-trunc']:\n",
    "    generate_all_pareto_plots(dataset, score, alphas, methods, show_inset=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1dc91b4c-1821-4e43-904a-0c25043843f3",
   "metadata": {},
   "source": [
    "## Generate csv's of metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26b3ec6c-c0d7-4864-b171-035f08a7ddbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract metrics into csv\n",
    "\n",
    "def extract_metric_table(all_res, coverage_metrics, set_size_metric='mean'):\n",
    "    records = []\n",
    "    for alpha_key, methods in all_res.items():\n",
    "        alpha_val = float(alpha_key.split('=')[1])\n",
    "        for method, result in methods.items():\n",
    "            row = {'alpha': alpha_val, 'method': method}\n",
    "            for cov_metric in coverage_metrics:\n",
    "                cov_val = result['coverage_metrics'].get(cov_metric, None)\n",
    "                row[f'{cov_metric}'] = cov_val\n",
    "            size_val = result['set_size_metrics'].get(set_size_metric, None)\n",
    "            row[f'set_size_{set_size_metric}'] = size_val\n",
    "            records.append(row)\n",
    "    return pd.DataFrame(records)\n",
    "\n",
    "\n",
    "score = 'softmax'\n",
    "alphas = [0.2, 0.1, 0.05, 0.01]\n",
    "\n",
    "for dataset in dataset_names.keys():\n",
    "    set_size_metric = 'train_mean' if dataset.endswith('-trunc') else 'mean'\n",
    "    \n",
    "    # Load the results\n",
    "    all_res = load_all_results(dataset, alphas, methods, score=score)\n",
    "    \n",
    "    # Extract into DataFrame\n",
    "    df = extract_metric_table(\n",
    "        all_res,\n",
    "        coverage_metrics=['cov_below50', 'undercov_gap', 'macro_cov', 'train_marginal_cov'],\n",
    "        set_size_metric=set_size_metric\n",
    "    )\n",
    "    \n",
    "    # Save full CSV\n",
    "    pth = f'{fig_folder}/{dataset}/{dataset}_{score}_metrics.csv'\n",
    "    df.to_csv(pth, index=False)\n",
    "    print(f'Saved csv of metrics to {pth}')\n",
    "\n",
    "    ## --- Clean csv and save that ---\n",
    "    selected_alpha = 0.1\n",
    "    selected_methods = {'standard': 'Standard',\n",
    "                       'classwise': 'Classwise',\n",
    "                       'clustered': 'Clustered',\n",
    "                       'prevalence-adjusted': 'Standard w. PAS',\n",
    "                       'cvx-cw_weight=0.999': 'Interp-Q ($\\\\tau=0.999$)',\n",
    "                       'cvx-cw_weight=0.99': 'Interp-Q ($\\\\tau=0.99$)',\n",
    "                       'cvx-cw_weight=0.9': 'Interp-Q ($\\\\tau=0.9$)',\n",
    "                       'fuzzy-rarity-1e-05': 'Raw Fuzzy ($\\\\sigma=0.00001$)',\n",
    "                       'fuzzy-rarity-0.01': 'Raw Fuzzy ($\\\\sigma=0.01$)',\n",
    "                       'fuzzy-rarity-0.1': 'Raw Fuzzy ($\\\\sigma=0.1$)',\n",
    "                       'fuzzy-RErarity-1e-05': 'Fuzzy ($\\\\sigma=0.00001$)',\n",
    "                       'fuzzy-RErarity-0.01': 'Fuzzy ($\\\\sigma=0.01$)',\n",
    "                       'fuzzy-RErarity-0.1': 'Fuzzy ($\\\\sigma=0.1$)'}\n",
    "    # Filter and rename methods\n",
    "    df = df[df['alpha'] == selected_alpha] # Select alpha = 0.1 only\n",
    "    df = df[df[\"method\"].isin(selected_methods.keys())].copy()\n",
    "    df[\"method\"] = df[\"method\"].map(selected_methods)\n",
    "\n",
    "    # Remove alpha column\n",
    "    df = df.drop('alpha', axis=1)\n",
    "\n",
    "    # Rename columns\n",
    "    df.rename(columns={'method': 'Method', \n",
    "                       'cov_below50': 'FracBelow50% $\\\\downarrow$',\n",
    "                       'undercov_gap': 'UnderCovGap $\\\\downarrow$', \n",
    "                       'macro_cov': 'MacroCov $\\\\uparrow$',\n",
    "                       'train_marginal_cov': 'MarginalCov (desired = 0.9)', \n",
    "                       'set_size_train_mean': 'Avg. set size $\\\\downarrow$',\n",
    "                       'set_size_mean': 'Avg. set size'}, inplace=True)\n",
    "\n",
    "    \n",
    "    # Round numeric values to 3 significant digits\n",
    "    def round_sig(x, sig=3):\n",
    "        if isinstance(x, (int, float, np.float64)):\n",
    "            return float(f\"{x:.{sig}g}\")\n",
    "        return x\n",
    "    df = df.map(round_sig)\n",
    "\n",
    "    # Save\n",
    "    pth = f'{fig_folder}/{dataset}/{dataset}_{score}_alpha={selected_alpha}_metrics_CLEANED.csv'\n",
    "    df.to_csv(pth, index=False)\n",
    "    print(f'Saved cleaned csv of metrics for alpha={selected_alpha} to {pth}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9618348f-55b0-4929-82c6-328096d7235e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db85e3e9-e92a-476f-bda0-876fa9130a1c",
   "metadata": {},
   "source": [
    "## Appendix plots\n",
    "\n",
    "### Fuzzy and Interp-Q with PAS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1736a72-5672-46fe-9035-a94ce0bc4278",
   "metadata": {},
   "outputs": [],
   "source": [
    "rarity_bandwidths = ['1e-15', '1e-10', '1e-05', '0.0001']\n",
    "\n",
    "alphas = [0.1, 0.2]\n",
    "score = 'PAS'\n",
    "methods = ['standard', 'classwise', 'classwise-exact', 'clustered'] + \\\n",
    "            [f'fuzzy-rarity-{bw}' for bw in rarity_bandwidths] +\\\n",
    "            [f'fuzzy-RErarity-{bw}' for bw in rarity_bandwidths] +\\\n",
    "            [f'cvx-cw_weight={w}' for w in cw_weights] \n",
    "\n",
    "\n",
    "for dataset in dataset_names.keys():\n",
    "    generate_all_pareto_plots(dataset, score, alphas, methods, legendfontsize=13)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01eee432",
   "metadata": {},
   "outputs": [],
   "source": [
    "rarity_bandwidths\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26562edf-2975-45eb-bf98-806f7b1dfe0f",
   "metadata": {},
   "source": [
    "### Fuzzy with Random and Quantile projections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53b228b8-f5c6-4b42-8b33-aeaf7763319e",
   "metadata": {},
   "outputs": [],
   "source": [
    "bandwidths = rarity_bandwidths\n",
    "score = 'softmax'\n",
    "methods = ['standard', 'classwise'] + \\\n",
    "        [f'fuzzy-rarity-{bw}' for bw in bandwidths] + \\\n",
    "        [f'fuzzy-RErarity-{bw}' for bw in bandwidths] + \\\n",
    "        [f'fuzzy-random-{bw}' for bw in bandwidths] + \\\n",
    "        [f'fuzzy-RErandom-{bw}' for bw in bandwidths] + \\\n",
    "        [f'fuzzy-quantile-{bw}' for bw in bandwidths] + \\\n",
    "        [f'fuzzy-REquantile-{bw}' for bw in bandwidths] \n",
    "\n",
    "for dataset in dataset_names.keys():\n",
    "    generate_all_pareto_plots(dataset, score, alphas, methods, save_suffix='_extra_projections', show_inset=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e0dc2af-88a6-4269-9de4-850f0a187add",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "check_file = '/home-warm/plantnet/conformal_cache/metrics/inaturalist_PAS_alpha=0.1_fuzzy-RErarity-0.001.pkl'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a3862dd",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa487a1d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch-oban",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
