{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Python Path non-sense\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "import pickle\n",
    "\n",
    "sys.path = [x for x in sys.path if 'bayes_gsl' not in x]\n",
    "new_path = '/Users/maxw/projects/gsl-bnn/' ## change this to the path of the src directory!\n",
    "if new_path not in sys.path:\n",
    "    sys.path.append(new_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the metrics from the results/label_mismatch directory and print them for easy viz\n",
    "from experiments.label_mismatch.config import experiment_settings\n",
    "results_path = experiment_settings['results_path']\n",
    "results_file = results_path + f\"model_data_metrics_dicts.pkl\"\n",
    "\n",
    "with open(results_file, 'rb') as f:\n",
    "    metrics_dicts = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def display_metrics(metrics_dict, precision=8):\n",
    "    calibration_dict = metrics_dict['calibration_dict']\n",
    "    precision_format = f'.{precision}f'\n",
    "    std_precision_format = f'.{precision}f' if precision != 3 else '.3f'  # Adjusting for different precision needs\n",
    "\n",
    "    print(f'Test Error: {1 - metrics_dict[\"accuracies\"].mean():{precision_format}} \\pm {metrics_dict[\"accuracies\"].std():{std_precision_format}}')\n",
    "    print(f'Test NLL: {-1 * metrics_dict[\"log_likelihoods\"].mean():.3f} \\pm {metrics_dict[\"log_likelihoods\"].std():.3f}')\n",
    "    print(f'Test BS: {metrics_dict[\"brier_scores\"].mean():{precision_format}} \\pm {metrics_dict[\"brier_scores\"].std():{std_precision_format}}')\n",
    "    print(f'Test ECE:{calibration_dict[\"ece\"]:{precision_format}}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Model: DPG\n",
      "\n",
      "Data Set: RG_.32\n",
      "Test Error: 0.01573682 \\pm 0.01008262\n",
      "Test NLL: 10.333 \\pm 5.378\n",
      "Test BS: 0.01286086 \\pm 0.00716473\n",
      "Test ECE:0.00341998\n",
      "\n",
      "Data Set: RG_.5\n",
      "Test Error: 0.01899993 \\pm 0.01078356\n",
      "Test NLL: 17.532 \\pm 5.946\n",
      "Test BS: 0.02287454 \\pm 0.00920260\n",
      "Test ECE:0.05208334\n",
      "\n",
      "Data Set: ER_.5\n",
      "Test Error: 0.00505251 \\pm 0.00561556\n",
      "Test NLL: 33.009 \\pm 5.315\n",
      "Test BS: 0.04655694 \\pm 0.01235685\n",
      "Test ECE:0.17155522\n",
      "\n",
      "Data Set: BA_1\n",
      "Test Error: 0.02221048 \\pm 0.00634117\n",
      "Test NLL: 20.911 \\pm 6.205\n",
      "Test BS: 0.02103098 \\pm 0.00607401\n",
      "Test ECE:0.01663103\n",
      "\n",
      "\n",
      "Model: PDS\n",
      "\n",
      "Data Set: RG_.32\n",
      "Test Error: 0.01589471 \\pm 0.01025924\n",
      "Test NLL: 10.317 \\pm 5.391\n",
      "Test BS: 0.01285023 \\pm 0.00717723\n",
      "Test ECE:0.00374968\n",
      "\n",
      "Data Set: RG_.5\n",
      "Test Error: 0.01931572 \\pm 0.01114136\n",
      "Test NLL: 17.704 \\pm 5.945\n",
      "Test BS: 0.02331776 \\pm 0.00944993\n",
      "Test ECE:0.05286844\n",
      "\n",
      "Data Set: ER_.5\n",
      "Test Error: 0.00557888 \\pm 0.00608435\n",
      "Test NLL: 33.432 \\pm 5.275\n",
      "Test BS: 0.04801297 \\pm 0.01273474\n",
      "Test ECE:0.17353761\n",
      "\n",
      "Data Set: BA_1\n",
      "Test Error: 0.02215785 \\pm 0.00628479\n",
      "Test NLL: 20.802 \\pm 6.143\n",
      "Test BS: 0.02096917 \\pm 0.00604504\n",
      "Test ECE:0.01659561\n",
      "\n",
      "\n",
      "Model: DPG-MIMO\n",
      "\n",
      "Data Set: RG_.32\n",
      "Test Error: 0.01415789 \\pm 0.00940604\n",
      "Test NLL: 9.191 \\pm 5.155\n",
      "Test BS: 0.01188410 \\pm 0.00664590\n",
      "Test ECE:0.00350250\n",
      "\n",
      "Data Set: RG_.5\n",
      "Test Error: 0.01326311 \\pm 0.01149442\n",
      "Test NLL: 17.125 \\pm 5.178\n",
      "Test BS: 0.01737878 \\pm 0.00652344\n",
      "Test ECE:0.05460664\n",
      "\n",
      "Data Set: ER_.5\n",
      "Test Error: 0.01878953 \\pm 0.01189182\n",
      "Test NLL: 40.487 \\pm 5.515\n",
      "Test BS: 0.04619085 \\pm 0.00956773\n",
      "Test ECE:0.16170704\n",
      "\n",
      "Data Set: BA_1\n",
      "Test Error: 0.02710527 \\pm 0.01173046\n",
      "Test NLL: 13.371 \\pm 3.028\n",
      "Test BS: 0.02060318 \\pm 0.00600022\n",
      "Test ECE:0.01124196\n",
      "\n",
      "\n",
      "Model: DPG-MIMO-E\n",
      "\n",
      "Data Set: RG_.32\n",
      "Test Error: 0.01289469 \\pm 0.00823394\n",
      "Test NLL: 8.354 \\pm 4.738\n",
      "Test BS: 0.01065306 \\pm 0.00604091\n",
      "Test ECE:0.00240069\n",
      "\n",
      "Data Set: RG_.5\n",
      "Test Error: 0.01568413 \\pm 0.01414174\n",
      "Test NLL: 12.656 \\pm 4.646\n",
      "Test BS: 0.01516709 \\pm 0.00870954\n",
      "Test ECE:0.02916749\n",
      "\n",
      "Data Set: ER_.5\n",
      "Test Error: 0.11957896 \\pm 0.04219780\n",
      "Test NLL: 37.770 \\pm 5.657\n",
      "Test BS: 0.07818024 \\pm 0.01967286\n",
      "Test ECE:0.02907484\n",
      "\n",
      "Data Set: BA_1\n",
      "Test Error: 0.01610523 \\pm 0.00739919\n",
      "Test NLL: 8.349 \\pm 2.919\n",
      "Test BS: 0.01318497 \\pm 0.00479427\n",
      "Test ECE:0.00283587\n"
     ]
    }
   ],
   "source": [
    "# top level = models\n",
    "# second level data sets\n",
    "# third level = metrics dict\n",
    "\n",
    "for model, model_dict in metrics_dicts.items():\n",
    "    print(f'\\n\\nModel: {model}')\n",
    "    for data_set, metrics_dict in model_dict.items():\n",
    "        print(f'\\nData Set: {data_set}')\n",
    "        display_metrics(metrics_dict)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gsl-bnn-mac-m2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
