{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from random import choices\n",
    "import pickle\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path\n",
    "from typing import List, Tuple\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\") # Ignore all warnings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_Kronecker_Factors(N:int, layer_index:int, max_steps:int, choices_list:bool = None) -> Tuple:\n",
    "    \"\"\"\n",
    "    Retrieves Kronecker factors for specified layers from serialized files.\n",
    "\n",
    "    Parameters:\n",
    "        N (int): Number of layers to retrieve.\n",
    "        layer_index (int): Index of the layer to retrieve factors for.\n",
    "        max_steps (int): Maximum step interval to sample from.\n",
    "        choices_list (bool, optional): If provided, uses this list to sample steps. Otherwise, generates a new list.\n",
    "\n",
    "    Returns:\n",
    "        tuple: A tuple containing lists of H-bar and S matrices, the list of chosen steps, and the class name of the layer.\n",
    "    \n",
    "    Raises:\n",
    "        FileNotFoundError: If the directories for H-bar or S are not found.\n",
    "    \"\"\"\n",
    "    H_bar_dir = Path(\"H_bar_resnet\").expanduser()\n",
    "    S_dir = Path(\"S_resnet\").expanduser()\n",
    "    if not H_bar_dir.exists() or not S_dir.exists:\n",
    "        raise FileNotFoundError(f'H_bar_resnet and S_resnet directories not found')\n",
    "    \n",
    "    if choices_list is None:\n",
    "        choices_list = choices(range(0, max_steps + 100, 100), k=N)\n",
    "    H_bar, S = [], []\n",
    "    for i in choices_list:\n",
    "        with open(f'H_bar_resnet/H_bar_{i}.pkl', 'rb') as f:\n",
    "            dict_H_bar = pickle.load(f)\n",
    "        H_bar_l = list(dict_H_bar.values())[layer_index]\n",
    "        H_bar.append(H_bar_l.cpu().numpy())\n",
    "        with open(f'S_resnet/S_{i}.pkl', 'rb') as f:\n",
    "            dict_S = pickle.load(f)\n",
    "        S_l = list(dict_S.values())[layer_index]\n",
    "        layer_name = list(dict_S.keys())[layer_index].__class__.__name__\n",
    "        S.append(S_l.cpu().numpy())\n",
    "    return H_bar, S, choices_list, layer_name\n",
    "\n",
    "\n",
    "def plot_Kronecker_Factors(H_bar:List, S:List, choices_list:List, N:int, layer_name:str, \n",
    "                           layer_index:int, verbose:bool = False):\n",
    "    \"\"\"\n",
    "    Plots the Kronecker factors (H_bar and S matrices) for a specified layer.\n",
    "\n",
    "    Parameters:\n",
    "        H_bar (List): List of H_bar matrices to plot.\n",
    "        S (List): List of S matrices to plot.\n",
    "        choices_list (List): List of steps at which matrices were sampled.\n",
    "        N (int): Number of layers to plot.\n",
    "        layer_name (str): Name of the neural network layer.\n",
    "        layer_index (int): Index of the layer in the model.\n",
    "        verbose (bool, optional): If True, displays additional titles and information on the plots.\n",
    "\n",
    "    This function creates a subplot with 2*N elements, displaying each matrix with global min and max values for better visualization.\n",
    "    \"\"\"\n",
    "    # Compute global min and max for H_bar\n",
    "    all_values_H_bar = np.concatenate([matrix.ravel() for matrix in H_bar])\n",
    "    global_min_H_bar, global_max_H_bar = all_values_H_bar.min(), all_values_H_bar.max()\n",
    "    \n",
    "    # Compute global min and max for S\n",
    "    all_values_S = np.concatenate([matrix.ravel() for matrix in S])\n",
    "    global_min_S, global_max_S = all_values_S.min(), all_values_S.max()\n",
    "    \n",
    "    fig, axs = plt.subplots(1, 2*N, figsize=(15, 5*N), constrained_layout=True)\n",
    "\n",
    "    def plot_single_matrix(matrix, ax, vmin, vmax):\n",
    "        im = ax.imshow(matrix, cmap=\"gray\", interpolation='nearest', vmin=vmin, vmax=vmax)\n",
    "        ax.set_aspect('equal')\n",
    "        ax.grid(False)  \n",
    "        return im\n",
    "\n",
    "    # Plot H_bar matrices\n",
    "    for i, matrix in enumerate(H_bar):\n",
    "        axs[i].set_title(f'Matrix $\\\\mathcal{{H}}$ | Step {choices_list[i]}', fontsize=16, fontweight='bold')\n",
    "        im_A = plot_single_matrix(matrix, axs[i], global_min_H_bar, global_max_H_bar)\n",
    "\n",
    "    # Plot S matrices\n",
    "    for i, matrix in enumerate(S):\n",
    "        axs[i+N].set_title(f'Matrix $\\\\mathcal{{S}}$ | Step {choices_list[i]}', fontsize=16, fontweight='bold')\n",
    "        im_G = plot_single_matrix(matrix, axs[i+N], global_min_S, global_max_S)\n",
    "    if verbose:\n",
    "        plt.suptitle(f'Matrices $\\\\mathcal{{H}}$ and $\\\\mathcal{{S}}$ of {layer_name} layer at position {layer_index+1} of Resnet18', fontsize=20, fontweight='bold')\n",
    "\n",
    "    # Add colorbars\n",
    "    fig.colorbar(im_A, orientation='vertical', shrink=0.28)\n",
    "    fig.colorbar(im_G, orientation='vertical', shrink=0.28)\n",
    "    plt.subplots_adjust(wspace=0.1, hspace=0.3)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- Settings from Paper ----\n",
    "N = 2  # Number of matrices for H_bar and for S\n",
    "layer_index = 40 # 36 or 40\n",
    "max_steps = 9800\n",
    "choice_list = [5200, 9800]\n",
    "# ---- Settings from Paper ----\n",
    "H_bar, S, choices_list, layer_name = get_Kronecker_Factors(N, layer_index, max_steps, choice_list)\n",
    "plot_Kronecker_Factors(H_bar, S, choices_list, N, layer_name, layer_index)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
