{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.ticker as mtick\n",
    "from scipy.spatial import distance_matrix\n",
    "from neurips_fc_nets import *\n",
    "from neurips_utils import *\n",
    "\n",
    "sns.set(font_scale=1.1)\n",
    "plt.rcParams['figure.dpi'] = 100\n",
    "plt.rcParams['savefig.dpi'] = 300\n",
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['text.latex.preamble'] = r'\\usepackage{dsfont}'\n",
    "\n",
    "from IPython.display import set_matplotlib_formats\n",
    "set_matplotlib_formats('svg', 'pdf')\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore') \n",
    "# colors = ['blue', 'orange', 'olive', 'brown', 'cornflowerblue', 'red', 'k', 'purple']\n",
    "colors = ['lightsteelblue', 'cornflowerblue', 'blue', 'midnightblue']\n",
    "label_size = 16\n",
    "hps = {\n",
    "    'exp_1': {'n': 10, 'd': 5, 'm_teacher': 2, 'm_student': 20, 'init_scales_teacher': [1.0, 1.0], 'init_scales_student': [0.0001, 0], 'seed': 0, 'gamma': 0.001, 'warmup_mult_factor': 1.0},  # good stabilization and decrease in the clustering coefficient\n",
    "\n",
    "    'exp_2': {'n': 10, 'd': 5, 'm_teacher': 2, 'm_student': 20, 'init_scales_teacher': [1.0, 1.0], 'init_scales_student': [0.01, 0], 'seed': 1, 'gamma': 0.001, 'warmup_mult_factor': 1.0}, \n",
    "        \n",
    "    'exp_3': {'n': 10, 'd': 5, 'm_teacher': 2, 'm_student': 20, 'init_scales_teacher': [1.0, 1.0], 'init_scales_student': [0.01, 0], 'seed': 1, 'gamma': 0.001, 'warmup_mult_factor': 1.0}, \n",
    "\n",
    "    'exp_4': {'n': 10, 'd': 5, 'm_teacher': 2, 'm_student': 20, 'init_scales_teacher': [1.0, 1.0], 'init_scales_student': [0.000001, 0], 'seed': 1, 'gamma': 0.001, 'warmup_mult_factor': 1.0}, \n",
    "\n",
    "    # 'lenaic_3': {'n': 10, 'd': 5, 'm_teacher': 2, 'm_student': 20, 'init_scales_teacher': [1.0, 1.0], 'init_scales_student': [0.000001, 0], 'seed': 0, 'gamma': 0.001, 'warmup_mult_factor': 1.0}, \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### d=2 case for which we can visualize the neurons if needed\n",
    "list_of_experiments = ['exp_2', 'exp_4']\n",
    "results = []\n",
    "for exp in list_of_experiments: \n",
    "    X, y, X_test, y_test, net_teacher = get_data_two_layer_relu_net(hps[exp]['n'], hps[exp]['d'], hps[exp]['m_teacher'], hps[exp]['init_scales_teacher'], hps[exp]['seed'])\n",
    "    x_log_scale = False\n",
    "    num_iter, iters_loss = get_iters_eval(n_iter_power=5, x_log_scale=x_log_scale)  # 4.15\n",
    "\n",
    "    n_plots = 4\n",
    "    lp = 0.1\n",
    "    l0_threshold = 0.1\n",
    "    weight_avg = 0.0\n",
    "\n",
    "    # gammas = [0.13, 0.455, 0.455, 0.455]\n",
    "    # schedule_lengths = [0.0, 0.1, 0.3, 0.5]\n",
    "    # decay_factors = [1.0, 2.0, 4.0, 5.0]\n",
    "\n",
    "    gammas = [0.01]\n",
    "    schedule_lengths = [0.0]\n",
    "    decay_factors = [1.0]\n",
    "    batch_sizes = [hps[exp]['n']] * len(schedule_lengths)\n",
    "\n",
    "    nets_all = []\n",
    "    plt.figure(figsize = (26, 4))  # (12, 4) for 2 plots, (22, 4) for 3 plots\n",
    "    for i, (gamma, batch_size, schedule_length, decay_factor) in enumerate(zip(gammas, batch_sizes, schedule_lengths, decay_factors)):\n",
    "        np.random.seed(hps[exp]['seed'])\n",
    "        torch.manual_seed(hps[exp]['seed'])\n",
    "\n",
    "        net_init = FCNet2Layers(n_feature=hps[exp]['d'], n_hidden=hps[exp]['m_student']) \n",
    "        net_init.init_gaussian(init_scales=hps[exp]['init_scales_student'])\n",
    "        train_losses, test_losses, nets = train_fc_net(X, y, X_test, y_test, gamma, batch_size, net_init, iters_loss, num_iter, thresholds=[int(schedule_length*num_iter)], decays=[decay_factor]) \n",
    "        nets_all.append(nets)\n",
    "        print(len(nets))\n",
    "        singular_values = [np.linalg.svd(net.layer1.weight.data.numpy())[1] for net in nets]\n",
    "        sing_array = np.array(singular_values)\n",
    "        results.append(sing_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import matplotlib.patches as mpatches\n",
    "import matplotlib.lines as mlines\n",
    "\n",
    "\n",
    "directions = 5\n",
    "\n",
    "colors = ['#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00', '#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00' ]\n",
    "\n",
    "sing_array = results[0]\n",
    "\n",
    "fig,ax = plt.subplots()\n",
    "\n",
    "patch1 = mlines.Line2D([],[],linestyle='dashed', color = '#984ea3', label='$\\gamma = 10^{-4}$')\n",
    "patch2 = mlines.Line2D([],[],linestyle='solid', color = '#984ea3', label='$\\gamma = 10^{-2}$')\n",
    "\n",
    "first_legend = ax.legend(handles = [patch1,patch2], loc='lower center', ncol = 3, bbox_to_anchor = (0., -.35, 1, .1))\n",
    "\n",
    "ax.add_artist(first_legend)\n",
    "\n",
    "for direction in range(directions):\n",
    "    ax.semilogy(iters_loss, sing_array[:,direction], colors[direction], label = r'$\\sigma_{:2d}$ '.format(direction),  linestyle = '--')\n",
    "\n",
    "sing_array = results[1]\n",
    "\n",
    "for direction in range(directions):\n",
    "    ax.loglog(iters_loss, sing_array[:,direction], colors[direction],  linestyle = 'solid')\n",
    "\n",
    "ax.legend(loc = 'lower left', bbox_to_anchor = (0., .15, 1, .102))\n",
    "\n",
    "\n",
    "plt.xlabel('Iterations', fontsize=label_size)\n",
    "plt.ylabel('Singular Values', fontsize=label_size)\n",
    "plt.title('Evolution of Singular Values of ReLU network')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "linestyle = 'dotted'\n",
    "\n",
    "color_1 = '#4daf4a'\n",
    "color_2 = '#ff7f00'\n",
    "\n",
    "start = 0\n",
    "end = 115\n",
    "\n",
    "\n",
    "plt.vlines(x=0, ymin=10^(-4), ymax=1, colors = color_1, linestyle = linestyle, linewidth=2)\n",
    "\n",
    "plt.vlines(x=4000, ymin=10^(-4), ymax=1, colors=color_1, linestyle = linestyle, linewidth=2)\n",
    "\n",
    "plt.vlines(x=5000, ymin=10^(-4), ymax=1, colors=color_2, linestyle = linestyle, linewidth=2)\n",
    "\n",
    "plt.vlines(x=8500, ymin=10^(-4), ymax=1, colors=color_2, linestyle = linestyle, linewidth=2)\n",
    "\n",
    "\n",
    "train_loss_list = np.array([loss.data.numpy() for loss in train_losses])\n",
    "\n",
    "plt.semilogy(iters_loss[start:end], train_loss_list[start:end])\n",
    "\n",
    "plt.annotate('--First Saddle--',  (600,.25) , color = '#a65628', fontsize=12, horizontalalignment='left', verticalalignment='top')\n",
    "plt.annotate('--Second Saddle--',  (5250,.05) , color = '#a65628', fontsize=12, horizontalalignment='left', verticalalignment='bottom')\n",
    "\n",
    "\n",
    "\n",
    "plt.xlabel('Iterations', fontsize=label_size)\n",
    "plt.ylabel('Training loss', fontsize=label_size)\n",
    "\n",
    "\n",
    "plt.savefig('relu-saddle.pdf', bbox_inches='tight')\n",
    "\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directions = 5\n",
    "\n",
    "start = 0\n",
    "end = 105\n",
    "\n",
    "colors = ['#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00', '#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00' ]\n",
    "\n",
    "sing_array = results[1]\n",
    "\n",
    "for direction in range(directions):\n",
    "    plt.loglog(iters_loss[start:end], sing_array[start:end,direction], colors[direction], label = r'$\\sigma_{:2d}$ '.format(direction),  linestyle = '--')\n",
    "\n",
    "plt.xlabel('Iterations', fontsize=label_size)\n",
    "plt.ylabel('Singular Values', fontsize=label_size)\n",
    "plt.title('Evolution of Singular Values of ReLU network - Phase 1')\n",
    "plt.legend()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.14 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "vscode": {
   "interpreter": {
    "hash": "397704579725e15f5c7cb49fe5f0341eb7531c82d19f2c29d197e8b64ab5776b"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
