{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fee44565",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e63d66bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "NON_OVERFITTING = True\n",
    "MILD_OVERFITTING = False\n",
    "OVERFITTING = False\n",
    "assert sum([NON_OVERFITTING, MILD_OVERFITTING, OVERFITTING]) == 1, \"Only one mode can be True\"\n",
    "\n",
    "DATASET_NAME = \"cifar10\"\n",
    "MODEL_NAME = \"resnet18\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d0e1361",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -----------------------------\n",
    "# File paths\n",
    "# -----------------------------\n",
    "if NON_OVERFITTING:\n",
    "    BASE = Path(f\"{DATASET_NAME}--{MODEL_NAME}/non_overfitting\")\n",
    "elif MILD_OVERFITTING:\n",
    "    BASE = Path(f\"{DATASET_NAME}--{MODEL_NAME}/mild_overfitting\")\n",
    "elif OVERFITTING:\n",
    "    BASE = Path(f\"{DATASET_NAME}--{MODEL_NAME}/overfit\")\n",
    "METRICS_CSV = BASE / \"metrics.csv\"\n",
    "BOUNDS_CSV  = BASE / \"bound_val_01.csv\"\n",
    "\n",
    "# -----------------------------\n",
    "# Load data\n",
    "# -----------------------------\n",
    "metrics = pd.read_csv(METRICS_CSV)\n",
    "bounds  = pd.read_csv(BOUNDS_CSV)\n",
    "bounds[\"Bound5_mean\"] = bounds[\"Bound5_mean\"] - bounds[\"A1\"] - 0.2*(bounds[\"A3_seed_42\"] + bounds[\"A3_seed_43\"] + bounds[\"A3_seed_44\"] + bounds[\"A3_seed_45\"] + bounds[\"A3_seed_46\"])\n",
    "bounds[\"Bound3_mean\"] = bounds[\"Bound3_mean\"] - bounds[\"A1\"]\n",
    "\n",
    "# Keep only validation rows if the file contains both val/test\n",
    "if \"bound_set\" in bounds.columns:\n",
    "    bounds = bounds[bounds[\"bound_set\"].astype(str).str.lower() == \"val\"]\n",
    "bounds_use = bounds.loc[:, [\"epoch\", \"Bound5_mean\", \"Bound3_mean\"]].rename(columns={\"Bound5_mean\": \"g\", \"Bound3_mean\": \"Unc_TP\"})\n",
    "\n",
    "df = pd.merge(\n",
    "    metrics.loc[:, [\"epoch\", \"test_acc\"]],\n",
    "    bounds_use,\n",
    "    on=\"epoch\",\n",
    "    how=\"inner\"\n",
    ").sort_values(\"epoch\")\n",
    "\n",
    "# -----------------------------\n",
    "# Line plot\n",
    "# -----------------------------\n",
    "epochs = df[\"epoch\"].to_numpy()\n",
    "g_raw  = df[\"g\"].to_numpy()\n",
    "g_min, g_max = g_raw.min(), g_raw.max()\n",
    "g_norm = (g_raw - g_min) / (g_max - g_min + 1e-12)  # numeric safety\n",
    "acc = df[\"test_acc\"].to_numpy()\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7.2, 4.6))\n",
    "ax.plot(epochs, g_norm, linewidth=2, label=r\"$mac_h$\")\n",
    "ax.plot(epochs, acc, linewidth=2, label=\"Test accuracy\")\n",
    "\n",
    "ax.set_xlabel(\"Epoch\", fontsize=32)\n",
    "ax.set_ylim(0.0, 1.0)                      # shared scale\n",
    "ax.tick_params(axis=\"both\", labelsize=28)\n",
    "ax.grid(True, linestyle=\"--\", alpha=0.35)\n",
    "ax.legend(fontsize=24)\n",
    "\n",
    "fig.tight_layout()\n",
    "if NON_OVERFITTING:\n",
    "    path = f\"{DATASET_NAME}_{MODEL_NAME}_gnorm_acc_vs_epoch_non_overfitting.pdf\"\n",
    "elif MILD_OVERFITTING:\n",
    "    path = f\"{DATASET_NAME}_{MODEL_NAME}_gnorm_acc_vs_epoch_mild_overfitting.pdf\"\n",
    "elif OVERFITTING:\n",
    "    path = f\"{DATASET_NAME}_{MODEL_NAME}_gnorm_acc_vs_epoch_overfitting.pdf\"\n",
    "fig.savefig(path, bbox_inches=\"tight\", dpi=200)\n",
    "plt.show()\n",
    "\n",
    "# Calculate correlation\n",
    "corr = np.corrcoef(g_norm, acc)[0, 1]\n",
    "print(f\"Correlation between normalized g and accuracy: {corr:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "851b5aaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -----------------------------\n",
    "# File paths\n",
    "# -----------------------------\n",
    "if NON_OVERFITTING:\n",
    "    BASE = Path(f\"{DATASET_NAME}--{MODEL_NAME}/non_overfitting\")\n",
    "elif MILD_OVERFITTING:\n",
    "    BASE = Path(f\"{DATASET_NAME}--{MODEL_NAME}/mild_overfitting\")\n",
    "elif OVERFITTING:\n",
    "    BASE = Path(f\"{DATASET_NAME}--{MODEL_NAME}/overfit\")\n",
    "METRICS_CSV = BASE / \"metrics.csv\"\n",
    "BOUNDS_CSV  = BASE / \"bound_val_01.csv\"\n",
    "\n",
    "# -----------------------------\n",
    "# Load data\n",
    "# -----------------------------\n",
    "metrics = pd.read_csv(METRICS_CSV)\n",
    "bounds  = pd.read_csv(BOUNDS_CSV)\n",
    "bounds[\"Bound5_mean\"] = bounds[\"Bound5_mean\"] - bounds[\"A1\"]\n",
    "bounds[\"Bound3_mean\"] = bounds[\"Bound3_mean\"] - bounds[\"A1\"]\n",
    "\n",
    "# Keep only validation rows if the file contains both val/test\n",
    "if \"bound_set\" in bounds.columns:\n",
    "    bounds = bounds[bounds[\"bound_set\"].astype(str).str.lower() == \"val\"]\n",
    "bounds_use = bounds.loc[:, [\"epoch\", \"Bound5_mean\", \"Bound3_mean\"]].rename(columns={\"Bound5_mean\": \"g\", \"Bound3_mean\": \"Unc_TP\"})\n",
    "\n",
    "df = pd.merge(\n",
    "    metrics.loc[:, [\"epoch\", \"test_acc\", \"train_loss\"]],\n",
    "    bounds_use,\n",
    "    on=\"epoch\",\n",
    "    how=\"inner\"\n",
    ").sort_values(\"epoch\")\n",
    "\n",
    "# -----------------------------\n",
    "# Line plot\n",
    "# -----------------------------\n",
    "epochs = df[\"epoch\"].to_numpy()\n",
    "train_errs = df[\"train_loss\"].to_numpy()\n",
    "train_errs = train_errs\n",
    "train_errs_norm = (train_errs - train_errs.min()) / (train_errs.max() - train_errs.min() + 1e-12) # numeric safety\n",
    "acc = df[\"test_acc\"].to_numpy()\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7.2, 4.6))\n",
    "ax.plot(epochs, train_errs_norm, linewidth=2, label=\"Training error\")\n",
    "ax.plot(epochs, acc, linewidth=2, label=\"Test accuracy\")\n",
    "\n",
    "ax.set_xlabel(\"Epoch\", fontsize=32)\n",
    "ax.set_ylim(0.0, 1.0)                      # shared scale\n",
    "ax.tick_params(axis=\"both\", labelsize=28)\n",
    "ax.grid(True, linestyle=\"--\", alpha=0.35)\n",
    "ax.legend(fontsize=24)\n",
    "\n",
    "fig.tight_layout()\n",
    "if NON_OVERFITTING:\n",
    "    path = f\"{DATASET_NAME}_{MODEL_NAME}_trainerr_acc_vs_epoch_non_overfitting.pdf\"\n",
    "elif MILD_OVERFITTING:\n",
    "    path = f\"{DATASET_NAME}_{MODEL_NAME}_trainerr_acc_vs_epoch_mild_overfitting.pdf\"\n",
    "elif OVERFITTING:\n",
    "    path = f\"{DATASET_NAME}_{MODEL_NAME}_trainerr_acc_vs_epoch_overfitting.pdf\"\n",
    "fig.savefig(path, bbox_inches=\"tight\", dpi=200)\n",
    "plt.show()\n",
    "\n",
    "# Calculate correlation\n",
    "corr = np.corrcoef(acc, train_errs_norm)[0, 1]\n",
    "print(f\"Correlation between normalized training error and accuracy: {corr:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00520f97",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -----------------------------\n",
    "# File paths\n",
    "# -----------------------------\n",
    "if NON_OVERFITTING:\n",
    "    BASE = Path(f\"{DATASET_NAME}--{MODEL_NAME}/non_overfitting\")\n",
    "elif MILD_OVERFITTING:\n",
    "    BASE = Path(f\"{DATASET_NAME}--{MODEL_NAME}/mild_overfitting\")\n",
    "elif OVERFITTING:\n",
    "    BASE = Path(f\"{DATASET_NAME}--{MODEL_NAME}/overfit\")\n",
    "METRICS_CSV = BASE / \"metrics.csv\"\n",
    "BOUNDS_CSV  = BASE / \"bound_val_01.csv\"\n",
    "\n",
    "# -----------------------------\n",
    "# Load data\n",
    "# -----------------------------\n",
    "metrics = pd.read_csv(METRICS_CSV)\n",
    "bounds  = pd.read_csv(BOUNDS_CSV)\n",
    "bounds[\"Bound5_mean\"] = bounds[\"Bound5_mean\"] - bounds[\"A1\"]\n",
    "bounds[\"Bound3_mean\"] = bounds[\"Bound3_mean\"] - bounds[\"A1\"]\n",
    "\n",
    "# Keep only validation rows if the file contains both val/test\n",
    "if \"bound_set\" in bounds.columns:\n",
    "    bounds = bounds[bounds[\"bound_set\"].astype(str).str.lower() == \"val\"]\n",
    "bounds_use = bounds.loc[:, [\"epoch\", \"Bound5_mean\", \"Bound3_mean\"]].rename(columns={\"Bound5_mean\": \"g\", \"Bound3_mean\": \"Unc_TP\"})\n",
    "\n",
    "df = pd.merge(\n",
    "    metrics.loc[:, [\"epoch\", \"test_acc\", \"train_loss\", \"val_loss\"]],\n",
    "    bounds_use,\n",
    "    on=\"epoch\",\n",
    "    how=\"inner\"\n",
    ").sort_values(\"epoch\")\n",
    "\n",
    "# -----------------------------\n",
    "# Line plot\n",
    "# -----------------------------\n",
    "epochs = df[\"epoch\"].to_numpy()\n",
    "val_errs = df[\"val_loss\"].to_numpy()\n",
    "val_errs = val_errs\n",
    "val_errs_norm = (val_errs - val_errs.min()) / (val_errs.max() - val_errs.min() + 1e-12) # numeric safety\n",
    "acc = df[\"test_acc\"].to_numpy()\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7.2, 4.6))\n",
    "ax.plot(epochs, val_errs_norm, linewidth=2, label=\"Validation error\")\n",
    "ax.plot(epochs, acc, linewidth=2, label=\"Test accuracy\")\n",
    "\n",
    "ax.set_xlabel(\"Epoch\", fontsize=32)\n",
    "ax.set_ylim(0.0, 1.0)                      # shared scale\n",
    "ax.tick_params(axis=\"both\", labelsize=28)\n",
    "ax.grid(True, linestyle=\"--\", alpha=0.35)\n",
    "ax.legend(fontsize=24)\n",
    "\n",
    "fig.tight_layout()\n",
    "if NON_OVERFITTING:\n",
    "    path = f\"{DATASET_NAME}_{MODEL_NAME}_valerr_acc_vs_epoch_non_overfitting.pdf\"\n",
    "elif MILD_OVERFITTING:\n",
    "    path = f\"{DATASET_NAME}_{MODEL_NAME}_valerr_acc_vs_epoch_mild_overfitting.pdf\"\n",
    "elif OVERFITTING:\n",
    "    path = f\"{DATASET_NAME}_{MODEL_NAME}_valerr_acc_vs_epoch_overfitting.pdf\"\n",
    "fig.savefig(path, bbox_inches=\"tight\", dpi=200)\n",
    "plt.show()\n",
    "\n",
    "# Calculate correlation\n",
    "corr = np.corrcoef(acc, val_errs_norm)[0, 1]\n",
    "print(f\"Correlation between normalized validation error and accuracy: {corr:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1280e9d9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dgm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
