{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook plots the pickle outputs to remake the heat maps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import LogNorm\n",
    "\n",
    "import pickle\n",
    "\n",
    "def plot_error_heatmap(test_results: dict[str, dict[str, float]]) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Visualize the test results as a heatmap with logarithmic scaling.\n",
    "\n",
    "    The sorting is expected to be 'fem', 'poly 1-8', 'sin 1-8', 'cos 1-8'\n",
    "\n",
    "    Args:\n",
    "        test_results: Dictionary of dictionaries containing test MSE values\n",
    "    \"\"\"\n",
    "    all_keys = test_results.keys()\n",
    "    num_keys = len(all_keys)\n",
    "    results = np.zeros((num_keys, num_keys))\n",
    "    \n",
    "    for i, key1 in enumerate(all_keys):\n",
    "        for j, key2 in enumerate(all_keys):\n",
    "            results[i, j] = test_results[key1][key2]\n",
    "    \n",
    "    \n",
    "    # Visualize the 2D array using imshow with logarithmic scaling\n",
    "    plt.figure(figsize=(6, 6))\n",
    "    plt.imshow(results, cmap='viridis', norm=LogNorm())\n",
    "    \n",
    "    plt.axhline(y=0.5, color='white', linestyle='--', linewidth=1)  # Between 1st and 2nd rows\n",
    "    plt.axvline(x=0.5, color='white', linestyle='--', linewidth=1)  # Between 1st and 2nd columns\n",
    "    plt.axhline(y=8.5, color='white', linestyle='--', linewidth=1)  # Between 9th and 10th rows\n",
    "    plt.axvline(x=8.5, color='white', linestyle='--', linewidth=1)  # Between 9th and 10th columns\n",
    "    plt.axhline(y=16.5, color='white', linestyle='--', linewidth=1)  # Between 9th and 10th rows\n",
    "    plt.axvline(x=16.5, color='white', linestyle='--', linewidth=1)  # Between 9th and 10th columns\n",
    "    \n",
    "    plt.xticks(ticks=np.arange(len(all_keys)), labels=all_keys, rotation=90)  # Rotate column labels for readability\n",
    "    plt.yticks(ticks=np.arange(len(all_keys)), labels=all_keys)\n",
    "    \n",
    "    plt.colorbar()  # Add a colorbar to show the scale\n",
    "    plt.title(\"Error Heatmap (Log Scale)\")\n",
    "    plt.xlabel(\"Testing data\")\n",
    "    plt.ylabel(\"Training data\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    \n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This pickle contains the matrix of evaluation errors. Any result file plots the same format. Change this to point to the file name your run wrote."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./tiny_linear_poisson_runs_20250521_001430.pkl', 'rb') as f:\n",
    "    results = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_error_heatmap(dict(results));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The tiny neural network models save the matrices to this file directly for interpretation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./tiny_linear_poisson_A_matrices_20250521_001430.pkl', 'rb') as f:\n",
    "    A_matrices = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 8))\n",
    "for i, (p_train, A_matrix) in enumerate(A_matrices):\n",
    "    plt.subplot(5, 5, i+1)\n",
    "    plt.imshow(A_matrix, cmap='viridis')\n",
    "    plt.clim(0, .012)\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    plt.title(f\" {p_train}\")\n",
    "plt.tight_layout()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
