{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T17:05:40.236778Z",
     "start_time": "2024-04-19T17:05:39.885583Z"
    }
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.patches import Circle\n",
    "from typing import List, Tuple\n",
    "from matplotlib.colors import Normalize\n",
    "from matplotlib.cm import ScalarMappable\n",
    "from pathlib import Path\n",
    "from random import choices\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\") # Ignore all warnings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gershgorin circle theorem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T17:05:43.899380Z",
     "start_time": "2024-04-19T17:05:43.891549Z"
    }
   },
   "outputs": [],
   "source": [
    "def add_gaussian_noise(matrix:np.array, mean:float = 0, std_dev:float = 10**-3) -> np.array:\n",
    "    \"\"\"\n",
    "    Adds Gaussian noise to the off-diagonal elements of a given matrix.\n",
    "\n",
    "    Parameters:\n",
    "        matrix (np.array): The matrix to which noise will be added.\n",
    "        mean (float, optional): The mean of the Gaussian noise. Default is 0.\n",
    "        std_dev (float, optional): The standard deviation of the Gaussian noise. Default is 0.001.\n",
    "\n",
    "    Returns:\n",
    "        np.array: The noisy matrix with updated off-diagonal elements.\n",
    "\n",
    "    This function iterates over the off-diagonal elements of the input matrix and adds Gaussian noise to each element.\n",
    "    \"\"\"\n",
    "    noisy_matrix = np.copy(matrix)\n",
    "    rows, cols = matrix.shape\n",
    "    for i in range(rows):\n",
    "        for j in range(cols):\n",
    "            if i != j:  # Only update off-diagonal elements\n",
    "                noisy_matrix[i, j] += np.random.normal(mean, std_dev)\n",
    "    return noisy_matrix\n",
    "\n",
    "def get_Kronecker_Factors(N:int, layer_index:int, max_steps:int, choices_list:bool = None, \n",
    "                          noise:bool = False, std:float = 10**-2) -> Tuple:\n",
    "    \"\"\"\n",
    "    Retrieves Kronecker factors for specified layers and optionally adds Gaussian noise.\n",
    "\n",
    "    Parameters:\n",
    "        N (int): Number of layer to retrieve for.\n",
    "        layer_index (int): Index of the layer to retrieve factors for.\n",
    "        max_steps (int): Maximum step interval to sample from.\n",
    "        choices_list (bool, optional): If provided, uses this list to sample steps. Otherwise, generates a new list.\n",
    "        noise (bool, optional): If True, adds Gaussian noise to the matrices. Default is False.\n",
    "        std (float, optional): Standard deviation of the Gaussian noise. Default is 0.01.\n",
    "\n",
    "    Returns:\n",
    "        tuple: A tuple containing lists of H_bar and S matrices, the list of chosen steps, and the class name of the layer.\n",
    "\n",
    "    Raises:\n",
    "        FileNotFoundError: If the directories for H-bar or S are not found.\n",
    "\n",
    "    This function can apply Gaussian noise to both H_bar and S matrices if specified, affecting only off-diagonal elements.\n",
    "    \"\"\"\n",
    "    H_bar_dir = Path(\"H_bar_resnet\").expanduser()\n",
    "    S_dir = Path(\"S_resnet\").expanduser()\n",
    "    if not H_bar_dir.exists() or not S_dir.exists:\n",
    "        raise FileNotFoundError(f'H_bar_resnet and S_resnet directories not found')\n",
    "    if choices_list is None:\n",
    "        choices_list = choices(range(50, max_steps + 50, 50), k=N)\n",
    "    H_bar, S = [], []\n",
    "    for i in choices_list:\n",
    "        with open(f'H_bar_resnet/H_bar_{i}.pkl', 'rb') as f:\n",
    "            dict_H_bar = pickle.load(f)\n",
    "        H_bar_l = list(dict_H_bar.values())[layer_index]\n",
    "        if noise:\n",
    "            H_bar.append(add_gaussian_noise(H_bar_l.cpu().numpy(), std_dev=std))\n",
    "        else:\n",
    "            H_bar.append(H_bar_l.cpu().numpy())\n",
    "        with open(f'S_resnet/S_{i}.pkl', 'rb') as f:\n",
    "            dict_S = pickle.load(f)\n",
    "        S_l = list(dict_S.values())[layer_index]\n",
    "        layer_name = list(dict_S.keys())[layer_index].__class__.__name__\n",
    "        if noise:\n",
    "            S.append(add_gaussian_noise(S_l.cpu().numpy(), std_dev=std))\n",
    "        else:\n",
    "            S.append(S_l.cpu().numpy())\n",
    "    return H_bar, S, choices_list, layer_name\n",
    "\n",
    "def plot_gershgorin_disks(H_bar:List, S:List, choices_list:List, N:int, \n",
    "                          layer_name:str, layer_index:int, verbose:bool = False):\n",
    "    \"\"\"\n",
    "    Plots Gershgorin disks for H-bar and S matrices of a specified layer to analyze their eigenvalues.\n",
    "\n",
    "    Parameters:\n",
    "        H_bar (List): List of H_bar matrices.\n",
    "        S (List): List of S matrices.\n",
    "        choices_list (List): Steps at which the matrices were sampled.\n",
    "        N (int): Number of layers to plot.\n",
    "        layer_name (str): Name of the layer.\n",
    "        layer_index (int): Index of the layer in the model.\n",
    "        verbose (bool, optional): If True, adds a detailed title to the plots.\n",
    "\n",
    "    This function visualizes the eigenvalue distribution and the Gershgorin disks, which estimate the eigenvalue bounds, for each matrix.\n",
    "    Eigenvalues are plotted as red 'x' marks, and Gershgorin disks are shown as circles on the complex plane.\n",
    "    \"\"\"\n",
    "    fig, axs = plt.subplots(2, N, figsize=(2*N, 2*N))\n",
    "    colormap = plt.cm.viridis\n",
    "\n",
    "    # Calculate normalization factors separately for H and S matrices\n",
    "    norm_A = Normalize(vmin=min([np.sum(np.abs(m), axis=1).min() for m in H_bar]), vmax=max([np.sum(np.abs(m), axis=1).max() for m in H_bar]))\n",
    "    norm_G = Normalize(vmin=min([np.sum(np.abs(m), axis=1).min() for m in S]), vmax=max([np.sum(np.abs(m), axis=1).max() for m in S]))\n",
    "\n",
    "    # ScalarMappables for creating the color bars\n",
    "    sm_A = ScalarMappable(norm=norm_A, cmap=colormap)\n",
    "    sm_G = ScalarMappable(norm=norm_G, cmap=colormap)\n",
    "\n",
    "    def plot_single_matrix(matrix, ax, index, type, norm):\n",
    "        eigenvalues = np.linalg.eigvals(matrix)\n",
    "        radius_values = [np.sum(np.abs(matrix[i, :])) - np.abs(matrix[i, i]) for i in range(len(matrix))]\n",
    "        for i, radius in enumerate(radius_values):\n",
    "            center = matrix[i, i]\n",
    "            color = colormap(norm(radius))\n",
    "            circle = Circle((center, 0), radius, color=color, fill=False, alpha=0.5)\n",
    "            ax.add_artist(circle)\n",
    "            ax.plot(center, 0, 'ko')\n",
    "        ax.plot(eigenvalues.real, eigenvalues.imag, 'rx')\n",
    "        if type == \"$\\\\mathcal{{S}}$\":\n",
    "            ax.set_xlabel('Real Part', fontsize=10)\n",
    "        if index == 0:\n",
    "            ax.set_ylabel('Imaginary Part', fontsize=10)\n",
    "        ax.set_title(f'Matrix {type} | Step {choices_list[index]}', fontsize=10, fontweight='bold')\n",
    "        ax.set_aspect('equal')\n",
    "        ax.grid(True)\n",
    "        ax.set_xlim([np.min(matrix.diagonal() - np.sum(np.abs(matrix), axis=1)), np.max(matrix.diagonal() + np.sum(np.abs(matrix), axis=1))])\n",
    "        ax.set_ylim([-np.max(np.sum(np.abs(matrix), axis=1)), np.max(np.sum(np.abs(matrix), axis=1))])\n",
    "        asp = np.diff(ax.get_xlim())[0] / np.diff(ax.get_ylim())[0]\n",
    "        ax.set_aspect(asp)\n",
    "\n",
    "    # Plot Gershgorin disks for H and S matrices\n",
    "    for i, matrix in enumerate(H_bar):\n",
    "        plot_single_matrix(matrix, axs[0, i], i, \"$\\\\mathcal{{H}}$\", norm_A)\n",
    "    for i, matrix in enumerate(S):\n",
    "        plot_single_matrix(matrix, axs[1, i], i, \"$\\\\mathcal{{S}}$\", norm_G)\n",
    "\n",
    "    if verbose:\n",
    "        plt.suptitle(f'Gershgorin Disks for Matrices $\\\\mathcal{{H}}$ and $\\\\mathcal{{S}}$ of {layer_name} layer at position {layer_index+1} of Resnet18',fontsize=10, fontweight='bold')\n",
    "\n",
    "    # Add color bars for H and S\n",
    "    cbar_ax_A = fig.add_axes([0.91, 0.56, 0.02, 0.30])\n",
    "    cbar_ax_G = fig.add_axes([0.91, 0.12, 0.02, 0.30])\n",
    "    fig.colorbar(sm_A,cax=cbar_ax_A)\n",
    "    fig.colorbar(sm_G,cax=cbar_ax_G)\n",
    "    plt.subplots_adjust(wspace=0.4, hspace=0.3)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-19T17:05:49.702468Z",
     "start_time": "2024-04-19T17:05:48.735640Z"
    }
   },
   "outputs": [],
   "source": [
    "# ---- Settings from Paper ----\n",
    "N = 2  # Number of matrices for H_bar and for S\n",
    "layer_index = 40 # 36 or 40\n",
    "max_steps = 9800\n",
    "choice_list = [5200, 9800]\n",
    "# ---- Settings from Paper ----\n",
    "H_bar, S, choices_list, layer_name = get_Kronecker_Factors(N, layer_index, max_steps, choice_list, noise=False)\n",
    "plot_gershgorin_disks(H_bar, S, choices_list, N, layer_name, layer_index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Perturbation Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_eigenvalues(H_bar:List, H_bar_noise:List, S:List, S_noise:List, choices_list:List, \n",
    "                     N:int, layer_name:str, layer_index:int, verbose:bool = False):\n",
    "    \"\"\"\n",
    "    Plots the eigenvalues of matrices with and without added noise, alongside a reference line based on the Kaiser Rule.\n",
    "\n",
    "    Parameters:\n",
    "        H_bar (List): List of original H-bar matrices.\n",
    "        H_bar_noise (List): List of H-bar matrices with noise added.\n",
    "        S (List): List of original S matrices.\n",
    "        S_noise (List): List of S matrices with noise added.\n",
    "        choices_list (List): Steps at which the matrices were sampled.\n",
    "        N (int): Number of layers to plot.\n",
    "        layer_name (str): Name of the layer.\n",
    "        layer_index (int): Index of the layer in the model.\n",
    "        verbose (bool, optional): If True, adds a detailed title to the plots.\n",
    "\n",
    "    This function visualizes the eigenvalues of the matrices on logarithmic scales for both axes, allowing for detailed \n",
    "    comparison between original and noisy conditions. It highlights the stability or instability introduced by noise.\n",
    "    \"\"\"\n",
    "    fig, axs = plt.subplots(2, N, figsize=(2*N, 2*N))\n",
    "    def plot_single_matrix(matrix,matrix_noise, ax, index, type):\n",
    "        # Calculate the eigenvalues\n",
    "        eigenvalues = np.linalg.eigh(matrix)[0]  # Use only the eigenvalues\n",
    "        eigenvalues_noise = np.linalg.eigh(matrix_noise)[0]\n",
    "        ax.plot(eigenvalues, '-', color='blue', markersize=6, linewidth=2, label=\"Without Noise\")  \n",
    "        ax.plot(eigenvalues_noise, '--', color='red', markersize=6, linewidth=2, label=\"With Noise\")  \n",
    "        ax.axhline(y=1, color='green', linestyle='--', linewidth=1.5, label=\"Kaiser Rule\")  \n",
    "        ax.set_title(f'Matrix {type} | Step {choices_list[index]}', fontsize=10, fontweight='bold')\n",
    "        ax.set_yscale('log')\n",
    "        ax.set_xscale('log')\n",
    "        if type == \"$\\\\mathcal{{S}}$\":\n",
    "            ax.set_xlabel('$\\\\lambda$', fontsize=10)\n",
    "        ax.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    " \n",
    "    for i, (matrix, matrix_noise) in enumerate(zip(H_bar, H_bar_noise)):\n",
    "        plot_single_matrix(matrix,matrix_noise, axs[0, i], i, \"$\\\\mathcal{{H}}$\")\n",
    "    for i, (matrix, matrix_noise) in enumerate(zip(S, S_noise)):\n",
    "        plot_single_matrix(matrix, matrix_noise, axs[1, i], i, \"$\\\\mathcal{{S}}$\")\n",
    "    if verbose:\n",
    "        plt.suptitle(f'Comparison of Eigenvalues: Original vs. Noisy $\\\\mathcal{{H}}$ and $\\\\mathcal{{S}}$ Matrices of {layer_name} layer at position {layer_index+1} of Resnet18', fontsize=16, fontweight='bold')\n",
    "    plt.subplots_adjust(wspace=0.4, hspace=0.4)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- Settings from Paper ----\n",
    "N = 2  # Number of matrices for H_bar and for S\n",
    "layer_index = 40 # 36 or 40\n",
    "max_steps = 9800\n",
    "choice_list = [5200, 9800]\n",
    "# ---- Settings from Paper ----\n",
    "A, G, choices_list, layer_name = get_Kronecker_Factors(N, layer_index, max_steps, choice_list)\n",
    "A_noise, G_noise, choices_list, layer_name = get_Kronecker_Factors(N, layer_index, max_steps, choice_list, noise=True, std=10**-3)\n",
    "plot_eigenvalues(A,A_noise, G, G_noise, choices_list, N, layer_name, layer_index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FFT and SNR analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_snr(signal:np.array, noise:np.array) -> np.array:\n",
    "    \"\"\"\n",
    "    Calculates the Signal-to-Noise Ratio (SNR) in decibels (dB) for given signal and noise arrays.\n",
    "\n",
    "    Parameters:\n",
    "        signal (np.array): The main signal array.\n",
    "        noise (np.array): The noise array mixed with the signal.\n",
    "\n",
    "    Returns:\n",
    "        np.array: The SNR of the signal in decibels.\n",
    "\n",
    "    This function computes the SNR by first calculating the power of the signal and noise, then using these \n",
    "    to compute the logarithmic ratio of signal power to noise power.\n",
    "    \"\"\"\n",
    "    power_signal = np.mean(signal ** 2)\n",
    "    power_noise = np.mean(noise ** 2)\n",
    "    return 10 * np.log10(power_signal / power_noise)\n",
    "\n",
    "def plot_fft(H_bar:List, H_bar_noise:List, S:List, S_noise:List, choices_list:List, N:int, \n",
    "             layer_name:str, layer_index:int, verbose:bool = False):\n",
    "    \"\"\"\n",
    "    Plots the Fast Fourier Transform (FFT) of matrices and their noisy counterparts, and calculates the SNR.\n",
    "\n",
    "    Parameters:\n",
    "        H_bar (List): List of original H-bar matrices.\n",
    "        H_bar_noise (List): List of H-bar matrices with added noise.\n",
    "        S (List): List of original S matrices.\n",
    "        S_noise (List): List of S matrices with added noise.\n",
    "        choices_list (List): List of steps at which the matrices were sampled.\n",
    "        N (int): Number of layers to plot.\n",
    "        layer_name (str): Name of the layer.\n",
    "        layer_index (int): Index of the layer in the model.\n",
    "        verbose (bool, optional): If True, adds a detailed title to the plots.\n",
    "\n",
    "    This function visualizes the magnitude spectrum of the FFT for each matrix, providing insights into \n",
    "    the frequency components of both original and noisy matrices. It also annotates each plot with the calculated SNR values.\n",
    "    \"\"\"\n",
    "    fig, axs = plt.subplots(4, N, figsize=(4*N, 6*N)) \n",
    "\n",
    "    def plot_single_matrix(matrix, ax, index, type):\n",
    "        fft_matrix = np.fft.fftshift(np.fft.fft2(matrix))\n",
    "        magnitude_spectrum = 20*np.log(np.abs(fft_matrix))\n",
    "        img = ax.imshow(magnitude_spectrum, cmap='hot', extent=[-np.pi, np.pi, -np.pi, np.pi])\n",
    "        ax.set_title(f'Matrix {type} | Step {choices_list[index]}', fontweight='bold')\n",
    "        asp = np.diff(ax.get_xlim())[0] / np.diff(ax.get_ylim())[0]\n",
    "        ax.set_aspect(asp)\n",
    "        return img\n",
    "\n",
    "    # Plot H_bar matrices and calculate SNR\n",
    "    snrs_A = []\n",
    "    for i, (matrix, matrix_noise) in enumerate(zip(H_bar, H_bar_noise)):\n",
    "        imgs_A = plot_single_matrix(matrix, axs[0, i], i, \"$\\\\mathcal{{H}}$\")\n",
    "        _ = plot_single_matrix(matrix_noise, axs[1, i], i, \"$\\\\hat{{\\\\mathcal{{H}}}}$\")\n",
    "        snr_value = calculate_snr(np.diag(matrix), np.triu(matrix_noise,1))\n",
    "        snrs_A.append(snr_value)\n",
    "        axs[1, i].annotate(f'SNR: {snr_value:.2f} dB', xy=(0.5, -0.25), xycoords='axes fraction', ha='center', va='center', fontweight='bold', fontsize=16)\n",
    "\n",
    "    # Plot S matrices and calculate SNR\n",
    "    snrs_G = []\n",
    "    for i, (matrix, matrix_noise) in enumerate(zip(S, S_noise)):\n",
    "        imgs_G = plot_single_matrix(matrix, axs[2, i], i, \"$\\\\mathcal{{S}}$\")\n",
    "        _ = plot_single_matrix(matrix_noise, axs[3, i], i, \"$\\\\hat{{\\\\mathcal{{S}}}}$\")\n",
    "        snr_value = calculate_snr(np.diag(matrix), np.triu(matrix_noise,1))\n",
    "        snrs_G.append(snr_value)\n",
    "        axs[3, i].annotate(f'SNR: {snr_value:.2f} dB', xy=(0.6, -0.25), xycoords='axes fraction', ha='center', va='center', fontweight='bold', fontsize=16)\n",
    "\n",
    "    # Create a color bar for the H_bar matrices\n",
    "    cax_A = fig.add_axes([0.78, 0.55, 0.02, 0.30])  \n",
    "    fig.colorbar(imgs_A, cax=cax_A)\n",
    "    # Create a color bar for the S matrices\n",
    "    cax_G = fig.add_axes([0.78, 0.14, 0.02, 0.30]) \n",
    "    fig.colorbar(imgs_G, cax=cax_G)\n",
    "\n",
    "    if verbose:\n",
    "        plt.suptitle(f'Comparison of FFT: Original vs. Noisy $\\\\mathcal{{H}}$ and $\\\\mathcal{{S}}$ Matrices of {layer_name} layer at position {layer_index+1}', fontsize=16, fontweight='bold')\n",
    "    plt.subplots_adjust(wspace=-0.4, hspace=0.5)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---- Settings from Paper ----\n",
    "N = 2  # Number of matrices for H_bar and for S\n",
    "layer_index = 40 # 36 or 40\n",
    "max_steps = 9800\n",
    "choice_list = [5200, 9800]\n",
    "# ---- Settings from Paper ----\n",
    "A, G, choices_list, layer_name = get_Kronecker_Factors(N, layer_index, max_steps, choice_list)\n",
    "A_noise, G_noise, choices_list, layer_name = get_Kronecker_Factors(N, layer_index, max_steps, choice_list, noise=True, std=10**-3)\n",
    "plot_fft(A,A_noise, G, G_noise, choices_list, N, layer_name, layer_index)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
