{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec144450",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "\n",
    "from source.utils.metrics import mse, nll, crps\n",
    "from source.constants import RESULTS_PATH, PLOTS_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "d3e53542",
   "metadata": {},
   "outputs": [],
   "source": [
    "task = [\"dots\", \"arrow\", \"city_street\", \"city_building\", \"city_sky\", \"city_car\", \"city_vegetation\"][2]\n",
    "network = [\"cnn\", \"resnet\"][0 if task in [\"dots\", \"arrow\"] else 1]\n",
    "method = [\"ensemble\", \"mc_dropout\"][0]\n",
    "arw = 0.0\n",
    "\n",
    "scoring_rule = [\"crps\", \"log\", \"mse\", \"spherical\"][0]\n",
    "\n",
    "method_seeds = [42]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "3c6b3251",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Analyzing city_street with resnet using ensemble and natural=True\n",
      "--------------------------------------------------\n",
      "Member MSE: tensor([0.0173, 0.0166, 0.0179, 0.0153, 0.0170, 0.0141, 0.0151, 0.0182, 0.0171,\n",
      "        0.0167])\n",
      "Ensemble MSE: 0.013427920639514923\n",
      "Member NLL: tensor([-1.0758, -1.0827, -1.0850, -1.1210, -1.0621, -1.1905, -1.1609, -1.0981,\n",
      "        -1.1075, -1.0947])\n",
      "Ensemble NLL: -1.1823147535324097\n",
      "Member CRPS: tensor([0.0552, 0.0559, 0.0562, 0.0520, 0.0549, 0.0492, 0.0510, 0.0559, 0.0543,\n",
      "        0.0545])\n",
      "Ensemble CRPS: 0.04889557510614395\n",
      "--------------------------------------------------\n",
      "Analyzing city_street with resnet using ensemble and natural=False\n",
      "--------------------------------------------------\n",
      "Member MSE: tensor([0.0266, 0.0259, 0.0301, 0.0512, 0.0278, 0.0303, 0.0269, 0.0258, 0.0426,\n",
      "        0.0257])\n",
      "Ensemble MSE: 0.025413567200303078\n",
      "Member NLL: tensor([-0.8460, -1.0640, -1.0606, -0.9449, -1.1615, -1.0331, -0.8883, -1.0252,\n",
      "        -0.9232, -1.0879])\n",
      "Ensemble NLL: -1.1176490783691406\n",
      "Member CRPS: tensor([0.0741, 0.0690, 0.0726, 0.0969, 0.0684, 0.0681, 0.0732, 0.0670, 0.0862,\n",
      "        0.0668])\n",
      "Ensemble CRPS: 0.06782465428113937\n",
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "for use_natural in [True, False]:\n",
    "\n",
    "    print(f\"Analyzing {task} with {network} using {method} and natural={use_natural}\")\n",
    "    print(\"-\"*50)\n",
    "\n",
    "    y_test = torch.load(os.path.join(RESULTS_PATH, f\"{task}_{network}_natural{use_natural}_mseed{method_seeds[0]}_arw{arw}\", f\"y_test.pt\"), weights_only=False)\n",
    "\n",
    "    means, variances = [], []\n",
    "    for method_seed in method_seeds:\n",
    "        preds = torch.load(os.path.join(RESULTS_PATH, f\"{task}_{network}_natural{use_natural}_mseed{method_seed}_arw{arw}\", f\"preds.pt\"), weights_only=False)\n",
    "        means.append(preds[..., 0])\n",
    "        variances.append(preds[..., 1])\n",
    "\n",
    "    means = torch.stack(means, dim=-2)\n",
    "    variances = torch.stack(variances, dim=-2)\n",
    "\n",
    "\n",
    "    member_mse = mse(means, y_test.reshape(-1, 1, 1))\n",
    "    member_nll = nll(means, variances, y_test.reshape(-1, 1, 1))\n",
    "    member_crps = crps(means, variances, y_test.reshape(-1, 1, 1))\n",
    "\n",
    "    ensemble_mse = mse(means.mean(dim=-1), y_test.reshape(-1, 1))\n",
    "    ensemble_nll = nll(means.mean(dim=-1), variances.mean(dim=-1), y_test.reshape(-1, 1))\n",
    "    ensemble_crps = crps(means.mean(dim=-1), variances.mean(dim=-1), y_test.reshape(-1, 1)) \n",
    "\n",
    "    print(f\"Member MSE: {member_mse.mean(dim=(0, 1))}\")\n",
    "    print(f\"Ensemble MSE: {ensemble_mse.mean(dim=(0, 1))}\")\n",
    "\n",
    "    print(f\"Member NLL: {member_nll.mean(dim=(0, 1))}\")\n",
    "    print(f\"Ensemble NLL: {ensemble_nll.mean(dim=(0, 1))}\")\n",
    "\n",
    "    print(f\"Member CRPS: {member_crps.mean(dim=(0, 1))}\")\n",
    "    print(f\"Ensemble CRPS: {ensemble_crps.mean(dim=(0, 1))}\")\n",
    "\n",
    "    print(\"-\"*50)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pflow",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
