{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.evaluation import *\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set default parameter values\n",
    "output_dir = \"./outputs\"\n",
    "features_dir = \"/path/to/precomputed/features\"\n",
    "alpha = 0.1\n",
    "scores_randomize = False\n",
    "temp_scaling = True\n",
    "ood_score = 'trust'\n",
    "use_binning = False\n",
    "degree = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = {\n",
    "    'ImageNet': 'imagenet',\n",
    "    'ImageNet-LT': 'imagenet_lt',\n",
    "    'Places365': 'places',\n",
    "    'Places365-LT': 'places_lt',\n",
    "}\n",
    "\n",
    "score_fns = {\n",
    "    'APS': 'aps',\n",
    "}\n",
    "\n",
    "models = {\n",
    "    'ImageNet': 'resnet50',\n",
    "    'ImageNet-LT': 'resnext50_imagenet_lt',\n",
    "    'Places365': 'resnet152_places',\n",
    "    'Places365-LT': 'resnet152_places_lt',\n",
    "}\n",
    "\n",
    "methods = {\n",
    "    'split': '$\\mathsf{split}$',\n",
    "    'naive': '$\\mathsf{naive}$',\n",
    "    'conditional': '$\\mathsf{conditional}$',\n",
    "}\n",
    "\n",
    "ood_scores_dir = {\n",
    "    'trust': '/path/to/trust_scores/',\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# aggregate results over all seeds and score functions\n",
    "horizontal_dict = {}\n",
    "for dataset, dataset_name in datasets.items():\n",
    "    for score_fn, score_fn_name in score_fns.items():\n",
    "        results_dir = os.path.join(output_dir, f\"{dataset_name}_{models[dataset]}\")\n",
    "        res_fname = f\"alpha_{alpha}_score_fn_{score_fn_name}_scores_randomize_{scores_randomize}_temp_scale_{temp_scaling}_ood_{ood_score}_use_binning_{use_binning}\"\n",
    "        df = aggregate_results_over_seeds(dataset_name, models[dataset], features_dir, results_dir, res_fname, ood_scores_dir[ood_score], methods.keys(), score_fn_name, degree)\n",
    "    \n",
    "    display(df)\n",
    "            "
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
