{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cd1f7a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Path to aggregated metrics JSON\n",
    "json_path = \"best_models_metrics.json\"\n",
    "with open(json_path, \"r\") as f:\n",
    "    models = json.load(f)\n",
    "\n",
    "# Extract model names and metrics\n",
    "model_names = [m[\"model\"] for m in models]\n",
    "metrics = [\"Acc_energy\",\"F1_energy\",\"Acc_alpha\",\"F1_alpha\",\"Acc_q0\",\"F1_q0\"]\n",
    "labels = [\n",
    "    \"Accuracy Energy\", \"Macro-F1 Energy\",\n",
    "    \"Accuracy αs\", \"Macro-F1 αs\",\n",
    "    \"Accuracy Q0\", \"Macro-F1 Q0\"\n",
    "]\n",
    "\n",
    "# Build matrix [n_models x n_metrics]\n",
    "values = np.array([[m[metric] for metric in metrics] for m in models])\n",
    "\n",
    "# Plot settings\n",
    "x = np.arange(len(model_names))  # models on x-axis\n",
    "width = 0.12                     # width of each bar\n",
    "fig, ax = plt.subplots(figsize=(11, 5))\n",
    "\n",
    "# Create grouped bars\n",
    "for i, label in enumerate(labels):\n",
    "    offsets = x + (i - (len(labels)-1)/2)*width\n",
    "    ax.bar(offsets, values[:, i], width, label=label)\n",
    "\n",
    "# Axis & legend\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(model_names, rotation=20, ha=\"right\")\n",
    "ax.set_ylabel(\"Score (%)\")\n",
    "ax.set_ylim(0, 105)\n",
    "ax.set_title(\"Accuracy and Macro-F1 across backbones for Energy, αs, and Q0\")\n",
    "ax.legend(ncols=3, fontsize=8)\n",
    "ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
    "\n",
    "# Save outputs\n",
    "os.makedirs(\"figures\", exist_ok=True)\n",
    "pdf_path = \"barplots_placeholder.pdf\"\n",
    "png_path = \"barplots_placeholder.png\"\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig(pdf_path, bbox_inches=\"tight\")\n",
    "fig.savefig(png_path, dpi=300, bbox_inches=\"tight\")\n",
    "\n",
    "print(f\"Saved bar plot to {pdf_path} and {png_path}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2f52eb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Path to aggregated metrics JSON\n",
    "json_path = \"best_models_metrics.json\"\n",
    "with open(json_path, \"r\") as f:\n",
    "    models = json.load(f)\n",
    "\n",
    "# Map long model tags to short names\n",
    "name_map = {\n",
    "    \"MambaOut_base_plus_rw\": \"Mamba\",\n",
    "    \"ConvNeXt_Gaussian_g500\": \"ConvNeXt\",\n",
    "    \"EfficientNet_g500\": \"EfficientNet\",\n",
    "    \"Swin_g500\": \"Swin\",\n",
    "    \"ViT_tiny_patch16_224_gaussian_lrp_12_rlrp_4\": \"ViT\"\n",
    "}\n",
    "\n",
    "# Extract model names and metrics\n",
    "model_names = [name_map.get(m[\"model\"], m[\"model\"]) for m in models]\n",
    "metrics = [\"Acc_energy\",\"F1_energy\",\"Acc_alpha\",\"F1_alpha\",\"Acc_q0\",\"F1_q0\"]\n",
    "labels = [\n",
    "    \"Accuracy Energy\", \"Macro-F1 Energy\",\n",
    "    \"Accuracy αs\", \"Macro-F1 αs\",\n",
    "    \"Accuracy Q0\", \"Macro-F1 Q0\"\n",
    "]\n",
    "\n",
    "# Build matrix [n_models x n_metrics]\n",
    "values = np.array([[m[metric] for metric in metrics] for m in models])\n",
    "\n",
    "# Plot settings\n",
    "x = np.arange(len(model_names))  # models on x-axis\n",
    "width = 0.12                     # width of each bar\n",
    "fig, ax = plt.subplots(figsize=(11, 5))\n",
    "\n",
    "# Create grouped bars\n",
    "for i, label in enumerate(labels):\n",
    "    offsets = x + (i - (len(labels)-1)/2)*width\n",
    "    ax.bar(offsets, values[:, i], width, label=label)\n",
    "\n",
    "# Axis & legend\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(model_names, rotation=20, ha=\"right\")\n",
    "ax.set_ylabel(\"Score (%)\")\n",
    "ax.set_ylim(0, 105)\n",
    "ax.set_title(\"Accuracy and Macro-F1 across backbones for Energy, αs, and Q0\")\n",
    "ax.legend(ncols=3, fontsize=8)\n",
    "ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
    "\n",
    "# Save outputs\n",
    "pdf_path = \"barplots_backbones.pdf\"\n",
    "png_path = \"barplots_backbones.png\"\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig(pdf_path, bbox_inches=\"tight\")\n",
    "fig.savefig(png_path, dpi=300, bbox_inches=\"tight\")\n",
    "\n",
    "print(f\"Saved bar plot to {pdf_path} and {png_path}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91b8e70a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json, os, numpy as np, matplotlib.pyplot as plt\n",
    "\n",
    "# Use mathtext (default). Optional: pick a serif font family.\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": False,\n",
    "    \"font.family\": \"serif\"\n",
    "})\n",
    "\n",
    "# ---- your original code, unchanged below ----\n",
    "json_path = \"best_models_metrics.json\"\n",
    "with open(json_path, \"r\") as f:\n",
    "    models = json.load(f)\n",
    "\n",
    "name_map = {\n",
    "    \"MambaOut_base_plus_rw\": \"Mamba\",\n",
    "    \"ConvNeXt_Gaussian_g500\": \"ConvNeXt\",\n",
    "    \"EfficientNet_g500\": \"EfficientNet\",\n",
    "    \"Swin_g500\": \"Swin\",\n",
    "    \"ViT_tiny_patch16_224_gaussian_lrp_12_rlrp_4\": \"ViT\"\n",
    "}\n",
    "\n",
    "model_names = [name_map.get(m[\"model\"], m[\"model\"]) for m in models]\n",
    "metrics = [\"Acc_energy\",\"F1_energy\",\"Acc_alpha\",\"F1_alpha\",\"Acc_q0\",\"F1_q0\"]\n",
    "labels = [\n",
    "    r\"Accuracy Energy\", r\"Macro-F1 Energy\",\n",
    "    r\"Accuracy $\\alpha_s$\", r\"Macro-F1 $\\alpha_s$\",\n",
    "    r\"Accuracy $Q_0$\", r\"Macro-F1 $Q_0$\"\n",
    "]\n",
    "\n",
    "values = np.array([[m[metric] for metric in metrics] for m in models])\n",
    "\n",
    "x = np.arange(len(model_names))\n",
    "width = 0.12\n",
    "fig, ax = plt.subplots(figsize=(11, 5))\n",
    "\n",
    "for i, label in enumerate(labels):\n",
    "    offsets = x + (i - (len(labels)-1)/2)*width\n",
    "    ax.bar(offsets, values[:, i], width, label=label)\n",
    "\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(model_names, rotation=20, ha=\"right\")\n",
    "ax.set_ylabel(r\"Score (\\%)\")  # mathtext handles \\% fine\n",
    "ax.set_ylim(0, 105)\n",
    "ax.set_title(r\"Accuracy and Macro-F1 across backbones for Energy, $\\alpha_s$, and $Q_0$\")\n",
    "ax.legend(ncols=3, fontsize=9, frameon=False)\n",
    "ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
    "\n",
    "pdf_path = \"barplots_backbones.pdf\"\n",
    "png_path = \"barplots_backbones.png\"\n",
    "fig.tight_layout()\n",
    "fig.savefig(pdf_path, bbox_inches=\"tight\")\n",
    "fig.savefig(png_path, dpi=300, bbox_inches=\"tight\")\n",
    "print(f\"Saved bar plot to {pdf_path} and {png_path}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a48d7533",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Path to aggregated metrics JSON\n",
    "json_path = \"best_models_metrics.json\"\n",
    "with open(json_path, \"r\") as f:\n",
    "    models = json.load(f)\n",
    "\n",
    "# Map long model tags to short names\n",
    "name_map = {\n",
    "    \"MambaOut_base_plus_rw\": \"Mamba\",\n",
    "    \"ConvNeXt_Gaussian_g500\": \"ConvNeXt\",\n",
    "    \"EfficientNet_g500\": \"EfficientNet\",\n",
    "    \"Swin_g500\": \"Swin\",\n",
    "    \"ViT_tiny_patch16_224_gaussian_lrp_12_rlrp_4\": \"ViT\"\n",
    "}\n",
    "\n",
    "# Extract model names and metrics\n",
    "model_names = [name_map.get(m[\"model\"], m[\"model\"]) for m in models]\n",
    "metrics = [\"Acc_energy\",\"F1_energy\",\"Acc_alpha\",\"F1_alpha\",\"Acc_q0\",\"F1_q0\"]\n",
    "labels = [\n",
    "    r\"Accuracy Energy\", r\"Macro-F1 Energy\",\n",
    "    r\"Accuracy $\\alpha_s$\", r\"Macro-F1 $\\alpha_s$\",\n",
    "    r\"Accuracy $Q_0$\", r\"Macro-F1 $Q_0$\"\n",
    "]\n",
    "\n",
    "# Build matrix [n_models x n_metrics]\n",
    "values = np.array([[m[metric] for metric in metrics] for m in models])\n",
    "\n",
    "# Plot settings\n",
    "x = np.arange(len(model_names))  # models on x-axis\n",
    "width = 0.12                     # width of each bar\n",
    "fig, ax = plt.subplots(figsize=(10, 4.5))\n",
    "\n",
    "# Create grouped bars\n",
    "for i, label in enumerate(labels):\n",
    "    offsets = x + (i - (len(labels)-1)/2)*width\n",
    "    ax.bar(offsets, values[:, i], width, label=label)\n",
    "\n",
    "# Axis & legend\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(model_names, rotation=0, ha=\"center\")  # straight labels\n",
    "ax.set_ylabel(r\"Score (\\%)\")\n",
    "ax.set_ylim(0, 105)\n",
    "ax.set_title(r\"Accuracy and Macro-F1 across backbones for Energy, $\\alpha_s$, and $Q_0$\")\n",
    "ax.legend(ncols=3, fontsize=9, frameon=True, framealpha=0.5)  # transparent legend box\n",
    "ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.4)\n",
    "\n",
    "# Save outputs\n",
    "pdf_path = \"barplots_backbones.pdf\"\n",
    "png_path = \"barplots_backbones.png\"\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig(pdf_path, bbox_inches=\"tight\")\n",
    "fig.savefig(png_path, dpi=300, bbox_inches=\"tight\")\n",
    "\n",
    "print(f\"Saved bar plot to {pdf_path} and {png_path}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "330d5248",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save as: figures/energy_dotplot.(pdf|png), alpha_dotplot.(pdf|png), q0_dotplot.(pdf|png)\n",
    "import json, numpy as np, matplotlib.pyplot as plt\n",
    "\n",
    "models = json.load(open(\"best_models_metrics.json\"))\n",
    "name_map = {\"MambaOut_base_plus_rw\":\"Mamba\",\"ConvNeXt_Gaussian_g500\":\"ConvNeXt\",\n",
    "            \"EfficientNet_g500\":\"EfficientNet\",\"Swin_g500\":\"Swin\",\n",
    "            \"ViT_tiny_patch16_224_gaussian_lrp_12_rlrp_4\":\"ViT\"}\n",
    "M = [name_map.get(m[\"model\"], m[\"model\"]) for m in models]\n",
    "\n",
    "def dotplot(task, acc_key, f1_key, out_base):\n",
    "    y = np.arange(len(M))\n",
    "    acc = [m[acc_key] for m in models]\n",
    "    f1  = [m[f1_key]  for m in models]\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(6, 3.8))\n",
    "    ax.plot(acc, y, 'o', label=\"Accuracy\")      # marker only, no color specified\n",
    "    ax.plot(f1,  y, 's', label=\"Macro-F1\")\n",
    "    ax.set_yticks(y); ax.set_yticklabels(M)\n",
    "    ax.set_xlim(0, 105); ax.set_xlabel(\"Score (%)\")\n",
    "    ax.set_title(task); ax.grid(axis=\"x\", linestyle=\"--\", alpha=0.4)\n",
    "    ax.legend(loc=\"lower right\", fontsize=8)\n",
    "    fig.tight_layout(); fig.savefig(f\"figures/{out_base}.pdf\", bbox_inches=\"tight\")\n",
    "    fig.savefig(f\"figures/{out_base}.png\", dpi=300, bbox_inches=\"tight\")\n",
    "\n",
    "dotplot(\"Energy loss\", \"Acc_energy\",\"F1_energy\",\"energy_dotplot\")\n",
    "dotplot(\"αs\",          \"Acc_alpha\",\"F1_alpha\",\"alpha_dotplot\")\n",
    "dotplot(\"Q0\",          \"Acc_q0\",\"F1_q0\",\"q0_dotplot\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e93392c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save as: figures/energy_dumbbell.(pdf|png), alpha_dumbbell.(pdf|png), q0_dumbbell.(pdf|png)\n",
    "import json, numpy as np, matplotlib.pyplot as plt\n",
    "models = json.load(open(\"best_models_metrics.json\"))\n",
    "name_map = {\"MambaOut_base_plus_rw\":\"Mamba\",\"ConvNeXt_Gaussian_g500\":\"ConvNeXt\",\n",
    "            \"EfficientNet_g500\":\"EfficientNet\",\"Swin_g500\":\"Swin\",\n",
    "            \"ViT_tiny_patch16_224_gaussian_lrp_12_rlrp_4\":\"ViT\"}\n",
    "M = [name_map.get(m[\"model\"], m[\"model\"]) for m in models]\n",
    "\n",
    "def dumbbell(task, acc_key, f1_key, out_base):\n",
    "    y = np.arange(len(M))\n",
    "    acc = [m[acc_key] for m in models]\n",
    "    f1  = [m[f1_key]  for m in models]\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(6, 3.8))\n",
    "    for i in range(len(M)):\n",
    "        ax.plot([f1[i], acc[i]], [y[i], y[i]], '-', alpha=0.7)  # line between points\n",
    "        ax.plot(f1[i],  y[i], 's')\n",
    "        ax.plot(acc[i], y[i], 'o')\n",
    "    ax.set_yticks(y); ax.set_yticklabels(M)\n",
    "    ax.set_xlim(0, 105); ax.set_xlabel(\"Score (%)\")\n",
    "    ax.set_title(task); ax.grid(axis=\"x\", linestyle=\"--\", alpha=0.4)\n",
    "    fig.tight_layout(); fig.savefig(f\"figures/{out_base}.pdf\", bbox_inches=\"tight\")\n",
    "    fig.savefig(f\"figures/{out_base}.png\", dpi=300, bbox_inches=\"tight\")\n",
    "\n",
    "dumbbell(\"Energy loss\", \"Acc_energy\",\"F1_energy\",\"energy_dumbbell\")\n",
    "dumbbell(\"αs\",          \"Acc_alpha\",\"F1_alpha\",\"alpha_dumbbell\")\n",
    "dumbbell(\"Q0\",          \"Acc_q0\",\"F1_q0\",\"q0_dumbbell\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9a5c684",
   "metadata": {},
   "outputs": [],
   "source": [
    "# saves: figures/metrics_heatmap.(pdf|png)\n",
    "import json, numpy as np, matplotlib.pyplot as plt\n",
    "models = json.load(open(\"best_models_metrics.json\"))\n",
    "name_map = {\"MambaOut_base_plus_rw\":\"Mamba\",\"ConvNeXt_Gaussian_g500\":\"ConvNeXt\",\n",
    "            \"EfficientNet_g500\":\"EfficientNet\",\"Swin_g500\":\"Swin\",\n",
    "            \"ViT_tiny_patch16_224_gaussian_lrp_12_rlrp_4\":\"ViT\"}\n",
    "M = [name_map.get(m[\"model\"], m[\"model\"]) for m in models]\n",
    "cols = [\"Acc_energy\",\"F1_energy\",\"Acc_alpha\",\"F1_alpha\",\"Acc_q0\",\"F1_q0\",\"Acc_total\"]\n",
    "labels = [\"Acc E\",\"F1 E\",\"Acc αs\",\"F1 αs\",\"Acc Q0\",\"F1 Q0\",\"Acc_total\"]\n",
    "\n",
    "Z = np.array([[m[c] for c in cols] for m in models])\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7, 3.8))\n",
    "im = ax.imshow(Z, aspect=\"auto\",cmap=\"Oranges\", vmin=0, vmax=100)\n",
    "ax.set_yticks(range(len(M))); ax.set_yticklabels(M)\n",
    "ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, rotation=30, ha=\"right\")\n",
    "\n",
    "# annotate\n",
    "for i in range(Z.shape[0]):\n",
    "    for j in range(Z.shape[1]):\n",
    "        ax.text(j, i, f\"{Z[i,j]:.1f}\", ha=\"center\", va=\"center\", fontsize=8)\n",
    "\n",
    "ax.set_title(\"Per-model metrics overview\")\n",
    "fig.tight_layout(); fig.savefig(\"figures/metrics_heatmap.pdf\", bbox_inches=\"tight\")\n",
    "fig.savefig(\"figures/metrics_heatmap.png\", dpi=300, bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4adf2140",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json, numpy as np, matplotlib.pyplot as plt\n",
    "\n",
    "models = json.load(open(\"best_models_metrics.json\"))\n",
    "name_map = {\"MambaOut_base_plus_rw\":\"Mamba\",\"ConvNeXt_Gaussian_g500\":\"ConvNeXt\",\n",
    "            \"EfficientNet_g500\":\"EfficientNet\",\"Swin_g500\":\"Swin\",\n",
    "            \"ViT_tiny_patch16_224_gaussian_lrp_12_rlrp_4\":\"ViT\"}\n",
    "M = [name_map.get(m[\"model\"], m[\"model\"]) for m in models]\n",
    "\n",
    "cols = [\"Acc_energy\",\"F1_energy\",\"Acc_alpha\",\"F1_alpha\",\"Acc_q0\",\"F1_q0\",\"Acc_total\"]\n",
    "labels = [\n",
    "    r\"$\\mathrm{Acc}_{E}$\", \n",
    "    r\"$\\mathrm{F1}_{E}$\", \n",
    "    r\"$\\mathrm{Acc}_{\\alpha_s}$\", \n",
    "    r\"$\\mathrm{F1}_{\\alpha_s}$\", \n",
    "    r\"$\\mathrm{Acc}_{Q_0}$\", \n",
    "    r\"$\\mathrm{F1}_{Q_0}$\", \n",
    "    r\"$\\mathrm{Acc}_{\\mathrm{total}}$\"\n",
    "]\n",
    "\n",
    "Z = np.array([[m[c] for c in cols] for m in models])\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7, 3.8))\n",
    "im = ax.imshow(Z, aspect=\"auto\", cmap=\"Oranges\")   # or cmap=\"jet\"\n",
    "\n",
    "ax.set_yticks(range(len(M)))\n",
    "ax.set_yticklabels(M)\n",
    "ax.set_xticks(range(len(labels)))\n",
    "ax.set_xticklabels(labels, rotation=30, ha=\"right\")\n",
    "\n",
    "# annotate with numeric values\n",
    "for i in range(Z.shape[0]):\n",
    "    for j in range(Z.shape[1]):\n",
    "        ax.text(j, i, f\"{Z[i,j]:.1f}\", ha=\"center\", va=\"center\", fontsize=8)\n",
    "\n",
    "ax.set_title(\"Per-model metrics overview\")\n",
    "fig.tight_layout()\n",
    "fig.savefig(\"figures/metrics_heatmap.pdf\", bbox_inches=\"tight\")\n",
    "fig.savefig(\"figures/metrics_heatmap.png\", dpi=300, bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37d43919",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json, numpy as np, matplotlib.pyplot as plt\n",
    "\n",
    "models = json.load(open(\"best_models_metrics.json\"))\n",
    "name_map = {\"MambaOut_base_plus_rw\":\"Mamba\",\"ConvNeXt_Gaussian_g500\":\"ConvNeXt\",\n",
    "            \"EfficientNet_g500\":\"EfficientNet\",\"Swin_g500\":\"Swin\",\n",
    "            \"ViT_tiny_patch16_224_gaussian_lrp_12_rlrp_4\":\"ViT\"}\n",
    "M = [name_map.get(m[\"model\"], m[\"model\"]) for m in models]\n",
    "\n",
    "cols = [\"Acc_energy\",\"F1_energy\",\"Acc_alpha\",\"F1_alpha\",\"Acc_q0\",\"F1_q0\",\"Acc_total\"]\n",
    "labels = [\n",
    "    r\"$\\mathrm{Acc}_{E}$\", \n",
    "    r\"$\\mathrm{F1}_{E}$\", \n",
    "    r\"$\\mathrm{Acc}_{\\alpha_s}$\", \n",
    "    r\"$\\mathrm{F1}_{\\alpha_s}$\", \n",
    "    r\"$\\mathrm{Acc}_{Q_0}$\", \n",
    "    r\"$\\mathrm{F1}_{Q_0}$\", \n",
    "    r\"$\\mathrm{Acc}_{\\mathrm{total}}$\"\n",
    "]\n",
    "\n",
    "Z = np.array([[m[c] for c in cols] for m in models])\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7, 3.8))\n",
    "im = ax.imshow(Z, aspect=\"auto\", cmap=\"Oranges\")\n",
    "\n",
    "ax.set_yticks(range(len(M)))\n",
    "ax.set_yticklabels(M)\n",
    "ax.set_xticks(range(len(labels)))\n",
    "ax.set_xticklabels(labels, rotation=30, ha=\"right\")\n",
    "\n",
    "# draw grid lines (black cell borders)\n",
    "ax.set_xticks(np.arange(-0.5, Z.shape[1], 1), minor=True)\n",
    "ax.set_yticks(np.arange(-0.5, Z.shape[0], 1), minor=True)\n",
    "ax.grid(which=\"minor\", color=\"black\", linestyle=\"-\", linewidth=0.5)\n",
    "ax.tick_params(which=\"minor\", bottom=False, left=False)\n",
    "\n",
    "# annotate with numeric values\n",
    "for i in range(Z.shape[0]):\n",
    "    for j in range(Z.shape[1]):\n",
    "        ax.text(j, i, f\"{Z[i,j]:.1f}\", ha=\"center\", va=\"center\", fontsize=8)\n",
    "\n",
    "# no title\n",
    "fig.tight_layout()\n",
    "fig.savefig(\"figures/metrics_heatmap.pdf\", bbox_inches=\"tight\")\n",
    "fig.savefig(\"figures/metrics_heatmap.png\", dpi=300, bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f255eda1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# {\n",
    "#     \"model\": \"ViT_tiny_patch16_224_gaussian_lrp_12_rlrp_4\",\n",
    "#     \"Acc_energy\": 100.0,\n",
    "#     \"F1_energy\": 100.0,\n",
    "#     \"Acc_alpha\": 95.83,\n",
    "#     \"F1_alpha\": 95.83,\n",
    "#     \"Acc_q0\": 78.19,\n",
    "#     \"F1_q0\": 77.57,\n",
    "#     \"Acc_total\": 74.03\n",
    "#   },\n",
    "# {\n",
    "#     \"model\": \"MambaOut_base_plus_rw\",\n",
    "#     \"Acc_energy\": 100.0,\n",
    "#     \"F1_energy\": 100.0,\n",
    "#     \"Acc_alpha\": 93.89,\n",
    "#     \"F1_alpha\": 93.9,\n",
    "#     \"Acc_q0\": 75.21,\n",
    "#     \"F1_q0\": 74.98,\n",
    "#     \"Acc_total\": 69.1\n",
    "#   },\n",
    "# {\n",
    "#     \"model\": \"Swin_g500\",\n",
    "#     \"Acc_energy\": 99.79,\n",
    "#     \"F1_energy\": 99.72,\n",
    "#     \"Acc_alpha\": 88.19,\n",
    "#     \"F1_alpha\": 88.1,\n",
    "#     \"Acc_q0\": 63.06,\n",
    "#     \"F1_q0\": 61.96,\n",
    "#     \"Acc_total\": 51.39\n",
    "#   },\n",
    "# {\n",
    "#     \"model\": \"EfficientNet_g500\",\n",
    "#     \"Acc_energy\": 100.0,\n",
    "#     \"F1_energy\": 100.0,\n",
    "#     \"Acc_alpha\": 94.72,\n",
    "#     \"F1_alpha\": 94.71,\n",
    "#     \"Acc_q0\": 70.21,\n",
    "#     \"F1_q0\": 68.79,\n",
    "#     \"Acc_total\": 64.93\n",
    "#   },\n",
    "# {\n",
    "#     \"model\": \"ConvNeXt_Gaussian_g500\",\n",
    "#     \"Acc_energy\": 100.0,\n",
    "#     \"F1_energy\": 100.0,\n",
    "#     \"Acc_alpha\": 93.54,\n",
    "#     \"F1_alpha\": 93.53,\n",
    "#     \"Acc_q0\": 74.1,\n",
    "#     \"F1_q0\": 73.59,\n",
    "#     \"Acc_total\": 67.64\n",
    "#   }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a37cac8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json, numpy as np, matplotlib.pyplot as plt\n",
    "\n",
    "models = json.load(open(\"best_models_metrics.json\"))\n",
    "name_map = {\"MambaOut_base_plus_rw\":\"Mamba\",\"ConvNeXt_Gaussian_g500\":\"ConvNeXt\",\n",
    "            \"EfficientNet_g500\":\"EfficientNet\",\"Swin_g500\":\"Swin\",\n",
    "            \"ViT_tiny_patch16_224_gaussian_lrp_12_rlrp_4\":\"ViT\"}\n",
    "M = [name_map.get(m[\"model\"], m[\"model\"]) for m in models]\n",
    "\n",
    "cols = [\"Acc_energy\",\"F1_energy\",\"Acc_alpha\",\"F1_alpha\",\"Acc_q0\",\"F1_q0\",\"Acc_total\"]\n",
    "labels = [\n",
    "    r\"$\\mathrm{Acc}_{E}$\", \n",
    "    r\"$\\mathrm{F1}_{E}$\", \n",
    "    r\"$\\mathrm{Acc}_{\\alpha_s}$\", \n",
    "    r\"$\\mathrm{F1}_{\\alpha_s}$\", \n",
    "    r\"$\\mathrm{Acc}_{Q_0}$\", \n",
    "    r\"$\\mathrm{F1}_{Q_0}$\", \n",
    "    r\"$\\mathrm{Acc}_{\\mathrm{total}}$\"\n",
    "]\n",
    "\n",
    "Z = np.array([[m[c] for c in cols] for m in models])\n",
    "\n",
    "# make cells less tall by reducing height\n",
    "fig, ax = plt.subplots(figsize=(7, 2.2))\n",
    "im = ax.imshow(Z, aspect=\"auto\", cmap=\"Oranges\")\n",
    "\n",
    "ax.set_yticks(range(len(M)))\n",
    "ax.set_yticklabels(M,fontsize=12)\n",
    "ax.set_xticks(range(len(labels)))\n",
    "ax.set_xticklabels(labels,fontsize=12 )\n",
    "\n",
    "# black grid lines between cells\n",
    "ax.set_xticks(np.arange(-0.5, Z.shape[1], 1), minor=True)\n",
    "ax.set_yticks(np.arange(-0.5, Z.shape[0], 1), minor=True)\n",
    "ax.grid(which=\"minor\", color=\"black\", linestyle=\"-\", linewidth=0.5)\n",
    "ax.tick_params(which=\"minor\", bottom=False, left=False)\n",
    "\n",
    "# annotate with larger font\n",
    "for i in range(Z.shape[0]):\n",
    "    for j in range(Z.shape[1]):\n",
    "        ax.text(j, i, f\"{Z[i,j]:.1f}\", ha=\"center\", va=\"center\", fontsize=12)\n",
    "\n",
    "# no title\n",
    "fig.tight_layout()\n",
    "fig.savefig(\"figures/metrics_heatmap.pdf\", bbox_inches=\"tight\")\n",
    "fig.savefig(\"figures/metrics_heatmap.png\", dpi=300, bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "589b3a2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "models"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
