{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "from sklearn.decomposition import PCA\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.data as Data\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation as animation\n",
    "\n",
    "import numpy as np\n",
    "import imageio\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({\n",
    "    \"animation.writer\": \"ffmpeg\",\n",
    "    \"font.family\": \"serif\",  # use serif/main font for text elements\n",
    "    \"font.size\": 12,\n",
    "    \"text.usetex\": True,     # use inline math for ticks\n",
    "    \"pgf.rcfonts\": False,    # don't setup fonts from rc parameters\n",
    "    \"hist.bins\": 20, # default number of bins in histograms\n",
    "    \"pgf.preamble\": [\n",
    "         \"\\\\usepackage{units}\",          # load additional packages\n",
    "         \"\\\\usepackage{metalogo}\",\n",
    "         \"\\\\usepackage{unicode-math}\",   # unicode math setup\n",
    "         r\"\\setmathfont{xits-math.otf}\",\n",
    "         r\"\\setmainfont{DejaVu Serif}\",  # serif font via preamble\n",
    "         r'\\usepackage{color}',\n",
    "    ]\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Neural network architecture and initialisation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    1 hidden layer Relu network\n",
    "    \"\"\"\n",
    "    def __init__(self, n_feature, n_hidden, n_output, init_scale=1, bias_hidden=True, bias_output=False, balanced=True, clipping=False, **kwargs):\n",
    "        \"\"\"\n",
    "        n_feature: dimension of input\n",
    "        n_hidden: number of hidden neurons\n",
    "        n_output: dimension of output\n",
    "        init_scale: weights are initialized ~ N(0, init_scale^2/(md)) where d is the input dimension of this layer and m the width\n",
    "        bias_hidden: if True, use bias parameters in hidden layer. Use no bias otherwise\n",
    "        bias_output: if True, use bias parameters in output layer. Use no bias otherwise\n",
    "        balanced: if True, use a balanced initialisation\n",
    "        clipping: if True, ensure that ||(w_j,a_j)|| \\in [minclip, maxclip]/sqrt(n_hidden) for any j\n",
    "        \"\"\"\n",
    "        super(Net, self).__init__()\n",
    "        self.init_scale = init_scale\n",
    "        \n",
    "        self.hidden = torch.nn.Linear(n_feature, n_hidden, bias=bias_hidden)   # hidden layer with rescaled init\n",
    "        torch.nn.init.normal_(self.hidden.weight.data, std=(init_scale/np.sqrt(n_hidden*n_feature)))\n",
    "        if bias_hidden:\n",
    "            torch.nn.init.normal_(self.hidden.bias.data, std=(init_scale/np.sqrt(n_hidden*n_feature)))\n",
    "            \n",
    "        self.predict = torch.nn.Linear(n_hidden, n_output, bias=bias_output)   # output layer with rescaled init\n",
    "        if balanced: # balanced initialisation\n",
    "            if bias_hidden:\n",
    "                neuron_norms = (self.hidden.weight.data.norm(dim=1).square()+self.hidden.bias.data.square()).sqrt()\n",
    "            else:\n",
    "                neuron_norms = (self.hidden.weight.data.norm(dim=1).square()).sqrt()\n",
    "            self.predict.weight.data = 2*torch.bernoulli(0.5*torch.ones_like(self.predict.weight.data)) -1\n",
    "            self.predict.weight.data *= neuron_norms\n",
    "        else:\n",
    "            torch.nn.init.normal_(self.predict.weight.data, std=(init_scale/np.sqrt(n_hidden)))\n",
    "        if bias_output:\n",
    "            torch.nn.init_normal_(self.predict.bias.data, std=(init_scale/np.sqrt(n_hidden)))\n",
    "            \n",
    "        if clipping:\n",
    "            neuron_norms = self.hidden.weight.data.norm(dim=1).square() + self.predict.weight.data.norm().square()\n",
    "            if bias_hidden:\n",
    "                neuron_norms += self.hidden.bias.data.square()\n",
    "            if bias_output:\n",
    "                neuron_norms += self.predict.bias.data.square()\n",
    "            neuron_norms = neuron_norms.sqrt()\n",
    "            ra = kwargs.get('minclip', init_scale/10)/np.sqrt(n_hidden)\n",
    "            rb = kwargs.get('maxclip', 10*init_scale)/np.sqrt(n_hidden)\n",
    "            m_weights = torch.clip(neuron_norms, min=ra, max=rb)/neuron_norms\n",
    "            self.hidden.weight.data *= m_weights.unsqueeze(1)\n",
    "            self.predict.weight.data *= m_weights\n",
    "            if bias_hidden:\n",
    "                self.hidden.bias.data *= m_weights\n",
    "            if bias_output:\n",
    "                self.predict.bias.data *= m_weights\n",
    "            \n",
    "        self.activation = kwargs.get('activation', torch.nn.ReLU()) # activation of hidden layer\n",
    "\n",
    "    def forward(self, z):\n",
    "        z = self.activation(self.hidden(z))     \n",
    "        z = self.predict(z)             # linear output\n",
    "        return z"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(4) # fix random seed\n",
    "d = 150\n",
    "n = 75\n",
    "m_teacher = 6 # number of neurons of the teacher network\n",
    "\n",
    "## Teacher network\n",
    "teacher = Net(n_feature=d, n_hidden=m_teacher, n_output=1, init_scale=1, balanced=False, clipping=False, bias_hidden=False\n",
    "          )\n",
    "\n",
    "x = torch.randn(n, d)\n",
    "y = teacher(x).detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(teacher.predict.weight)\n",
    "print('-'*20)\n",
    "print(y.reshape(-1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Visualisation functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def multicolor_label(ax,list_of_strings,list_of_colors,axis='x',anchorpad=0,**kw):\n",
    "    \"\"\"this function creates axes labels with multiple colors\n",
    "    ax: specifies the axes object where the labels should be drawn\n",
    "    list_of_strings: a list of all of the text items\n",
    "    list_if_colors: a corresponding list of colors for the strings\n",
    "    axis:'x', 'y', or 'both' and specifies which label(s) should be drawn\"\"\"\n",
    "    from matplotlib.offsetbox import AnchoredOffsetbox, TextArea, HPacker, VPacker\n",
    "\n",
    "    # x-axis label\n",
    "    if axis=='x' or axis=='both':\n",
    "        boxes = [TextArea(text, textprops=dict(color=color, ha='left',va='bottom',**kw)) \n",
    "                    for text,color in zip(list_of_strings,list_of_colors) ]\n",
    "        xbox = HPacker(children=boxes,align=\"center\",pad=0, sep=60)\n",
    "        anchored_xbox = AnchoredOffsetbox(loc=3, child=xbox, pad=anchorpad,frameon=False,bbox_to_anchor=(0.27, -0.18),\n",
    "                                          bbox_transform=ax.transAxes, borderpad=0.)\n",
    "        ax.add_artist(anchored_xbox)\n",
    "\n",
    "    # y-axis label\n",
    "    if axis=='y' or axis=='both':\n",
    "        boxes = [TextArea(text, textprops=dict(color=color, ha='left',va='bottom',rotation=90,**kw)) \n",
    "                     for text,color in zip(list_of_strings[::-1],list_of_colors) ]\n",
    "        ybox = VPacker(children=boxes,align=\"center\", pad=0, sep=5)\n",
    "        anchored_ybox = AnchoredOffsetbox(loc=3, child=ybox, pad=anchorpad, frameon=False, bbox_to_anchor=(-0.10, 0.2), \n",
    "                                          bbox_transform=ax.transAxes, borderpad=0.)\n",
    "        ax.add_artist(anchored_ybox)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_single_frame(fig, arts, frame_number):\n",
    "    \"\"\"save as a pdf a single frame of an animation\n",
    "    fig: the figure to save\n",
    "    arts: list of images resulting in the animation\n",
    "    frame_number: the specific frame to save as a pdf\n",
    "    \"\"\"\n",
    "    # make sure everything is hidden\n",
    "    for frame_arts in arts:\n",
    "        for art in frame_arts:\n",
    "            art.set_visible(False)\n",
    "    # make the one artist we want visible\n",
    "    for i in range(len(arts[frame_number])):\n",
    "        arts[frame_number][i].set_visible(True)\n",
    "    fig.savefig(\"frame_{}.pdf\".format(frame_number))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom PCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def svd_flip(u, v, u_based_decision=True):\n",
    "    \"\"\"Sign correction to ensure deterministic output from SVD.\n",
    "    Adjusts the columns of u and the rows of v such that the loadings in the\n",
    "    columns in u that are largest in absolute value are always positive.\n",
    "    Parameters\n",
    "    ----------\n",
    "    u : ndarray\n",
    "        u and v are the output of `linalg.svd` or\n",
    "        :func:`~sklearn.utils.extmath.randomized_svd`, with matching inner\n",
    "        dimensions so one can compute `np.dot(u * s, v)`.\n",
    "    v : ndarray\n",
    "        u and v are the output of `linalg.svd` or\n",
    "        :func:`~sklearn.utils.extmath.randomized_svd`, with matching inner\n",
    "        dimensions so one can compute `np.dot(u * s, v)`.\n",
    "        The input v should really be called vt to be consistent with scipy's\n",
    "        output.\n",
    "    u_based_decision : bool, default=True\n",
    "        If True, use the columns of u as the basis for sign flipping.\n",
    "        Otherwise, use the rows of v. The choice of which variable to base the\n",
    "        decision on is generally algorithm dependent.\n",
    "    Returns\n",
    "    -------\n",
    "    u_adjusted, v_adjusted : arrays with the same dimensions as the input.\n",
    "    \"\"\"\n",
    "    if u_based_decision:\n",
    "        # columns of u, rows of v\n",
    "        max_abs_cols = np.argmax(np.abs(u), axis=0)\n",
    "        signs = np.sign(u[max_abs_cols, range(u.shape[1])])\n",
    "        u *= signs\n",
    "        v *= signs[:, np.newaxis]\n",
    "    else:\n",
    "        # rows of v, columns of u\n",
    "        max_abs_rows = np.argmax(np.abs(v), axis=1)\n",
    "        signs = np.sign(v[range(v.shape[0]), max_abs_rows])\n",
    "        u *= signs\n",
    "        v *= signs[:, np.newaxis]\n",
    "    return u, v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class MyPCA:\n",
    "    \"\"\"\n",
    "    Same implementation as sklearn but allows to not center data before processing it.\n",
    "    \"\"\"\n",
    "    def __init__(self, n_components, centered=False):\n",
    "        self.n_components = n_components\n",
    "        self.centered_ = centered\n",
    "\n",
    "    def fit(self, X):\n",
    "        \"\"\"\n",
    "        Assumes observations in X are passed as rows of a numpy array.\n",
    "        \"\"\"\n",
    "        n_samples, n_features = X.shape\n",
    "\n",
    "        # Center data\n",
    "        if self.centered_:\n",
    "            self.mean_ = np.mean(X, axis=0)\n",
    "            X -= self.mean_\n",
    "\n",
    "        U, S, Vt = np.linalg.svd(X, full_matrices=False)\n",
    "        # flip eigenvectors' sign to enforce deterministic output\n",
    "        U, Vt = svd_flip(U, Vt)\n",
    "\n",
    "        components_ = Vt\n",
    "\n",
    "        # Get variance explained by singular values\n",
    "        explained_variance_ = (S**2) / (n_samples - 1)\n",
    "        total_var = explained_variance_.sum()\n",
    "        explained_variance_ratio_ = explained_variance_ / total_var\n",
    "        singular_values_ = S.copy()  # Store the singular values.\n",
    "\n",
    "        self.noise_variance_ = explained_variance_[self.n_components:].mean()\n",
    "\n",
    "        self.n_samples_, self.n_features_ = n_samples, n_features\n",
    "        self.components_ = components_[:self.n_components]\n",
    "        self.explained_variance_ = explained_variance_[:self.n_components]\n",
    "        self.explained_variance_ratio_ = explained_variance_ratio_[:self.n_components]\n",
    "        self.singular_values_ = singular_values_[:self.n_components]\n",
    "        \n",
    "    def transform(self, X):\n",
    "        if self.centered_:\n",
    "            X = X - self.mean_\n",
    "        X_transformed = np.dot(X, self.components_.T)\n",
    "        return X_transformed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# init network\n",
    "net = Net(n_feature=d, n_hidden=200, n_output=1, init_scale=1e-20, balanced=True, bias_hidden=False)     # define the network\n",
    " \n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=0.001) #Gradient descent\n",
    "loss_func = torch.nn.MSELoss(reduction='mean')  # mean squared error\n",
    "\n",
    "n_samples = x.shape[0]\n",
    "n_iterations = 200000 # number of gradient descent steps\n",
    "\n",
    "loss = torch.Tensor(np.array([0]))\n",
    "previous_loss = torch.Tensor(np.array([np.infty]))\n",
    "\n",
    "losses = []\n",
    "\n",
    "# plot parameters\n",
    "iter_geom = 1.1 #saved frames correspond to steps t=\\lceil k^{iter_geom} \\rceil for all integers k \n",
    "last_iter = 0\n",
    "frame = 0\n",
    "weights = []\n",
    "signs = []\n",
    "iters = []\n",
    "\n",
    "# train the network\n",
    "for it in tqdm(range(n_iterations)):\n",
    "    prediction = net(x)\n",
    "    loss = loss_func(prediction, y) \n",
    "    if (it<2 or it==int(last_iter*iter_geom)+1): # save net weights\n",
    "        weights.append(net.hidden.weight.data.detach().numpy().copy())\n",
    "        signs.append(net.predict.weight.data.heaviside(torch.as_tensor(float(0.5))).reshape(-1).numpy().copy())\n",
    "        iters.append(it)\n",
    "        last_iter = it\n",
    "    losses.append(loss.data.numpy())\n",
    "    optimizer.zero_grad()   # clear gradients for next train\n",
    "    loss.backward()         # backpropagation, compute gradients\n",
    "    optimizer.step()        # descent step"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Loss profile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(losses, lw=3)\n",
    "plt.ylim(ymin=0)\n",
    "#plt.xlim(xmin=0, xmax=100000)\n",
    "plt.ylabel(r'$L(\\theta)$',fontsize=20)\n",
    "plt.xlabel('Iterations', fontsize=20)\n",
    "plt.grid(alpha=0.2)\n",
    "plt.tight_layout()\n",
    "plt.savefig('loss_profile_n{}_d{}.pdf'.format(n,d))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Neuron alignment visualisation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## PCA to represent the neurons in 2D\n",
    "\n",
    "pca = MyPCA(n_components=d, centered=False)\n",
    "pca.fit(weights[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ims = []\n",
    "fig = plt.figure(\"Neuron alignment\")\n",
    "plt.ioff()\n",
    "\n",
    "# Cosmetics\n",
    "c1 = 'tab:green' # color of left axis\n",
    "c2 = 'tab:blue' # color of right axis\n",
    "color_map = {0 : 'firebrick',\n",
    "             0.5 : 'black',\n",
    "             1 : 'darkviolet'}\n",
    "\n",
    "#plt.subplots_adjust(left=0.15, right=0.85)\n",
    "\n",
    "ax = fig.add_subplot(111, projection='polar') # polar coordinates\n",
    "ax.set_rorigin(-5e-2) # set inner circle for 0 norm vectors\n",
    "ax.set_theta_zero_location(\"E\")\n",
    "ax.yaxis.set_ticklabels([])\n",
    "\n",
    "\n",
    "#######\n",
    "for i,w in enumerate(weights):\n",
    "    s = signs[i]\n",
    "    it = iters[i]\n",
    "    c = [color_map[d] for d in s] # color of stars given their output layer sign\n",
    "    w1 = pca.transform(w) # projection in 2D space\n",
    "    im = ax.scatter(np.arctan(w1[:,1]/w1[:,0])+np.pi*np.heaviside(w1[:,0],0), np.linalg.norm(w1[:,:2],axis=1), animated=True, c=c, marker='*')\n",
    "    t1 = ax.annotate(\"iteration: \"+str(it),(0.1,0.95),xycoords='figure fraction',annotation_clip=False) # add text\n",
    "    t2 = ax.annotate(\"frame: \"+str(i),(0.8,0.95),xycoords='figure fraction',annotation_clip=False) # add text\n",
    "    ims.append([im,t1,t2])\n",
    "    \n",
    "ani = animation.ArtistAnimation(fig, ims, interval=100, repeat=False)\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ani.save('highdim_alignment.mp4', fps=10, dpi=120) # save animation as .mp4 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save specific frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del ani"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_single_frame(fig, ims, 108) # save specific frame of animation as .pdf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PCA explained variance repartition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(pca.explained_variance_ratio_, '+')\n",
    "plt.xlabel(\"Component number\", fontsize=20)\n",
    "plt.ylabel(\"Explained variance ratio\", fontsize=20)\n",
    "plt.grid(alpha=0.2)\n",
    "plt.tight_layout()\n",
    "plt.savefig('explained_variance_n{}_d{}.pdf'.format(n,d))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
