{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb77e675-6adb-4e84-8ac2-b9af727f698f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys; sys.path.append(\"../\") # For relative imports\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "from utils.conformal_utils import *\n",
    "from utils.experiment_utils import get_inputs_folder, get_outputs_folder, get_figs_folder\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "plt.rcParams.update({\n",
    "    'font.size': 16,        # base font size\n",
    "    'axes.titlesize': 18,   # subplot titles\n",
    "    'axes.labelsize': 16,   # x/y labels\n",
    "    'legend.fontsize': 16,  # legend text\n",
    "    'xtick.labelsize': 16,  # tick labels\n",
    "    'ytick.labelsize': 16,\n",
    "\n",
    "})\n",
    "# use tex with matplotlib\n",
    "plt.rc('text', usetex=True)\n",
    "plt.rc('font', family='serif')\n",
    "plt.rcParams['text.latex.preamble'] = r'\\usepackage{amsmath}'\n",
    "\n",
    "dataset_names = {\n",
    "    \"plantnet\": \"Pl@ntNet-300K\",\n",
    "    \"plantnet-trunc\": \"Pl@ntNet-300K (truncated)\",\n",
    "    \"inaturalist\": \"iNaturalist-2018\",\n",
    "    \"inaturalist-trunc\": \"iNaturalist-2018 (truncated)\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "646af825-9d55-4467-b8ad-2750a528346b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load in paths from folders.json\n",
    "inputs_folder = get_inputs_folder()\n",
    "results_folder = get_outputs_folder()\n",
    "fig_folder = get_figs_folder()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed2fbf07-57f7-4d30-b98c-5da79e1de8e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'plantnet'\n",
    "\n",
    "# ------- Get data --------\n",
    "cal_softmax = np.load(f'{inputs_folder}/best-{dataset}-model_cal_softmax.npy')\n",
    "cal_labels = np.load(f'{inputs_folder}/best-{dataset}-model_cal_labels.npy')\n",
    "test_softmax = np.load(f'{inputs_folder}/best-{dataset}-model_test_softmax.npy')\n",
    "test_labels = np.load(f'{inputs_folder}/best-{dataset}-model_test_labels.npy')\n",
    "print('Loaded pre-computed softmax scores')\n",
    "\n",
    "num_classes = cal_softmax.shape[1]\n",
    "print('Num classes:', num_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "649139eb-28c0-4a04-9c7b-8160f998a88b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_plantnet_at_risk_species():\n",
    "    ## Identify indices of at-risk species in PlantNet-300K\n",
    "\n",
    "    names_as_numbers_files = \"../data/plantnet300K_class_idx_to_species_id.json\"\n",
    "    names_files = \"../data/plantnet300K_species_id_2_name.json\"\n",
    "    status_iucn = \"../data/plantnet300K_iucn_status_dict.json\"\n",
    "    \n",
    "    names_as_numbers = json.load(open(names_as_numbers_files, \"r\"))\n",
    "    new_names = json.load(open(names_files, \"r\"))\n",
    "    status_iucn = json.load(open(status_iucn, \"r\"))\n",
    "    \n",
    "    df = pd.DataFrame.from_dict(names_as_numbers, orient=\"index\", columns=[\"species_id\"])\n",
    "    df = df.reset_index()\n",
    "    df = df.rename(columns={\"index\": \"class_id\"})\n",
    "    \n",
    "    # df[\"species_id\"] = df[\"species_id\"].astype(str)\n",
    "    df[\"class_id\"] = df[\"class_id\"].astype(int)\n",
    "    df[\"species_name\"] = df[\"species_id\"].map(new_names)\n",
    "    # df = df.set_index(\"class_id\")\n",
    "    \n",
    "    \n",
    "    # create a new dataframe with the iucn status with the species_id and the iucn status\n",
    "    df_iucn = pd.DataFrame.from_dict(status_iucn, orient=\"index\", columns=[\"iucn_status\"])\n",
    "    df[\"iucn_status\"] = \"Not Evaluated\"\n",
    "    for idx, specie in enumerate(df[\"species_name\"].values):\n",
    "        if specie in df_iucn.index:\n",
    "            df.loc[idx, \"iucn_status\"] = df_iucn.loc[specie, \"iucn_status\"]\n",
    "\n",
    "    print('Number of each IUCN category:', df['iucn_status'].value_counts())\n",
    "    at_risk_codes = ['EN', 'VU', 'NT', 'CR', 'LR/nt', 'LR/lc', 'LR/cd']\n",
    "    print(f'We consider {at_risk_codes} as at-risk')\n",
    "    at_risk_class_ids = np.array(df['class_id'][df['iucn_status'].isin(at_risk_codes)])\n",
    "\n",
    "    print('At-risk species:', at_risk_class_ids, f'({len(at_risk_class_ids)} total)')\n",
    "    return at_risk_class_ids\n",
    "\n",
    "at_risk_species = get_plantnet_at_risk_species()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aec64d35-89ba-4cf2-9e3e-598d02ec5fbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "train_labels_path = f'{inputs_folder}/{dataset}_train_labels.npy'\n",
    "alphas = [.2, .1, .05, .01]\n",
    "gammas = [2, 10, 100, 500] # scalar for upweighting at risk species\n",
    "\n",
    "all_res = {f'{alpha=}': {} for alpha in alphas}\n",
    "\n",
    "for score, b in zip(['softmax', 'PAS', 'WPAS', 'WPAS', 'WPAS', 'WPAS'], \n",
    "                    [None, None] + gammas):\n",
    "\n",
    "    ## Get conformal scores\n",
    "    if score == 'WPAS':\n",
    "        weights = np.ones((num_classes,))\n",
    "        weights[at_risk_species] = b\n",
    "        weights = weights / np.sum(weights)\n",
    "    else:\n",
    "        weights = None\n",
    "    cal_scores = get_conformal_scores(cal_softmax, score, \n",
    "                                      train_labels_path=train_labels_path, weights=weights)\n",
    "    test_scores = get_conformal_scores(test_softmax, score, \n",
    "                                      train_labels_path=train_labels_path, weights=weights)\n",
    "\n",
    "    # Run Standard CP for different alphas\n",
    "    for alpha in alphas:\n",
    "        standard_qhat, pred_sets, coverage_metrics, set_size_metrics = standard_conformal(cal_scores, cal_labels, \n",
    "                                                           test_scores, test_labels, alpha)\n",
    "        res = {'pred_sets': pred_sets, \n",
    "               'qhat': standard_qhat,\n",
    "               'coverage_metrics': coverage_metrics,\n",
    "               'set_size_metrics': set_size_metrics}\n",
    "\n",
    "        if score == 'WPAS':\n",
    "            score_name = f'WPAS ($\\\\gamma=$ {b})'\n",
    "        else:\n",
    "            score_name = score\n",
    "            \n",
    "        all_res[f'{alpha=}'][score_name] = res\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "461af5ca-8c8e-4588-80e0-83123b8d8bab",
   "metadata": {},
   "outputs": [],
   "source": [
    "for alpha in alphas:\n",
    "    print(f'----- alpha = {alpha} -----')\n",
    "    for score in all_res[f'alpha={alphas[0]}'].keys():\n",
    "        res = all_res[f'{alpha=}'][score]\n",
    "        plt.plot(res['coverage_metrics']['raw_class_coverages'][at_risk_species], 'o', alpha=0.3, \n",
    "                 label=f'{score}, avg size={res['set_size_metrics']['mean']:.2f}')\n",
    "        other_species = np.setdiff1d(np.arange(num_classes), at_risk_species)\n",
    "        print(f'[{score}] avg class-cond cov for at risk species: {np.mean(res['coverage_metrics']['raw_class_coverages'][at_risk_species]):.3f}',\n",
    "             f', for other species: {np.mean(res['coverage_metrics']['raw_class_coverages'][other_species]):.3f}')\n",
    "    plt.legend()\n",
    "    # plt.title(f'Average set size: {res['set_size_metrics']['mean']}')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03c034f4-5855-4545-91cd-e55147d94798",
   "metadata": {},
   "outputs": [],
   "source": [
    "score_to_color = {'softmax': 'blue',\n",
    "                  'PAS': 'green',\n",
    "                   f'WPAS ($\\\\gamma=$ {gammas[0]})': (0.3, 0.13, 0.7),\n",
    "                   f'WPAS ($\\\\gamma=$ {gammas[1]})': (0.5, 0.13, 0.7),\n",
    "                   f'WPAS ($\\\\gamma=$ {gammas[2]})': (0.7, 0.13, 0.7),\n",
    "                   f'WPAS ($\\\\gamma=$ {gammas[3]})': (0.9, 0.13, 0.7),\n",
    "                  }\n",
    "\n",
    "alphas = [.2, .1, .05, .01]\n",
    "\n",
    "\n",
    "\n",
    "markersizes = [4,5,6,7]\n",
    "\n",
    "metric_names = ['At-risk average $\\\\hat{c}_y$',\n",
    "                'Not-at-risk average $\\\\hat{c}_y$',\n",
    "                'MacroCov',\n",
    "                'MarginalCov']\n",
    "             \n",
    "fig, axes = plt.subplots(1, len(metric_names), figsize=(13, 2.2), sharey=True)\n",
    "for i in range(len(metric_names)):\n",
    "    ax = axes[i]\n",
    "    if i == 3:\n",
    "        for a in alphas:\n",
    "                ax.axvline(1-a, linestyle='--', color='grey')\n",
    "            \n",
    "    for j, alpha in enumerate(alphas):\n",
    "        for score in all_res[f'alpha={alphas[0]}'].keys():\n",
    "            res = all_res[f'{alpha=}'][score]\n",
    "\n",
    "            if score == 'softmax':\n",
    "                marker = 'X'\n",
    "            elif score == 'PAS':\n",
    "                marker = '^'\n",
    "            else:\n",
    "                marker = 'o'\n",
    "            \n",
    "            if i == 0: # Avg of at risk\n",
    "                x = np.mean(res['coverage_metrics']['raw_class_coverages'][at_risk_species])\n",
    "            elif i == 1: # Avg of not at risk species\n",
    "                other_species = np.setdiff1d(np.arange(num_classes), at_risk_species)\n",
    "                x = np.mean(res['coverage_metrics']['raw_class_coverages'][other_species])\n",
    "            elif i == 2: # Macro-coverage\n",
    "                x = np.mean(res['coverage_metrics']['raw_class_coverages'])\n",
    "            elif i == 3: # Marginal coverage\n",
    "                x = res['coverage_metrics']['marginal_cov']\n",
    "                \n",
    "            y = res['set_size_metrics']['mean']\n",
    "           \n",
    "            ax.plot(x, y, marker, alpha=0.6, markersize=markersizes[j],\n",
    "                    color=score_to_color[score], label=f'{score}, $\\\\alpha=$ {alpha}')\n",
    "            ax.spines[['right', 'top']].set_visible(False)\n",
    "            # other_species = np.setdiff1d(np.arange(num_classes), at_risk_species)\n",
    "            # print(f'[{score}] avg class-cond cov for at risk species: {np.mean(res['coverage_metrics']['raw_class_coverages'][at_risk_species]):.3f}',\n",
    "            #      f', for other species: {np.mean(res['coverage_metrics']['raw_class_coverages'][other_species]):.3f}')\n",
    "    ax.set_xlabel(metric_names[i])\n",
    "    ax.set_ylim(bottom=0)\n",
    "    \n",
    "axes[0].set_ylabel('Average set size')\n",
    "plt.legend(ncols = len(alphas), loc='upper left', bbox_to_anchor=(-3.85,-0.35), fontsize=12)\n",
    "plt.tight_layout()\n",
    "plt.suptitle(dataset_names[dataset], y=1.02)\n",
    "\n",
    "os.makedirs(f'{fig_folder}/weighted_macro_coverage', exist_ok=True)\n",
    "pth = f'{fig_folder}/weighted_macro_coverage/plantnet_weighted_macro_coverage_results.pdf'\n",
    "plt.savefig(pth, bbox_inches='tight')\n",
    "print('Saved plot to', pth)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3fd48c1-6ac2-482a-8186-489ffb4a4bb5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conformal_env",
   "language": "python",
   "name": "conformal_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
