{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "f01c0bc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import tensorflow as tf\n",
    "from sklearn.metrics import accuracy_score, f1_score, mean_absolute_error\n",
    "import bayesflow as bf\n",
    "from bayesflow.computational_utilities import expected_calibration_error\n",
    "\n",
    "from train_model_comparison import get_trainer\n",
    "from models import get_model\n",
    "from settings import MODEL_NAMES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "68f101a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODES = ['expert', 'learner', 'hybrid']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c2ccdf16",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sims = pickle.load(open('./simulations/test_model_comparison.pkl', 'rb+'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "247c04f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(mode):\n",
    "\n",
    "    # Obtain predictions\n",
    "    trainer = get_trainer(mode, len(MODEL_NAMES))\n",
    "    conf = trainer.configurator(test_sims)\n",
    "    pmps = trainer.amortizer.posterior_probs(conf)\n",
    "\n",
    "    # Compute metrics\n",
    "    acc = accuracy_score(conf['model_indices'].argmax(1), pmps.argmax(1))\n",
    "    f1 = f1_score(conf['model_indices'].argmax(1), pmps.argmax(1), average='weighted')\n",
    "    cal_err = np.mean(expected_calibration_error(conf['model_indices'], pmps)[0])\n",
    "    mae = mean_absolute_error(conf['model_indices'], pmps)\n",
    "    return np.array([acc, f1, cal_err, mae])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "f553483c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Loaded loss history from checkpoints/model_comparison/expert_50000/history_250.pkl.\n",
      "INFO:root:Networks loaded from checkpoints/model_comparison/expert_50000/ckpt-250\n",
      "INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Loaded loss history from checkpoints/model_comparison/learner_50000/history_250.pkl.\n",
      "INFO:root:Networks loaded from checkpoints/model_comparison/learner_50000/ckpt-250\n",
      "INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!\n",
      "INFO:root:Loaded loss history from checkpoints/model_comparison/hybrid_50000/history_250.pkl.\n",
      "INFO:root:Networks loaded from checkpoints/model_comparison/hybrid_50000/ckpt-250\n"
     ]
    }
   ],
   "source": [
    "results = np.zeros((3, 4))\n",
    "for i, mode in enumerate(MODES):\n",
    "    results[i, :] = evaluate(mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "2c62b435",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(\n",
    "    np.round(results, 3), \n",
    "    columns=['Accuracy', 'F1-Score', 'ECE', 'MAE'], \n",
    "    index=[m.capitalize() for m in MODES]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "c7648525",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>F1-Score</th>\n",
       "      <th>ECE</th>\n",
       "      <th>MAE</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Expert</th>\n",
       "      <td>0.612</td>\n",
       "      <td>0.608</td>\n",
       "      <td>0.012</td>\n",
       "      <td>0.061</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Learner</th>\n",
       "      <td>0.686</td>\n",
       "      <td>0.687</td>\n",
       "      <td>0.020</td>\n",
       "      <td>0.048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Hybrid</th>\n",
       "      <td>0.716</td>\n",
       "      <td>0.714</td>\n",
       "      <td>0.013</td>\n",
       "      <td>0.046</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         Accuracy  F1-Score    ECE    MAE\n",
       "Expert      0.612     0.608  0.012  0.061\n",
       "Learner     0.686     0.687  0.020  0.048\n",
       "Hybrid      0.716     0.714  0.013  0.046"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
