{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append(\"../Synaptic-Flow/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib as mpl\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from matplotlib.colors import ListedColormap\n",
    "plt.rcParams[\"figure.figsize\"] = (7.5, 5)\n",
    "from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes\n",
    "from mpl_toolkits.axes_grid1.inset_locator import mark_inset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dict = {'resnet20' : \"ResNet-20\", \n",
    "              'wide-resnet20' : \"WideResNet-20\",\n",
    "              'fc-1000' : \"FC-1000\"}\n",
    "pruner_dict = {'synflow': \"SynFlow\", 'grasp': \"GraSP\", 'snip': \"SNIP\", 'mag': \"Magnitude\", 'rand': \"Random\", \n",
    "              'synflow-dist': \"SynFlow-Dist\", 'synflow-l2': \"SynFlow-L2\", 'synflow-dist-l2': \"SynFlow-L2-Dist\"}\n",
    "dataset_dict = {'cifar10': \"CIFAR-10\", 'cifar100' : \"CIFAR-100\"}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def get_dataframe(dim='path_kernel', \n",
    "                  model_class_list=['default', 'lottery'],\n",
    "                  model_list=['fc', 'fc-1000', 'resnet20', 'wide-resnet20'],\n",
    "                  pruner_list = ['synflow'],\n",
    "                  dataset_list = ['cifar10', 'cifar100'],\n",
    "                  seed_list = ['83', '1337'], # later on need 23, 923\n",
    "                  comp_list = [\"0.0\", \"1.0\"], \n",
    "                 root_dir=\"../Results/pruned/\"):\n",
    "    return_df = {}\n",
    "    for model_class in model_class_list:\n",
    "        for model in model_list:\n",
    "            for pruner in pruner_list:\n",
    "                curr_dir = root_dir + f\"{model_class}/{model}/{pruner}/\"\n",
    "                for dataset in dataset_list:\n",
    "                    for seed in seed_list:\n",
    "                        for comp in comp_list:\n",
    "                            file = curr_dir + f\"{dataset}_{seed}_{comp}.csv\"\n",
    "                            # print(f\"Trying to get dataframe from file: {file}\")\n",
    "                            if (os.path.exists(file)):\n",
    "                                # print(f\"Reading from file: {file}\")\n",
    "                                df = pd.read_csv(file)\n",
    "                                if (dim not in df.columns):\n",
    "                                    pass\n",
    "                                    # print(f\"Didn't find {dim} in file: {file}\")\n",
    "                                elif (np.abs(df['train_loss'].iloc[1] - df['train_loss'].iloc[100]) < 0.5):\n",
    "                                    pass\n",
    "                                    # print(f\"Didn't train: {file}\")\n",
    "                                    \n",
    "                                else:\n",
    "                                    \n",
    "                                    return_df[f\"{model}--{dataset}--{pruner}--{comp}--{seed}\"] = df[dim].iloc[:100].values\n",
    "                                        \n",
    "                            else:\n",
    "                                pass\n",
    "                                # print(\"Returning none dataframe as \", file, \" doesn't exist\")\n",
    "                                # return None\n",
    "    return  pd.DataFrame(return_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_dataframes(dims = ['path_kernel', 'weight_movement_norm'],\n",
    "                    model_class_list_list = [[\"lottery\", \"default\"]],\n",
    "                    model_list_list = [['resnet20', 'wide-resnet20']],\n",
    "                   pruner_list_list = [['synflow']],\n",
    "                   dataset_list_list = [['cifar10'], ['cifar100']],\n",
    "                   comp_list_list = [[\"0.0\"], [\"1.0\"]],\n",
    "                    seed_list=['1337', '82' ,'821', '23', '923', '83'],\n",
    "                   root_dir=\"../Results/pruned/\"):\n",
    "    for dim in dims:\n",
    "        for model_class_list in model_class_list_list:\n",
    "            for model_list in model_list_list:\n",
    "                for pruner_list in pruner_list_list:\n",
    "                    for dataset_list in dataset_list_list:\n",
    "                        for comp_list in comp_list_list:\n",
    "                            df = get_dataframe(dim=dim,\n",
    "                                               model_class_list=model_class_list,\n",
    "                                               model_list=model_list,\n",
    "                                               pruner_list=pruner_list,\n",
    "                                               dataset_list=dataset_list,\n",
    "                                               comp_list=comp_list,\n",
    "                                               seed_list = seed_list)\n",
    "                            \n",
    "                            ## need to get the path kernel df for the color gradient\n",
    "                            path_kernel_df = get_dataframe(dim='path_kernel',\n",
    "                                               model_class_list=model_class_list,\n",
    "                                               model_list=model_list,\n",
    "                                               pruner_list=pruner_list,\n",
    "                                               dataset_list=dataset_list,\n",
    "                                               comp_list=comp_list,\n",
    "                                               seed_list = seed_list)\n",
    "                            if path_kernel_df is None or len(path_kernel_df) == 0:\n",
    "                                continue\n",
    "                            path_kernel_df = path_kernel_df.drop(path_kernel_df.index[0])\n",
    "                            first_epoch_path_kernel_values = path_kernel_df.iloc[0]\n",
    "                            min_path_kernel = first_epoch_path_kernel_values.min()\n",
    "                            max_path_kernel = first_epoch_path_kernel_values.max()\n",
    "                            # print(min_path_kernel, max_path_kernel)\n",
    "                            if df is not None:\n",
    "                                df = df.drop(df.index[0])\n",
    "                                first_epoch_values = df.iloc[0]\n",
    "                                # print(first_epoch_values.min(), first_epoch_values.max())\n",
    "                                if dim == \"path_kernel\":\n",
    "                                    df = df.sub(first_epoch_values)\n",
    "                                    df = df.div(first_epoch_values)\n",
    "                                elif dim == \"weight_movement_norm\":\n",
    "                                    df = df.div(first_epoch_values)\n",
    "                                norm_first_epoch = plt.Normalize(vmin=min_path_kernel, \n",
    "                                                                 vmax=max_path_kernel)\n",
    "                                colors = plt.cm.coolwarm(norm_first_epoch(first_epoch_path_kernel_values.values))\n",
    "                                legend_columns = []\n",
    "                                for col in df.columns:\n",
    "                                    model, dataset, pruner, comp, seed = col.split('--')\n",
    "                                    label = f\"{model_dict[model]}--{pruner_dict[pruner]}\"\n",
    "                                    legend_columns.append(label)\n",
    "\n",
    "                                colors_list = [list(c) for c in colors]\n",
    "                                style_list_2 = ['solid', 'dotted', 'dashed', 'dashdot',':',(0, (3, 5, 1, 5, 1, 5)), (0, (5, 10)), (0, (3, 1, 1, 1))]\n",
    "                                marker_list = [r'$1$', r'$2$', r'$3$', r'$4$', r'$5$', r'$6$', r'$7$', r'$8$', r'$9$', r'$a$', r'$b$', r'$c$', r'$d$', r'$e$', r'$f$', r'$g$', r'$h$']\n",
    "                                color_dict = {}\n",
    "                                \n",
    "                                fig, ax = plt.subplots()\n",
    "                                loc = 2\n",
    "                                if dim == \"weight_movement_norm\":\n",
    "                                    loc = 4\n",
    "                                for idx, name in enumerate(df.columns):\n",
    "                                    color_dict[name] = mpl.colors.rgb2hex(colors_list[idx][:3])\n",
    "                                \n",
    "                                marker_plot_dict = {}\n",
    "                                marker_counter = 0\n",
    "                                marker_symbol = '+'\n",
    "                                line_plots = [] \n",
    "                                for idx, (name, data) in enumerate(df.iteritems()):\n",
    "                                    # name looks like: resnet20--cifar100--synflow--0.0--1337\n",
    "                                    model_name, dataset_name, pruner_name, comp_ratio = name.split('.')[0].split('--') # resnet20--cifar100--synflow--0\n",
    "                                    \n",
    "                                    marker_plot_label = f\"{model_dict[model_name]}, {pruner_dict[pruner_name]}\"\n",
    "                                    if comp_list[0] == '0.0':\n",
    "                                        marker_plot_label = f\"{model_dict[model_name]}\"\n",
    "                                    # if comp_ratio == '1':\n",
    "                                    #    marker_plot_label += \", 10\\% pruned\"\n",
    "                                    if (marker_plot_label in marker_plot_dict):\n",
    "                                        marker = marker_plot_dict[marker_plot_label]\n",
    "                                        ax.plot(df[name], color = color_dict[name], marker=marker_symbol, markevery=15, ms=10)\n",
    "                                    else: \n",
    "                                        marker_plot_dict[marker_plot_label] = marker_list[marker_counter]\n",
    "                                        marker_symbol = marker_plot_dict[marker_plot_label]\n",
    "                                        marker_counter += 1\n",
    "                                        ax.plot(df[name], color = color_dict[name], marker=marker_symbol, markevery=15, ms=10, label=marker_plot_label)\n",
    "\n",
    "                                ax.legend()\n",
    "                        \n",
    "                                cbax = fig.add_axes([0.91, 0.15, 0.01, 0.7])\n",
    "                                \n",
    "                                plt.colorbar(plt.cm.ScalarMappable(norm=norm_first_epoch, cmap=plt.cm.coolwarm), cax=cbax)\n",
    "                                \n",
    "                                pruning_str = \"\"\n",
    "                                if comp_list[0] == '1.0':\n",
    "                                    pruning_str = \"- Pruning 10\\%\"\n",
    "                                if dim == 'path_kernel':\n",
    "                                    ax.set_title(label=f\"Relative Change in Path Kernel - {dataset_dict[dataset_list[0]]} {pruning_str}\")\n",
    "                                    ax.set_xlim(0, 100)\n",
    "                                    ax.set_xlabel(\"Epochs\")\n",
    "                                    ax.set_ylabel(r\"$\\text{Tr}(\\bm{\\Pi}_{\\bm{\\theta}_t}) - \\text{Tr}(\\bm{\\Pi}_{\\bm{\\theta}_0})$\")\n",
    "                                    \n",
    "                                if dim == 'weight_movement_norm':\n",
    "                                    ax.set_title(label=f\"Relative Change in Parameters - {dataset_dict[dataset_list[0]]} {pruning_str}\")\n",
    "                                    ax.set_xlim(0, 100)\n",
    "                                    ax.set_xlabel(\"Epochs\")\n",
    "                                    ax.set_ylabel(r\"$\\bm{\\omega}_t$\")\n",
    "                                    \n",
    "                                filename = root_dir+f\"{dim}--compression-{','.join(comp_list)}-pruner-{','.join(pruner_list)}--model-{','.join(model_list)}--dataset-{','.join(dataset_list)}.png\"\n",
    "                                # print(f\"Saving file: {filename}\")\n",
    "                                plt.savefig(filename, dpi = 250)\n",
    "                                plt.show()\n",
    "                                print(\"---\"*25)\n",
    "                        \n",
    "                        \n",
    "                        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot: Path Kernel per dataset {CIFAR10/100}, per compression {0.0/1.0}, per model {fc-1000, ResNet-20, Wide-Resnet-20} for all pruners"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.rcParams[\"text.usetex\"] = True\n",
    "plt.rcParams['text.latex.preamble']=[r\"\\usepackage{bm}\", r'\\usepackage{amsmath}']\n",
    "\n",
    "plot_dataframes(dims=[\"path_kernel\"], \n",
    "                comp_list_list = [[\"0.0\"], [\"1.0\"]], \n",
    "                dataset_list_list = [['cifar100'], [\"cifar10\"]], \n",
    "                model_list_list = [['fc-1000'], ['resnet20'], ['wide-resnet20']],\n",
    "               pruner_list_list = [['synflow', 'grasp', 'rand', 'snip', 'synflow-dist', 'synflow-l2', 'synflow-dist-l2']])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot: Path Kernel per dataset {CIFAR10/100}, per compression {0.0/1.0}, per pruner {SynFlow, GraSP, Rand, SNIP, SynFlow-L2, SynFlow-Dist, SynFlow-L2-Dist} for all models {fc-1000, ResNet-20, Wide-Resnet-20}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.rcParams[\"text.usetex\"] = True\n",
    "plt.rcParams['text.latex.preamble']=[r\"\\usepackage{bm}\", r'\\usepackage{amsmath}']\n",
    "\n",
    "plot_dataframes(dims=[\"path_kernel\"], \n",
    "                comp_list_list = [[\"0.0\"], [\"1.0\"]], \n",
    "                dataset_list_list = [['cifar100'], [\"cifar10\"]], \n",
    "                model_list_list = [['fc-1000', 'resnet20', 'wide-resnet20']],\n",
    "               pruner_list_list = [['synflow'], ['grasp'], ['rand'], ['snip'], ['synflow-dist'], ['synflow-l2'], ['synflow-dist-l2']])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot: Weight Movement per dataset {CIFAR10/100}, per compression {0.0/1.0}, per model {fc-1000, ResNet-20, Wide-Resnet-20} for all pruners"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.rcParams[\"text.usetex\"] = True\n",
    "plt.rcParams['text.latex.preamble']=[r\"\\usepackage{bm}\"]\n",
    "\n",
    "plot_dataframes(dims=[\"weight_movement_norm\"], \n",
    "                comp_list_list = [[\"0.0\"], [\"1.0\"]], \n",
    "                dataset_list_list = [['cifar100'], [\"cifar10\"]], \n",
    "                model_list_list = [['fc-1000'], ['resnet20'], ['wide-resnet20']],\n",
    "               pruner_list_list = [['synflow', 'grasp', 'rand', 'snip', 'synflow-dist', 'synflow-l2', 'synflow-dist-l2']])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot: Weight Movement per dataset {CIFAR10/100}, per compression {0.0/1.0}, per pruner {SynFlow, GraSP, Rand, SNIP, SynFlow-L2, SynFlow-Dist, SynFlow-L2-Dist} for all models {fc-1000, ResNet-20, Wide-Resnet-20}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams[\"text.usetex\"] = True\n",
    "plt.rcParams['text.latex.preamble']=[r\"\\usepackage{bm}\"]\n",
    "\n",
    "plot_dataframes(dims=[\"weight_movement_norm\"], \n",
    "                comp_list_list = [[\"0.0\"], [\"1.0\"]], \n",
    "                dataset_list_list = [['cifar100'], [\"cifar10\"]], \n",
    "                model_list_list = [['fc-1000', 'resnet20', 'wide-resnet20']],\n",
    "               pruner_list_list = [['synflow'], ['grasp'], ['rand'], ['snip'], ['synflow-dist'], ['synflow-l2'], ['synflow-dist-l2']])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Output across different models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sorted_values(directory):\n",
    "    mean_dict = {}\n",
    "    sum_dict = {}\n",
    "    norm_dict = {}\n",
    "    if os.path.exists(directory):\n",
    "        # print(directory)\n",
    "        for file in os.listdir(directory):\n",
    "            if \"_output\" in file:\n",
    "                # print(file)\n",
    "                iteration = int(file.split(\"_\")[0])\n",
    "                output = np.load(f\"{directory}/{file}\")\n",
    "\n",
    "                mean_dict[iteration] = output.mean()\n",
    "                sum_dict[iteration] = output.sum()\n",
    "                norm_dict[iteration] = np.linalg.norm(output)\n",
    "\n",
    "    sorted_mean_dict = dict(sorted(mean_dict.items()))\n",
    "    sorted_sum_dict = dict(sorted(sum_dict.items()))\n",
    "    sorted_norm_dict = dict(sorted(norm_dict.items()))\n",
    "    \n",
    "    return sorted_mean_dict, sorted_sum_dict, sorted_norm_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_logit_dataframe(model_class_list=['lottery'],\n",
    "                        model_list=['resnet20', 'wide-resnet20'],\n",
    "                        pruner_list = ['synflow'],\n",
    "                        dataset_list = ['cifar10', 'cifar100'],\n",
    "                        seed_list = ['83', '1337'],\n",
    "                        comp_list = [\"0.0\", \"1.0\"], \n",
    "                        root_dir=\"../Results/pruned/\", op=\"Mean\"):\n",
    "    marker_list = [r'$1$', r'$2$', r'$3$', r'$4$', r'$5$', r'$6$', r'$7$', r'$8$', r'$9$', r'$a$', r'$b$', r'$c$', r'$d$', r'$e$', r'$f$', r'$g$', r'$h$']\n",
    "\n",
    "    for model_class in model_class_list:\n",
    "        for model in model_list:\n",
    "            for dataset in dataset_list:\n",
    "                for comp in comp_list:\n",
    "                    df = pd.DataFrame(columns=[])\n",
    "      \n",
    "                    # print(\"In: \", curr_dir)\n",
    "                    path_kernel_df = get_dataframe(dim='path_kernel', \n",
    "                                                   model_class_list=[model_class],\n",
    "                                                   model_list=[model],\n",
    "                                                   pruner_list=pruner_list,\n",
    "                                                   dataset_list=[dataset],\n",
    "                                                   comp_list=[comp],\n",
    "                                                   seed_list = seed_list)\n",
    "                    if len(path_kernel_df) == 0:\n",
    "                        print(\"Continuing...\")\n",
    "                        continue\n",
    "\n",
    "\n",
    "                    path_kernel_df = path_kernel_df.drop(path_kernel_df.index[0])\n",
    "                    first_epoch_path_kernel_values = path_kernel_df.iloc[0]\n",
    "\n",
    "                    min_path_kernel = first_epoch_path_kernel_values.min()\n",
    "                    max_path_kernel = first_epoch_path_kernel_values.max()\n",
    "\n",
    "                    norm_first_epoch = plt.Normalize(vmin=min_path_kernel, vmax=max_path_kernel)\n",
    "                    colors = plt.cm.coolwarm(norm_first_epoch(first_epoch_path_kernel_values.values))\n",
    "                    colors_list = [list(c) for c in colors]\n",
    "\n",
    "\n",
    "                    for pruner in pruner_list:\n",
    "                        curr_dir = root_dir + f\"{model_class}/{model}/{pruner}/\"\n",
    "                        for seed in seed_list:\n",
    "                            \n",
    "\n",
    "                            dir_name =  f\"output_{dataset}_{seed}_{comp}\"\n",
    "                            output_directory = curr_dir + dir_name\n",
    "                            \n",
    "                            mean_dict, sum_dict, norm_dict = get_sorted_values(output_directory)\n",
    "                            if len(list(mean_dict.values())) > 0:\n",
    "                                val = mean_dict.values()\n",
    "                                #if op == \"Sum\":\n",
    "                                #    val = sum_dict.values()\n",
    "                                if op == \"Norm\":\n",
    "                                    val = norm_dict.values()\n",
    "                                \n",
    "                                val = list(val)[:100]\n",
    "                                if len(val) == 100:\n",
    "\n",
    "                                    df[f\"{model}--{pruner}--{dir_name}\"] = val\n",
    "                                else:\n",
    "                                    pass\n",
    "                                    # print(\"Not found: \", output_directory)\n",
    "\n",
    "                if (len(df) > 0):\n",
    "\n",
    "                    legend_columns = []\n",
    "                    fig, ax = plt.subplots()\n",
    "                    for col in df.columns:\n",
    "                        model, pruner, output_dir = col.split('--')\n",
    "                        legend_columns.append(f\"{model_dict[model]}--{pruner_dict[pruner]}\")\n",
    "                    color_dict = {}\n",
    "                    for idx, name in enumerate(df.columns):\n",
    "\n",
    "                        color_dict[name] = mpl.colors.rgb2hex(colors_list[idx][:3])\n",
    "                    marker_plot_dict = {}\n",
    "                    marker_counter = 0\n",
    "                    marker_symbol = '+'\n",
    "                    line_plots = [] \n",
    "                    for idx, (name, data) in enumerate(df.iteritems()):\n",
    "\n",
    "                        # name looks like: resnet20--synflow--output_cifar100_1337_0.0\n",
    "                        model_name, pruner_name, output_dir_name = name.split('.')[0].split('--') # resnet20--synflow--output_cifar100_1337_0\n",
    "\n",
    "                        comp_ratio = output_dir_name.split('_')[-1]\n",
    "                        \n",
    "                        marker_plot_label = f\"{model_dict[model_name]}, {pruner_dict[pruner_name]}\"\n",
    "                        if comp_ratio == '0':\n",
    "                            marker_plot_label = f\"{model_dict[model_name]}\"\n",
    "                        if (marker_plot_label in marker_plot_dict):\n",
    "                            marker = marker_plot_dict[marker_plot_label]\n",
    "                            ax.plot(df[name], color = color_dict[name], marker=marker_symbol, markevery=15, ms=10)\n",
    "                        else: \n",
    "                            marker_plot_dict[marker_plot_label] = marker_list[marker_counter]\n",
    "                            marker_symbol = marker_plot_dict[marker_plot_label]\n",
    "                            marker_counter += 1\n",
    "                            ax.plot(df[name], color = color_dict[name], marker=marker_symbol, markevery=15, ms=10, label=marker_plot_label)\n",
    "                    ax.legend()\n",
    "\n",
    "                    cbax = fig.add_axes([0.91, 0.2, 0.01, 0.5])\n",
    "\n",
    "                    pruning_str = \"\"\n",
    "                    if comp_list[0] == '1.0':\n",
    "                        pruning_str = \"Pruning 10\\%\"\n",
    "                    fig.colorbar(plt.cm.ScalarMappable(norm=norm_first_epoch, cmap=plt.cm.coolwarm), cax=cbax)\n",
    "                    ax.set_title(label=f\"Relative Change in {op} Outputs - {dataset_dict[dataset]} {pruning_str}\")\n",
    "                    ax.set_xlim(0, 100)\n",
    "                    ax.set_xlabel(\"Epochs\")\n",
    "                    # $f(\\bm{\\mathcal{X}},\\bm{\\theta}_t) - f(\\bm{\\mathcal{X}},\\bm{\\theta}_0)$\n",
    "                    ax.set_ylabel(r\"$f(\\bm{\\mathcal{X}},\\bm{\\theta}_t) - f(\\bm{\\mathcal{X}},\\bm{\\theta}_0)$\")\n",
    "                    fig.savefig(root_dir + f\"output---{op}---{'-'.join(model_list)}---{'-'.join(dataset_list)}---{'-'.join(pruner_list)}---{'-'.join(comp_list)}.png\", dpi=250)\n",
    "                    plt.show()\n",
    "                    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_logit_dataframes(model_class_list_list = [[\"lottery\"]],\n",
    "                          model_list_list = [['resnet20', 'wide-resnet20']],\n",
    "                          pruner_list_list = [['synflow', 'grasp', 'snip', 'synflow-dist', 'synflow-l2', 'synflow-dist-l2']],\n",
    "                          dataset_list_list = [['cifar10'], ['cifar100']],\n",
    "                          comp_list_list = [[\"0.0\"], [\"1.0\"]],\n",
    "                          seed_list_list = [['1337', '82' ,'821', '23', '923']],\n",
    "                          root_dir=\"../Results/pruned/\", op=\"Mean\"):\n",
    "    for model_class_list in model_class_list_list:\n",
    "        for model_list in model_list_list:\n",
    "            for pruner_list in pruner_list_list:\n",
    "                for dataset_list in dataset_list_list:\n",
    "                    for comp_list in comp_list_list: \n",
    "                        for seed_list in seed_list_list:\n",
    "                            get_logit_dataframe(comp_list = comp_list, \n",
    "                                               dataset_list = dataset_list, \n",
    "                                               model_list = model_list,\n",
    "                                               seed_list = seed_list,\n",
    "                                               pruner_list = pruner_list,\n",
    "                                               model_class_list = model_class_list,\n",
    "                                               root_dir = root_dir, op=op)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot Mean Output per dataset {CIFAR10/100}, per compression ratio {0.0/1.0}, per model {ResNet-20/Wide-ResNet-20} for all pruners"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams[\"text.usetex\"] = True\n",
    "plt.rcParams['text.latex.preamble']=[r\"\\usepackage{bm}\", r'\\usepackage{amsmath}']\n",
    "\n",
    "plot_logit_dataframes(comp_list_list = [[\"0.0\"], [\"1.0\"]], \n",
    "                      dataset_list_list = [['cifar100'], [\"cifar10\"]], \n",
    "                      model_class_list_list = [[\"lottery\"]], \n",
    "                      model_list_list = [['resnet20'], ['wide-resnet20']], op=\"Mean\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot Norm Output per dataset {CIFAR10/100}, per compression ratio {0.0/1.0}, per model {ResNet-20/Wide-ResNet-20} for all pruners"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams[\"text.usetex\"] = True\n",
    "plt.rcParams['text.latex.preamble']=[r\"\\usepackage{bm}\", r'\\usepackage{amsmath}']\n",
    "\n",
    "plot_logit_dataframes(comp_list_list = [[\"0.0\"], [\"1.0\"]], \n",
    "                      dataset_list_list = [['cifar100'], [\"cifar10\"]], \n",
    "                      model_class_list_list = [[\"lottery\"]], \n",
    "                      model_list_list = [['resnet20'], ['wide-resnet20']], op=\"Norm\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot Mean Output per dataset {CIFAR10/100}, per compression ratio {0.0/1.0}, per model {Fc-1000} for all pruners"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.rcParams[\"text.usetex\"] = True\n",
    "plt.rcParams['text.latex.preamble']=[r\"\\usepackage{bm}\", r'\\usepackage{amsmath}']\n",
    "plot_logit_dataframes(comp_list_list = [[\"0.0\"], [\"1.0\"]], \n",
    "                      dataset_list_list = [['cifar100'], [\"cifar10\"]], \n",
    "                      model_class_list_list = [[\"default\"]], \n",
    "                      model_list_list = [['fc-1000']], op=\"Mean\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot Norm Output per dataset {CIFAR10/100}, per compression ratio {0.0/1.0}, per model {FC-1000} for all pruners"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.rcParams[\"text.usetex\"] = True\n",
    "plt.rcParams['text.latex.preamble']=[r\"\\usepackage{bm}\", r'\\usepackage{amsmath}']\n",
    "plot_logit_dataframes(comp_list_list = [[\"0.0\"], [\"1.0\"]], \n",
    "                      dataset_list_list = [['cifar100'], [\"cifar10\"]], \n",
    "                      model_class_list_list = [[\"default\"]], \n",
    "                      model_list_list = [['fc-1000']], op=\"Norm\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_logit_dataframe_per_pruner(model_class_list=['lottery'],\n",
    "                        model_list=['resnet20', 'wide-resnet20'],\n",
    "                        pruner_list = ['synflow'],\n",
    "                        dataset_list = ['cifar10', 'cifar100'],\n",
    "                        seed_list = ['83', '1337'],\n",
    "                        comp_list = [\"0.0\", \"1.0\"], \n",
    "                        root_dir=\"../Results/pruned/\", op=\"Mean\"):\n",
    "    marker_list = [r'$1$', r'$2$', r'$3$', r'$4$', r'$5$', r'$6$', r'$7$', r'$8$', r'$9$', r'$a$', r'$b$', r'$c$', r'$d$', r'$e$', r'$f$', r'$g$', r'$h$']\n",
    "\n",
    "    \n",
    "    for pruner in pruner_list:\n",
    "\n",
    "        for dataset in dataset_list:\n",
    "            for comp in comp_list:\n",
    "                df = pd.DataFrame(columns=[])\n",
    "\n",
    "                # print(\"In: \", curr_dir)\n",
    "                # load -> get\n",
    "                path_kernel_df = get_dataframe(dim='path_kernel', \n",
    "                                               model_class_list=model_class_list,\n",
    "                                               model_list=model_list,\n",
    "                                               pruner_list=[pruner],\n",
    "                                               dataset_list=[dataset],\n",
    "                                               comp_list=[comp],\n",
    "                                               seed_list = seed_list)\n",
    "                if len(path_kernel_df) == 0:\n",
    "                    \n",
    "                    print(\"Continuing...\")\n",
    "                    continue\n",
    "\n",
    "                \n",
    "                path_kernel_df = path_kernel_df.drop(path_kernel_df.index[0])\n",
    "                \n",
    "                first_epoch_path_kernel_values = path_kernel_df.iloc[0]\n",
    "\n",
    "                min_path_kernel = first_epoch_path_kernel_values.min()\n",
    "                max_path_kernel = first_epoch_path_kernel_values.max()\n",
    "\n",
    "                norm_first_epoch = plt.Normalize(vmin=min_path_kernel, vmax=max_path_kernel)\n",
    "                colors = plt.cm.coolwarm(norm_first_epoch(first_epoch_path_kernel_values.values))\n",
    "                colors_list = [list(c) for c in colors]\n",
    "                \n",
    "\n",
    "                for model_class in model_class_list:\n",
    "                    for model in model_list:\n",
    "                        curr_dir = root_dir + f\"{model_class}/{model}/{pruner}/\"\n",
    "                        for seed in seed_list:\n",
    "\n",
    "\n",
    "                            dir_name =  f\"output_{dataset}_{seed}_{comp}\"\n",
    "                            output_directory = curr_dir + dir_name\n",
    "\n",
    "                            mean_dict, sum_dict, norm_dict = get_sorted_values(output_directory)\n",
    "                            if len(list(mean_dict.values())) > 0:\n",
    "                                val = mean_dict.values()\n",
    "                                #if op == \"Sum\":\n",
    "                                #    val = sum_dict.values()\n",
    "                                if op == \"Norm\":\n",
    "                                    val = norm_dict.values()\n",
    "\n",
    "                                val = list(val)[:100]\n",
    "                                if len(val) == 100:\n",
    "\n",
    "                                    df[f\"{model}--{pruner}--{dir_name}\"] = val\n",
    "                                else:\n",
    "                                    pass\n",
    "                                    # print(\"Not found: \", output_directory)\n",
    "\n",
    "            if (len(df) > 0):\n",
    "\n",
    "                legend_columns = []\n",
    "                fig, ax = plt.subplots()\n",
    "                \n",
    "                for col in df.columns:\n",
    "                    model, pruner, output_dir = col.split('--')\n",
    "                    legend_columns.append(f\"{model_dict[model]}--{pruner_dict[pruner]}\")\n",
    "                color_dict = {}\n",
    "                for idx, name in enumerate(df.columns):\n",
    "                    color_dict[name] = mpl.colors.rgb2hex(colors_list[idx][:3])\n",
    "\n",
    "\n",
    "\n",
    "                marker_plot_dict = {}\n",
    "                marker_counter = 0\n",
    "                marker_symbol = '+'\n",
    "                line_plots = [] \n",
    "                for idx, (name, data) in enumerate(df.iteritems()):\n",
    "\n",
    "                    # name looks like: resnet20--synflow--output_cifar100_1337_0.0\n",
    "                    model_name, pruner_name, output_dir_name = name.split('.')[0].split('--') # resnet20--synflow--output_cifar100_1337_0\n",
    "\n",
    "                    comp_ratio = output_dir_name.split('_')[-1]\n",
    "                    marker_plot_label = f\"{model_dict[model_name]}, {pruner_dict[pruner_name]}\"\n",
    "                    if comp_ratio == '0':\n",
    "                        marker_plot_label = f\"{model_dict[model_name]}\"\n",
    "                    if (marker_plot_label in marker_plot_dict):\n",
    "                        marker = marker_plot_dict[marker_plot_label]\n",
    "                        ax.plot(df[name], color = color_dict[name], marker=marker_symbol, markevery=15, ms=10)\n",
    "                    else: \n",
    "                        marker_plot_dict[marker_plot_label] = marker_list[marker_counter]\n",
    "                        marker_symbol = marker_plot_dict[marker_plot_label]\n",
    "                        marker_counter += 1\n",
    "                        ax.plot(df[name], color = color_dict[name], marker=marker_symbol, markevery=15, ms=10, label=marker_plot_label)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "#                     for idx, (name, data) in enumerate(df.iteritems()):\n",
    "#                         ax.plot(df[name], color = color_dict[name], marker=marker_list[idx], markevery=10, ms=10)\n",
    "\n",
    "\n",
    "                ax.legend()\n",
    "\n",
    "                cbax = fig.add_axes([0.91, 0.2, 0.01, 0.5])\n",
    "\n",
    "                pruning_str = \"\"\n",
    "                if comp_list[0] == '1.0':\n",
    "                    pruning_str = \"Pruning 10\\%\"\n",
    "                fig.colorbar(plt.cm.ScalarMappable(norm=norm_first_epoch, cmap=plt.cm.coolwarm), cax=cbax)\n",
    "                ax.set_title(label=f\"Relative Change in {op} Outputs - {dataset_dict[dataset]} {pruning_str}\")\n",
    "                ax.set_xlim(0, 100)\n",
    "                ax.set_xlabel(\"Epochs\")\n",
    "                # $f(\\bm{\\mathcal{X}},\\bm{\\theta}_t) - f(\\bm{\\mathcal{X}},\\bm{\\theta}_0)$\n",
    "                ax.set_ylabel(r\"$f(\\bm{\\mathcal{X}},\\bm{\\theta}_t) - f(\\bm{\\mathcal{X}},\\bm{\\theta}_0)$\")\n",
    "                fig.savefig(root_dir + f\"output---{op}---{'-'.join(model_list)}---{'-'.join(dataset_list)}---{'-'.join(pruner_list)}---{'-'.join(comp_list)}.png\", dpi=250)\n",
    "                plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_logit_dataframes_per_pruner(model_class_list_list = [[\"lottery\"]],\n",
    "                          model_list_list = [['resnet20', 'wide-resnet20']],\n",
    "                          pruner_list_list = [['synflow', 'grasp', 'snip', 'synflow-dist', 'synflow-l2', 'synflow-dist-l2']],\n",
    "                          dataset_list_list = [['cifar10'], ['cifar100']],\n",
    "                          comp_list_list = [[\"0.0\"], [\"1.0\"]],\n",
    "                          seed_list_list = [['1337', '82' ,'821', '23', '923']],\n",
    "                          root_dir=\"../Results/pruned/\", op=\"Mean\"):\n",
    "    for model_class_list in model_class_list_list:\n",
    "        for model_list in model_list_list:\n",
    "            for pruner_list in pruner_list_list:\n",
    "                for dataset_list in dataset_list_list:\n",
    "                    for comp_list in comp_list_list: \n",
    "                        for seed_list in seed_list_list:\n",
    "                            get_logit_dataframe_per_pruner(comp_list = comp_list, \n",
    "                                               dataset_list = dataset_list, \n",
    "                                               model_list = model_list,\n",
    "                                               seed_list = seed_list,\n",
    "                                               pruner_list = pruner_list,\n",
    "                                               model_class_list = model_class_list,\n",
    "                                               root_dir = root_dir, op=op)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot Mean Output per dataset {CIFAR10/100}, per compression ratio {0.0/1.0}, per pruner {SynFlow, GraSP, Rand, SNIP, SynFlow-L2, SynFlow-Dist, SynFlow-L2-Dist} for all models {FC-1000/ResNet-20/Wide-ResNet-20}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams[\"text.usetex\"] = True\n",
    "plt.rcParams['text.latex.preamble']=[r\"\\usepackage{bm}\", r'\\usepackage{amsmath}']\n",
    "\n",
    "plot_logit_dataframes_per_pruner(comp_list_list = [[\"0.0\"], [\"1.0\"]], \n",
    "                      dataset_list_list = [['cifar100'], [\"cifar10\"]], \n",
    "                      model_class_list_list = [[\"default\", \"lottery\"]], \n",
    "                      pruner_list_list = [['synflow'], ['grasp'], [ 'snip'], ['synflow-dist'], ['synflow-l2'], ['synflow-dist-l2']],\n",
    "                      model_list_list = [[\"fc-1000\", 'resnet20', 'wide-resnet20']], op=\"Mean\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot Norm Output per dataset {CIFAR10/100}, per compression ratio {0.0/1.0}, per pruner {SynFlow, GraSP, SNIP, SynFlow-L2, SynFlow-Dist, SynFlow-L2-Dist} for all models {FC-1000/ResNet-20/Wide-ResNet-20}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.rcParams[\"text.usetex\"] = True\n",
    "plt.rcParams['text.latex.preamble']=[r\"\\usepackage{bm}\", r'\\usepackage{amsmath}']\n",
    "\n",
    "plot_logit_dataframes_per_pruner(comp_list_list = [[\"0.0\"], [\"1.0\"]], \n",
    "                      dataset_list_list = [['cifar100'], [\"cifar10\"]], \n",
    "                      model_class_list_list = [[\"default\", \"lottery\"]], \n",
    "                      pruner_list_list = [['synflow'], ['grasp'], [ 'snip'], ['synflow-dist'], ['synflow-l2'], ['synflow-dist-l2']],\n",
    "                      model_list_list = [[\"fc-1000\", 'resnet20', 'wide-resnet20']], op=\"Norm\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
