{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0e09ce89",
   "metadata": {
    "id": "0e09ce89"
   },
   "source": [
    " # Dynamics of a Single Neuron\n",
    " This notebook generates the figures 2, 3, 4 in the main and figures in the appendix for the section studying the dynamics of a single neuron."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58f2a09f-9597-467a-bb67-c8dac87105fb",
   "metadata": {
    "id": "58f2a09f-9597-467a-bb67-c8dac87105fb"
   },
   "source": [
    "## Setup Functions\n",
    "The following cells define plotting functions, setup the loss/gradient, and generate theory predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e1a3db8-a937-40d0-ad5e-d62090c5c42c",
   "metadata": {
    "id": "9e1a3db8-a937-40d0-ad5e-d62090c5c42c"
   },
   "outputs": [],
   "source": [
    "# Load Libraries\n",
    "import os\n",
    "import numpy as np\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from sklearn import svm\n",
    "from scipy.integrate import solve_ivp\n",
    "import matplotlib.colors as colors\n",
    "from matplotlib import cm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d8fe421-8985-4e48-ac09-196117f132e1",
   "metadata": {
    "id": "8d8fe421-8985-4e48-ac09-196117f132e1"
   },
   "source": [
    "### Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a38f8070",
   "metadata": {
    "id": "a38f8070"
   },
   "outputs": [],
   "source": [
    "# Plot styles\n",
    "def style_3D(ax, lims=None, no_axes=False, nogrid=False, nofill=False):\n",
    "\n",
    "    # Boundary\n",
    "    for axis in [ax.xaxis, ax.yaxis, ax.zaxis]:\n",
    "        axis.line.set_linewidth(3)\n",
    "\n",
    "    # Customize ticks and axes\n",
    "    ax.tick_params(axis='both', which='major', length=10, width=1.5, pad=10, labelsize=15)\n",
    "    ax.tick_params(axis='both', which='minor', width=1, length=2)\n",
    "    ax.xaxis.set_major_locator(plt.MaxNLocator(4))\n",
    "    ax.yaxis.set_major_locator(plt.MaxNLocator(4))\n",
    "    ax.zaxis.set_major_locator(plt.MaxNLocator(4))\n",
    "    ax.zaxis._axinfo['juggled'] = (1,2,0)\n",
    "\n",
    "    if lims is not None:\n",
    "        ax.set_xlim(lims)\n",
    "        ax.set_ylim(lims)\n",
    "        ax.set_zlim(lims)\n",
    "\n",
    "    if no_axes:\n",
    "        ax.set_axis_off()\n",
    "\n",
    "    if nogrid:\n",
    "        ax.grid(False)\n",
    "\n",
    "    if nofill:\n",
    "        ax.xaxis.pane.fill = False\n",
    "        ax.yaxis.pane.fill = False\n",
    "        ax.zaxis.pane.fill = False\n",
    "\n",
    "\n",
    "# Plot styles\n",
    "def style_heatmaps(ax, xlabels=True, ylabels=True, xlim=None, ylim=None):\n",
    "    if xlabels:\n",
    "        ax.tick_params(axis=\"x\", which=\"both\", bottom=True, top=False,\n",
    "                       labelbottom=True, left=True, right=False,\n",
    "                       labelleft=True, direction='out',length=7,width=1.5,pad=0,\n",
    "                       labelsize=24,labelrotation=45)\n",
    "        ax.xaxis.set_major_locator(plt.MaxNLocator(6))\n",
    "    else:\n",
    "        ax.tick_params(axis=\"x\", which=\"both\", bottom=True, top=False,\n",
    "                       labelbottom=True, left=True, right=False,\n",
    "                       labelleft=True, direction='out',length=7,width=1.5,pad=0,\n",
    "                       labelsize=24,labelrotation=45)\n",
    "        ax.set_xlabel(\"\")\n",
    "    if ylabels:\n",
    "        ax.tick_params(axis=\"y\", which=\"both\", bottom=True, top=False,\n",
    "                   labelbottom=True, left=True, right=False,\n",
    "                   labelleft=True, direction='out',length=7,width=1.5,pad=4,\n",
    "                   labelsize=24)\n",
    "        ax.yaxis.set_major_locator(plt.MaxNLocator(6))\n",
    "    else:\n",
    "        ax.tick_params(axis=\"y\", which=\"both\", bottom=True, top=False,\n",
    "                   labelbottom=False, left=True, right=False,\n",
    "                   labelleft=False, direction='out',length=7,width=1.5,pad=4,\n",
    "                   labelsize=24)\n",
    "        ax.set_ylabel(\"\")\n",
    "    ax.xaxis.offsetText.set_fontsize(20)\n",
    "\n",
    "    # Boundary\n",
    "    for dir in [\"top\", \"bottom\", \"right\", \"left\"]:\n",
    "        ax.spines[dir].set_linewidth(3)\n",
    "\n",
    "    # Limits\n",
    "    if xlim is not None:\n",
    "        ax.set_xlim(xlim)\n",
    "    if ylim is not None:\n",
    "        ax.set_ylim(ylim)\n",
    "\n",
    "def style_axes(ax, numyticks=5, numxticks=5):\n",
    "    ax.tick_params(axis=\"y\", which=\"both\", bottom=True, top=False,\n",
    "                   labelbottom=True, left=True, right=False,\n",
    "                   labelleft=True,direction='out',length=7,width=1.5,pad=8,labelsize=24)\n",
    "    ax.yaxis.set_major_locator(plt.MaxNLocator(numyticks))\n",
    "\n",
    "    ax.tick_params(axis=\"x\", which=\"both\", bottom=True, top=False,\n",
    "                   labelbottom=True, left=True, right=False,\n",
    "                   labelleft=True,direction='out',length=7,width=1.5,pad=8,\n",
    "                   labelsize=24)\n",
    "    ax.xaxis.set_major_locator(plt.MaxNLocator(numxticks))\n",
    "    #ax.ticklabel_format(axis=\"x\", style=\"scientific\", scilimits=(0,0), useMathText=True)\n",
    "    ax.ticklabel_format(axis=\"x\", useMathText=True)\n",
    "    ax.xaxis.offsetText.set_fontsize(20)\n",
    "    ax.grid()\n",
    "\n",
    "    # boundary\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    for dir in [\"top\", \"bottom\", \"right\", \"left\"]:\n",
    "        ax.spines[dir].set_linewidth(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d19e6aa2-aded-46ae-9460-ba2b449a0698",
   "metadata": {
    "id": "d19e6aa2-aded-46ae-9460-ba2b449a0698"
   },
   "outputs": [],
   "source": [
    "def plot_loss(ax, lim1, lim2, XX, XY):\n",
    "\n",
    "    b1 = np.linspace(1.1 * lim1[0], 1.1 * lim1[1], 100)\n",
    "    b2 = np.linspace(1.1 * lim2[0], 1.1 * lim2[1], 100)\n",
    "\n",
    "    b1, b2 = np.meshgrid(b1, b2)\n",
    "    loss = XX[0,0]*b1**2 + XX[1,1]*b2**2 + 2 * XX[0,1]*b1*b2 - 2*XY[0]*b1 - 2*XY[1]*b2\n",
    "    loss += np.abs(loss.min())\n",
    "\n",
    "    zmin, zmax = loss.min(), loss.max()\n",
    "    cnorm = colors.Normalize(vmin=zmin, vmax=zmax)\n",
    "\n",
    "    mesh = ax.contourf(b1, b2, loss, levels=np.linspace(zmin, zmax, 20), cmap='RdBu', norm=cnorm)\n",
    "    ax.contour(b1, b2, loss, levels=np.linspace(zmin, zmax, 20), colors='k')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2176de41-af39-4e96-af6d-2f3ee8f1b054",
   "metadata": {
    "id": "2176de41-af39-4e96-af6d-2f3ee8f1b054"
   },
   "outputs": [],
   "source": [
    "def plot_minima_surface(ax, OLS, lim, color=\"red\", opacity=0.9):\n",
    "\n",
    "    null_direction = np.array([-OLS[1], OLS[0]]) / np.linalg.norm(OLS)\n",
    "\n",
    "    xx = np.linspace(1e-1, 10, 500)\n",
    "    yy = np.linspace(-10, 10, 500)\n",
    "    xx, yy = np.meshgrid(xx, yy)\n",
    "    w1 = xx * OLS[0] + yy * null_direction[0]\n",
    "    w2 = xx * OLS[1] + yy * null_direction[1]\n",
    "    a = 1 / xx\n",
    "\n",
    "    # Filter data within plot limits\n",
    "    mask = ((w1 >= lim[0]) & (w1 <= lim[1]) &\n",
    "            (w2 >= lim[0]) & (w2 <= lim[1]) &\n",
    "            (a >= lim[0]) & (a <= lim[1]))\n",
    "    w1_trim = np.ma.masked_where(~mask, w1)\n",
    "    w2_trim = np.ma.masked_where(~mask, w2)\n",
    "    a_trim = np.ma.masked_where(~mask, a)\n",
    "\n",
    "    # Plot the hyperbolic surface\n",
    "    ax.plot_surface(w1_trim, w2_trim, a_trim, color=color, alpha=opacity, zorder=0)\n",
    "    ax.plot_surface(-w1_trim, -w2_trim, -a_trim, color=color, alpha=opacity, zorder=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7254fcac-e513-4a9a-90db-cd145a22f618",
   "metadata": {
    "id": "7254fcac-e513-4a9a-90db-cd145a22f618"
   },
   "outputs": [],
   "source": [
    "def plot_task_aligned_surface(R, ax, lim, XX, XY, opacity=1.0):\n",
    "    OLS = np.linalg.pinv(XX) @ XY\n",
    "    OLS /= np.linalg.norm(OLS)\n",
    "\n",
    "    a = np.linspace(-lim, lim, 500)\n",
    "    w = np.linspace(-lim, lim, 500)\n",
    "    a, w = np.meshgrid(a, w)\n",
    "\n",
    "    x = w * OLS[0]\n",
    "    y = w * OLS[1]\n",
    "    z = a\n",
    "\n",
    "    loss = MSE(z * x, z * y, XX, XY)\n",
    "    cnorm = colors.LogNorm(vmin=np.min(loss), vmax=np.max(loss))\n",
    "    ax.plot_surface(x, y, z, facecolors=cm.RdBu(cnorm(loss)), alpha=opacity, shade=False, edgecolor='none', zorder=0)\n",
    "    surf = cm.ScalarMappable(norm=cnorm, cmap=\"RdBu\")\n",
    "    return surf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "484df37d-8adc-43fb-8afe-4b81e5660194",
   "metadata": {
    "id": "484df37d-8adc-43fb-8afe-4b81e5660194"
   },
   "outputs": [],
   "source": [
    "def plot_conserved_surface(R, ax, lim, XX, XY, opacity=1.0):\n",
    "    # R = a^2 - |w|^2\n",
    "    if R > 0:\n",
    "\n",
    "        alpha = np.linspace(0, lim, 500)\n",
    "        theta = np.linspace(0, 2*np.pi, 500)\n",
    "        alpha, theta = np.meshgrid(alpha, theta)\n",
    "\n",
    "        x = alpha * np.cos(theta)\n",
    "        y = alpha * np.sin(theta)\n",
    "        z = np.sqrt(alpha**2 + R)\n",
    "\n",
    "        loss = MSE(z * x, z * y, XX, XY)\n",
    "        cnorm = colors.LogNorm(vmin=np.min(loss), vmax=np.max(loss))\n",
    "        ax.plot_surface(x, y, z, facecolors=cm.RdBu(cnorm(loss)), alpha=opacity, shade=False, edgecolor='none', zorder=0)\n",
    "\n",
    "        loss = MSE(-z * x, -z * y, XX, XY)\n",
    "        cnorm = colors.LogNorm(vmin=np.min(loss), vmax=np.max(loss))\n",
    "        ax.plot_surface(x, y, -z, facecolors=cm.RdBu(cnorm(loss)), alpha=opacity, shade=False, edgecolor='none', zorder=0)\n",
    "        surf = cm.ScalarMappable(norm=cnorm, cmap=\"RdBu\")\n",
    "    else:\n",
    "        alpha = np.linspace(-lim, lim, 500)\n",
    "        theta = np.linspace(0, 2*np.pi, 500)\n",
    "        alpha, theta = np.meshgrid(alpha, theta)\n",
    "\n",
    "        x = np.sqrt(alpha**2 - R) * np.cos(theta)\n",
    "        y = np.sqrt(alpha**2 - R) * np.sin(theta)\n",
    "        z = alpha\n",
    "\n",
    "        loss = MSE(z * x, z * y, XX, XY)\n",
    "        cnorm = colors.LogNorm(vmin=np.min(loss), vmax=np.max(loss))\n",
    "        ax.plot_surface(x, y, z, facecolors=cm.RdBu(cnorm(loss)), alpha=opacity, shade=False, edgecolor='none', zorder=0)\n",
    "        surf = cm.ScalarMappable(norm=cnorm, cmap=\"RdBu\")\n",
    "    return surf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7561d430-fca5-4068-a0bf-6e8bc3da876a",
   "metadata": {
    "id": "7561d430-fca5-4068-a0bf-6e8bc3da876a"
   },
   "outputs": [],
   "source": [
    "def plot_basin_seperating_surface(ax, lims, XX, XY, opacity=1.0):\n",
    "    delta = np.linspace(-2 * lims[1]**2, -1e-8, 500)\n",
    "    alpha = np.linspace(lims[0], lims[1], 500)\n",
    "    alpha, delta = np.meshgrid(alpha, delta)\n",
    "\n",
    "    OLS = np.linalg.pinv(XX) @ XY\n",
    "\n",
    "    k = alpha * (-delta / 2 - np.sqrt(delta**2 + 4 * np.linalg.norm(OLS)**2) / 2)\n",
    "\n",
    "    a = OLS[0]**2 / OLS[1]**2 + 1\n",
    "    b = -2*k*OLS[0]/OLS[1]**2\n",
    "    c = k**2 / OLS[1]**2 + delta - alpha**2\n",
    "\n",
    "    x = (-b + np.sqrt(b**2 - 4 * a * c)) / (2 * a)\n",
    "    y = (k - OLS[0]*x) / OLS[1]\n",
    "    z = alpha\n",
    "\n",
    "    # Filter data within plot limits\n",
    "    mask = ((x >= lims[0]) & (x <= lims[1]) &\n",
    "            (y >= lims[0]) & (y <= lims[1]) &\n",
    "            (z >= lims[0]) & (z <= lims[1]))\n",
    "\n",
    "    x_trimmed = np.ma.masked_where(~mask, x)\n",
    "    y_trimmed = np.ma.masked_where(~mask, y)\n",
    "    z_trimmed = np.ma.masked_where(~mask, z)\n",
    "\n",
    "\n",
    "    loss = MSE(z * x, z * y, XX, XY)\n",
    "    cnorm = colors.LogNorm(vmin=np.min(loss), vmax=np.max(loss))\n",
    "    ax.plot_surface(x_trimmed, y_trimmed, z_trimmed, facecolors=cm.RdBu(cnorm(loss)), alpha=opacity, shade=False, edgecolor='none', zorder=0)\n",
    "    surf = cm.ScalarMappable(norm=cnorm, cmap=\"RdBu\")\n",
    "\n",
    "\n",
    "    x = (-b - np.sqrt(b**2 - 4 * a * c)) / (2 * a)\n",
    "    y = (k - OLS[0]*x) / OLS[1]\n",
    "    z = alpha\n",
    "\n",
    "    # Filter data within plot limits\n",
    "    mask = ((x >= lims[0]) & (x <= lims[1]) &\n",
    "            (y >= lims[0]) & (y <= lims[1]) &\n",
    "            (z >= lims[0]) & (z <= lims[1]))\n",
    "\n",
    "    x_trimmed = np.ma.masked_where(~mask, x)\n",
    "    y_trimmed = np.ma.masked_where(~mask, y)\n",
    "    z_trimmed = np.ma.masked_where(~mask, z)\n",
    "\n",
    "    loss = MSE(z * x, z * y, XX, XY)\n",
    "    cnorm = colors.LogNorm(vmin=np.min(loss), vmax=np.max(loss))\n",
    "    ax.plot_surface(x_trimmed, y_trimmed, z_trimmed, facecolors=cm.RdBu(cnorm(loss)), alpha=opacity, shade=False, edgecolor='none', zorder=0)\n",
    "    surf = cm.ScalarMappable(norm=cnorm, cmap=\"RdBu\")\n",
    "\n",
    "    return surf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbd277a6-6111-4389-b955-39e9b4bd8aa3",
   "metadata": {
    "id": "cbd277a6-6111-4389-b955-39e9b4bd8aa3"
   },
   "outputs": [],
   "source": [
    "def plot_trimmed(ax, x, y, z, lim, lw, color, zorder, ls='-'):\n",
    "\n",
    "    if len(lim) == 2:\n",
    "        xlim = lim\n",
    "        ylim = lim\n",
    "        zlim = lim\n",
    "    else:\n",
    "        xlim = lim[0]\n",
    "        ylim = lim[1]\n",
    "        zlim = lim[2]\n",
    "\n",
    "\n",
    "    if z is not None:\n",
    "        # Filter data within plot limits\n",
    "        mask = ((x >= xlim[0]) & (x <= xlim[1]) &\n",
    "                (y >= ylim[0]) & (y <= ylim[1]) &\n",
    "                (z >= zlim[0]) & (z <= zlim[1]))\n",
    "        x_trimmed = x[mask]\n",
    "        y_trimmed = y[mask]\n",
    "        z_trimmed = z[mask]\n",
    "\n",
    "        ax.plot(x_trimmed, y_trimmed, z_trimmed, lw=lw, ls=ls, color=color, zorder=zorder)\n",
    "    else:\n",
    "        # Filter data within plot limits\n",
    "        mask = ((x >= xlim[0]) & (x <= xlim[1]) &\n",
    "                (y >= ylim[0]) & (y <= ylim[1]))\n",
    "        x_trimmed = x[mask]\n",
    "        y_trimmed = y[mask]\n",
    "\n",
    "        ax.plot(x_trimmed, y_trimmed, lw=lw, ls=ls, color=color, zorder=zorder)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c72b27e-1f6c-408b-bf72-27c7ba6cc5e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_diverging_colors_7(index):\n",
    "    colors = [\n",
    "        \"#964b00\",\n",
    "        \"#a47e1d\",\n",
    "        \"#acaf4d\",\n",
    "        \"#b0df89\",\n",
    "        \"#71c28b\",\n",
    "        \"#39a189\",\n",
    "        \"#008080\"\n",
    "    ]\n",
    "    return colors[index]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11847417-2899-4744-a9d9-b63332e82fb9",
   "metadata": {
    "id": "11847417-2899-4744-a9d9-b63332e82fb9"
   },
   "source": [
    "### Loss and Gradient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "523329af-2b73-4674-b996-8451912719c6",
   "metadata": {
    "id": "523329af-2b73-4674-b996-8451912719c6"
   },
   "outputs": [],
   "source": [
    "def MSE(b1, b2, XX, XY, eps=1e-2):\n",
    "    loss = XX[0,0]*b1**2 + XX[1,1]*b2**2 + 2 * XX[0,1]*b1*b2 - 2*XY[0]*b1 - 2*XY[1]*b2\n",
    "    return loss - np.min(loss) + eps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0582cdfd-8554-49a1-9885-7c3989781a81",
   "metadata": {
    "id": "0582cdfd-8554-49a1-9885-7c3989781a81"
   },
   "outputs": [],
   "source": [
    "def get_gradient(XX, XY):\n",
    "    def grad(theta):\n",
    "        w = theta[:-1]\n",
    "        a = theta[-1]\n",
    "        beta = a * w\n",
    "        residual = XX.dot(beta) - XY\n",
    "        dw = a * residual\n",
    "        da = w.dot(residual)\n",
    "        dtheta = np.concatenate((dw, [da]))\n",
    "        return dtheta\n",
    "    return lambda t, theta : -grad(theta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sEtpg2dr3GiS",
   "metadata": {
    "id": "sEtpg2dr3GiS"
   },
   "outputs": [],
   "source": [
    "def kernel_distance(a0, w0, a_t, w1_t, w2_t):\n",
    "    numer = 0\n",
    "    numer += 2*a0**2 * a_t**2\n",
    "    numer += a0**2 * (w1_t**2 + w2_t**2)\n",
    "    numer += a_t**2 * np.linalg.norm(w0)**2\n",
    "    numer += (w0[0]*w1_t + w0[1]*w2_t)**2\n",
    "    denom1 = np.sqrt(2*a0**4 + 2*a0**2*np.linalg.norm(w0)**2 + np.linalg.norm(w0)**4)\n",
    "    denom2 = np.sqrt(2*a_t**4 + 2*a0**2*(w1_t**2 + w2_t**2) + (w1_t**2 + w2_t**2)**2)\n",
    "    denom = denom1 * denom2\n",
    "    S = 1 - numer / denom\n",
    "    return S"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a19977e-99b8-4d03-ba6b-be59ea7c13e3",
   "metadata": {
    "id": "9a19977e-99b8-4d03-ba6b-be59ea7c13e3"
   },
   "source": [
    "### Theory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7273c337-c774-4bcf-85f9-75c7e392b7bc",
   "metadata": {
    "id": "7273c337-c774-4bcf-85f9-75c7e392b7bc"
   },
   "outputs": [],
   "source": [
    "def coordinates(R, beta):\n",
    "    \"\"\"\n",
    "    Takes in an R = a^2 - |w|^2 and a beta = aw and returns a,w\n",
    "    Note: This only works for positive a.\n",
    "    \"\"\"\n",
    "    a = np.array([np.sqrt((R + np.sqrt(R**2 + 4*np.linalg.norm(beta)**2))/2)])\n",
    "    w = beta / a\n",
    "    return w, a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d047dbf-a57e-47d9-b71f-2408ffec5dc3",
   "metadata": {
    "id": "6d047dbf-a57e-47d9-b71f-2408ffec5dc3"
   },
   "outputs": [],
   "source": [
    "def balanced_theory(R, time, w0, a0, XY):\n",
    "\n",
    "    nu0 = a0 * np.linalg.norm(w0)\n",
    "    mu0 = np.dot(w0, XY) / (np.linalg.norm(w0) * np.linalg.norm(XY))\n",
    "    beta0 = a0 * w0\n",
    "\n",
    "    r = np.linalg.norm(XY)\n",
    "    y0 = mu0 * r\n",
    "    c1 = np.arctanh(y0 / r) / r\n",
    "    scale = c1*r + r * time\n",
    "    alignment_theory = r * np.tanh(scale) / r\n",
    "\n",
    "    norm0 = nu0\n",
    "    numer = r * (np.cosh(2 * scale) + 1)\n",
    "    denom = 2 * scale + np.sinh(2 * scale)\n",
    "    c2 = (numer[0] - norm0 * denom[0]) / (norm0 * r)\n",
    "    norm_theory = numer / (denom + c2 * r)\n",
    "\n",
    "    dir1_0 = beta0[0] / norm0\n",
    "    dir2_0 = beta0[1] / norm0\n",
    "    c3 = (dir1_0 - XY[0] * np.tanh(c1 * r) / r) * np.cosh(c1 * r)\n",
    "    c4 = (dir2_0 - XY[1] * np.tanh(c1 * r) / r) * np.cosh(c1 * r)\n",
    "    beta1_theory = (c3/np.cosh(scale) + XY[0] * np.tanh(scale) / r) * norm_theory\n",
    "    beta2_theory = (c4/np.cosh(scale) + XY[1] * np.tanh(scale) / r) * norm_theory\n",
    "\n",
    "    abs_a = (norm_theory**2)**(1/4)\n",
    "    sign_a = np.sign(a0)\n",
    "    a_theory = sign_a * abs_a\n",
    "    w1_theory = beta1_theory / a_theory\n",
    "    w2_theory = beta2_theory / a_theory\n",
    "\n",
    "    return w1_theory, w2_theory, a_theory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "zB7AFCfZw__V",
   "metadata": {
    "id": "zB7AFCfZw__V"
   },
   "outputs": [],
   "source": [
    "def theory(R, time, w, a, OLS):\n",
    "    # Being anal about some things\n",
    "    assert a > 0, \"For initialization always choosing positive a\"\n",
    "    if R > 0:\n",
    "        assert (w**2).sum() < a**2\n",
    "        # tanh_theta0 = np.linalg.norm(w) / a\n",
    "        # # nice expression for cosh(arctanh(_))\n",
    "        # phi0 = 1 / np.sqrt(1 - tanh_theta0**2)\n",
    "        phi0 = np.cosh(np.arctanh(np.linalg.norm(w) / a))\n",
    "    elif R < 0:\n",
    "        assert (w**2).sum() > a**2\n",
    "        # tanh_theta0 = a / np.linalg.norm(w)\n",
    "        # phi0 = tanh_theta0 / np.sqrt(1 - tanh_theta0**2)\n",
    "        phi0 = np.sinh(np.arctanh(a / np.linalg.norm(w)))\n",
    "    else:\n",
    "        # FIXME: handle R=0\n",
    "        raise ValueError(f\"Not supporting R={R}\")\n",
    "        assert (w**2).sum() == a**2\n",
    "        tanh_theta0 = 1\n",
    "\n",
    "    s = np.linalg.norm(OLS)\n",
    "\n",
    "    # Initial values\n",
    "    nu0 = (w / a).dot(OLS)\n",
    "\n",
    "    print(f\"R: {R}, s: {s}, nu0: {nu0}, phi0: {phi0}\")\n",
    "\n",
    "    # Quantity that comes up a lot, complex-valued\n",
    "    Q = 1/2 * time * np.sqrt(4*s**2 + R**2) + np.arctanh((R + 2*nu0) / np.sqrt(4*s**2 + R**2) + 0j)\n",
    "\n",
    "    # Compute nu(t)\n",
    "    tanh_Q = np.tanh(Q)\n",
    "    assert all(abs(tanh_Q.imag) < 1e-7), \"Imaginary part should be vanishing for nu quantity\"\n",
    "    nu = 1/2 * (-R + np.sqrt(4*s**2 + R**2) * tanh_Q.real)\n",
    "\n",
    "    # This expression comes from using F[_] for nu. One expression for positive R, the other one\n",
    "    # for negative. Choosing positive sign for each (which might just be working because with our\n",
    "    # setup a_0 is always positive so that \\phi_0 is always positive)\n",
    "    if R < 0:\n",
    "        term_in_sqrt = \\\n",
    "     (1 / np.cosh(Q))**2 * \\\n",
    "      (4*s**2 * (s**2 - nu0*(R+nu0)) * phi0**2 - \\\n",
    "       (4*s**2 + R**2) * (s**2 + (s-nu0)*(s+nu0)*phi0**2) * np.exp(-time * R) + \\\n",
    "       R * (-s**2 + nu0*(R+nu0)) * phi0**2 * \\\n",
    "            (R * np.cosh(2 * Q) - np.sqrt(4*s**2 + R**2) * np.sinh(2 * Q))) / ((s**2 - nu0*(R+nu0)) * phi0**2)\n",
    "        phi = 2 * s / np.sqrt(term_in_sqrt) * 1j\n",
    "        assert all(abs(phi.imag) < 1e-7), \"Imaginary part should be vanishing\"\n",
    "    else:\n",
    "        term_in_sqrt = \\\n",
    "     (1 / np.cosh(Q))**2 * \\\n",
    "      (4*s**2 * (s**2 - nu0*(R+nu0)) * phi0**2 - \\\n",
    "       (4*s**2 + R**2) * (-nu0**2*phi0**2 + s**2*(-1+phi0**2)) * np.exp(-time * R) + \\\n",
    "       R * (-s**2 + nu0*(R+nu0)) * phi0**2 * \\\n",
    "            (R * np.cosh(2 * Q) - np.sqrt(4*s**2 + R**2) * np.sinh(2 * Q))) / ((s**2 - nu0*(R+nu0)) * phi0**2)\n",
    "        assert all(abs(term_in_sqrt.imag) < 1e-7), \"Imaginary part should be vanishing\"\n",
    "        phi = 2 * s / np.sqrt(term_in_sqrt)\n",
    "\n",
    "    # assert all(abs(phi.imag) < 1e-7), f\"Imaginary part should be vanishing, it's: {phi}\"\n",
    "\n",
    "    # Can now recover theta\n",
    "    if R > 0:\n",
    "        # When R > 0, we know that a does not change sign. This means that theta does not change\n",
    "        # sign, and in particular, since we're always starting with positive a, theta starts\n",
    "        # positive and remains positive. This is fine here because arccosh gives the inverse as\n",
    "        # positive (cosh is technically not invertible)\n",
    "        theta = np.arccosh(phi)\n",
    "\n",
    "        # FIXME: technically assumes $\\|beta_*\\|=0$\n",
    "        mu = 1 / np.tanh(theta) * nu\n",
    "    elif R < 0:\n",
    "        # This theta should be able to change sign\n",
    "        theta = np.arcsinh(phi)\n",
    "        mu = np.tanh(theta) * nu\n",
    "\n",
    "    # This is \\nu in paper\n",
    "    signed_beta_norm = np.abs(R) * np.sinh(2 * theta) / 2\n",
    "\n",
    "    return nu, phi, theta, signed_beta_norm, mu\n",
    "\n",
    "\n",
    "def a_w_norm(R, theta):\n",
    "    if R > 0:\n",
    "        # |a| is larger than ||w||, \\tan(\\theta)=||w||/a\n",
    "        a = np.sqrt(R) * np.cosh(theta)\n",
    "        w_norm = np.sqrt(R) * np.sinh(theta)\n",
    "    elif R < 0:\n",
    "        a = np.sqrt(-R) * np.sinh(theta)\n",
    "        w_norm = np.sqrt(-R) * np.cosh(theta)\n",
    "    return a, w_norm\n",
    "\n",
    "\n",
    "def get_w_theory(w0, w_norm, beta_star, mu, R, theta):\n",
    "\n",
    "    d = len(w0)\n",
    "    beta_hat = beta_star / np.linalg.norm(beta_star)\n",
    "    w0_hat = w0 / np.linalg.norm(w0)\n",
    "\n",
    "    # We get the components of w in the basis (\\beta_*, (I-\\beta_*\\beta_*^T)w_0)\n",
    "\n",
    "    # this `mu` is <w^, b> (has not normalized by beta norm)\n",
    "    c2 = mu / np.linalg.norm(beta_star)\n",
    "\n",
    "    if (w0_hat.dot(beta_hat)**2 >= (1 - 1e-9)) and (w0_hat.dot(beta_hat)**2 <= (1 + 1e-9)):\n",
    "        return (c2[None, :] * beta_hat[:, None] * w_norm).T\n",
    "\n",
    "    # Always taking positive square root. This is maybe wrong.\n",
    "    c1 = np.sqrt(1 - c2**2)\n",
    "\n",
    "    proj = (np.eye(d) - (beta_hat[:, None] @ beta_hat[:, None].T)) / np.sqrt(1 - (beta_hat.dot(w0_hat))**2)\n",
    "\n",
    "    w_hat =  c1[None, :] * (proj @ w0_hat)[:, None] + c2[None, :] * beta_hat[:, None]\n",
    "\n",
    "    # (d, T), (T,)\n",
    "    print(w_hat.shape, w_norm.shape)\n",
    "\n",
    "    return (w_hat * w_norm).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85cabc2a-0839-4cf1-991c-adade2d5f44b",
   "metadata": {
    "id": "85cabc2a-0839-4cf1-991c-adade2d5f44b"
   },
   "outputs": [],
   "source": [
    "# Gradient of q(x) assuming x is positive\n",
    "def grad_q(delta, x):\n",
    "    return 3/2 * np.sqrt(np.sqrt(x**2 + delta**2 / 4) - delta/2)\n",
    "\n",
    "# Define the vector z\n",
    "def z_vector(delta, beta0):\n",
    "    norm = np.linalg.norm(beta0)\n",
    "    return -3/2 * np.sqrt(np.sqrt(norm**2 + delta**2/4) - delta/2) * beta0 / norm\n",
    "\n",
    "# Return alpha that determines interpolating solution\n",
    "def implicit_bias_theory(delta, beta0, OLS):\n",
    "\n",
    "    # Define null direction\n",
    "    null = np.array([-OLS[1], OLS[0]])\n",
    "    null /= np.linalg.norm(null)\n",
    "\n",
    "    # Define c and k\n",
    "    c = np.linalg.norm(OLS)\n",
    "    k = -2 * z_vector(delta, beta0).dot(null) / 3\n",
    "\n",
    "    # Define alpha\n",
    "    kappa = (k**2 + delta) / 2\n",
    "    alpha = k * np.sqrt(kappa + np.sqrt(kappa**2 + c**2))\n",
    "\n",
    "    return alpha"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "i4k2tz8Hh6pf",
   "metadata": {
    "id": "i4k2tz8Hh6pf"
   },
   "source": [
    "Sort out what's going wrong in edge cases (e.g. other OLS soln)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "p40cJGU1tisk",
   "metadata": {
    "id": "p40cJGU1tisk"
   },
   "source": [
    "## Fig. 2 - 3: Whitened Input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ce4f8be-f842-445a-b26f-6c8b8c90705c",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "0ce4f8be-f842-445a-b26f-6c8b8c90705c",
    "outputId": "c15ffecf-9b31-45f5-cbb4-e7a6badbffca",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Create figures\n",
    "fig1 = plt.figure(figsize=(8, 8))\n",
    "fig2 = plt.figure(figsize=(8, 8))\n",
    "fig3 = plt.figure(figsize=(8, 8))\n",
    "fig4 = plt.figure(figsize=(8, 8))\n",
    "fig5 = plt.figure(figsize=(8, 8))\n",
    "fig6 = plt.figure(figsize=(8, 8))\n",
    "fig7 = plt.figure(figsize=(8, 8))\n",
    "\n",
    "# Create axes\n",
    "ax1 = fig1.add_subplot(111, projection='3d', computed_zorder=False)  # First plot is a 3D plot\n",
    "ax2 = fig2.add_subplot(111, projection='3d', computed_zorder=False)  # Second plot is a 3D plot\n",
    "ax3 = fig3.add_subplot(111, projection='3d', computed_zorder=False)  # Third plot is a 3D plot\n",
    "ax4 = fig4.add_subplot(111)  # Fourth plot is a 2D plot\n",
    "ax5 = fig5.add_subplot(111)  # Fifth plot is a 2D plot\n",
    "ax6 = fig6.add_subplot(111)  # Sixth plot is a 2D plot\n",
    "ax7 = fig7.add_subplot(111)  # Seventh plot is a 2D plot\n",
    "axes = [ax1, ax2, ax3]\n",
    "\n",
    "# Hyperparameters\n",
    "lims = (-3, 3)\n",
    "R_values = [-4, -2, -1, 0, 1, 2, 4] # a^2 - |w|^2\n",
    "R_3D_plot = [-2, 0, 2]\n",
    "\n",
    "# Setup Data\n",
    "XX = np.eye(2)\n",
    "XY = np.array([0.0, 1.0])\n",
    "\n",
    "# OLS solution\n",
    "# (NOTE: this is the least-norm solution, but to be specific it is the only solution\n",
    "# since we are taking X^TX to be full-rank, i.e. X to be full-rank)\n",
    "OLS = np.linalg.pinv(XX) @ XY\n",
    "ax4.scatter(OLS[0], OLS[1], s=75, c='red', zorder=3)\n",
    "\n",
    "# Initialization\n",
    "alpha=1.0\n",
    "w0 = np.array([-1.0, 0.0])\n",
    "w0 /= np.linalg.norm(w0)\n",
    "a0 = np.array([1.0])\n",
    "w0 *= alpha\n",
    "a0 *= alpha\n",
    "beta0 = a0 * w0\n",
    "ax4.scatter(beta0[0], beta0[1], s=75, c='grey', zorder=3)\n",
    "\n",
    "# Plot Function Space Surface (make `eps` of room along each axis past OLS or beta0)\n",
    "eps = 0.2\n",
    "lim1 = np.array([min(OLS[0], beta0[0]) - eps, max(OLS[0], beta0[0]) + eps])\n",
    "lim2 = np.array([min(OLS[1], beta0[1]) - eps, max(OLS[1], beta0[1]) + eps])\n",
    "plot_loss(ax4, lim1, lim2, XX, XY)\n",
    "\n",
    "i = 0\n",
    "for j, R in enumerate(R_values):\n",
    "\n",
    "    # Get color\n",
    "    c = get_diverging_colors_7(j) \n",
    "    # c = cm.BrBG(j/6) #c = cm.tab10(j)\n",
    "\n",
    "    # Initialization (NOTE: this is always setting a to be positive, adding assert)\n",
    "    w, a = coordinates(R, beta0)\n",
    "    assert a > 0\n",
    "    theta0 = np.concatenate((w, a))\n",
    "\n",
    "    # Gradient Flow\n",
    "    T = 20\n",
    "    grad = get_gradient(XX, XY)\n",
    "    theta = solve_ivp(grad, [0, T], theta0, rtol=1e-6)\n",
    "    time = theta.t\n",
    "    w1_t = theta.y[0, :]\n",
    "    w2_t = theta.y[1, :]\n",
    "    a_t = theta.y[2, :]\n",
    "    beta1 = a_t * w1_t\n",
    "    beta2 = a_t * w2_t\n",
    "\n",
    "    # Plot function space trajectories\n",
    "    ax4.plot(beta1, beta2, c=c, lw=5, label=r'$\\delta = {}$'.format(R))\n",
    "\n",
    "    # Plot nu, mu, kernel distance\n",
    "    nu = a_t * np.sqrt(w1_t**2 + w2_t**2)\n",
    "    mu = (w1_t * OLS[0] + w2_t * OLS[1]) / (np.sqrt(w1_t**2 + w2_t**2) * np.linalg.norm(OLS))\n",
    "    S = kernel_distance(a, w, a_t, w1_t, w2_t)\n",
    "    ax5.plot(time, nu, c=c, lw=7, label=r'$\\delta = {}$'.format(R))\n",
    "    ax6.plot(time, mu, c=c, lw=7, label=r'$\\delta = {}$'.format(R))\n",
    "    ax7.plot(time, S, c=c, lw=7, label=r'$\\delta = {}$'.format(R))\n",
    "\n",
    "    # Plot theory\n",
    "    if R == 0:\n",
    "        w1_theory, w2_theory, a_theory = balanced_theory(R, time, w, a, OLS)\n",
    "        signed_beta_norm_theory = a_theory * np.sqrt(w1_theory**2 + w2_theory**2)\n",
    "        mu_theory = (w1_theory * OLS[0] + w2_theory * OLS[1]) / np.sqrt(w1_theory**2 + w2_theory**2)\n",
    "    else:\n",
    "        nu_theory, phi_theory, theta_theory, signed_beta_norm_theory, mu_theory = theory(R, time, w, a, OLS)\n",
    "        a_theory, w_norm_theory = a_w_norm(R, theta_theory)\n",
    "        w_theory = get_w_theory(w, w_norm_theory, OLS, mu_theory, R, None)\n",
    "        w1_theory, w2_theory = w_theory[:, 0], w_theory[:, 1]\n",
    "    if R in R_3D_plot:\n",
    "      axes[i].plot(w1_theory, w2_theory, a_theory, lw=2, ls='--', color='k', zorder=1.5)\n",
    "    beta1_theory = a_theory * w1_theory\n",
    "    beta2_theory = a_theory * w2_theory\n",
    "    ax4.plot(beta1_theory, beta2_theory, lw=2, ls='--', color='k')\n",
    "\n",
    "    mu_theory /= np.linalg.norm(OLS)\n",
    "    S_theory = kernel_distance(a, w, a_theory, w1_theory, w2_theory)\n",
    "    ax5.plot(time, signed_beta_norm_theory, lw=2, ls='--', color='k')\n",
    "    ax6.plot(time, mu_theory, lw=2, ls='--', color='k')\n",
    "    ax7.plot(time, S_theory, lw=2, ls='--', color='k')\n",
    "\n",
    "    if R in R_3D_plot:\n",
    "\n",
    "        # Plot Conserved Quanitity over Parameter space\n",
    "        surf = plot_conserved_surface(R, axes[i], lims[-1], XX, XY, opacity=0.75)\n",
    "        # fig.colorbar(surf, ax=axes[i], pad=0.05, shrink=0.65, aspect=15)\n",
    "\n",
    "        # Plot hyperbola of equivalent OLS solutions\n",
    "        w_opt, a_opt = coordinates(R, OLS)\n",
    "        t = np.linspace(1, 10, 100)\n",
    "        plot_trimmed(axes[i], w_opt[0]/t, w_opt[1]/t, a_opt*t, lims, lw=5, color='red', zorder=2)\n",
    "        plot_trimmed(axes[i], w_opt[0]*t, w_opt[1]*t, a_opt/t, lims, lw=5, color='red', zorder=-2)\n",
    "        plot_trimmed(axes[i], -w_opt[0]*t, -w_opt[1]*t, -a_opt/t, lims, lw=5, color='red', zorder=2)\n",
    "        plot_trimmed(axes[i], -w_opt[0]/t, -w_opt[1]/t, -a_opt*t, lims, lw=5, color='red', zorder=-2)\n",
    "\n",
    "        # Initialization\n",
    "        axes[i].scatter(w[0], w[1], a, color='k', s=50, edgecolor=c, linewidths=2, zorder=3)\n",
    "\n",
    "        # Plot hyperbola of equivalent initializations\n",
    "        t = np.linspace(1, 10, 100)\n",
    "        plot_trimmed(axes[i], w[0]/t, w[1]/t, a*t, lims, lw=5, color='grey', zorder=2)\n",
    "        plot_trimmed(axes[i], w[0]*t, w[1]*t, a/t, lims, lw=5, color='grey', zorder=-2)\n",
    "        plot_trimmed(axes[i], -w[0]*t, -w[1]*t, -a/t, lims, lw=5, color='grey', zorder=2)\n",
    "        plot_trimmed(axes[i], -w[0]/t, -w[1]/t, -a*t, lims, lw=5, color='grey', zorder=-2)\n",
    "\n",
    "        # Plot parameter trajectories\n",
    "        axes[i].plot(w1_t, w2_t, a_t, color=c, lw=5, zorder=1)\n",
    "        axes[i].scatter(w1_t[-1], w2_t[-1], a_t[-1], color='k', s=50, edgecolor=c, linewidths=2, zorder=3)\n",
    "\n",
    "        # Axes index\n",
    "        i += 1\n",
    "\n",
    "\n",
    "# Style plots\n",
    "for ax in axes:\n",
    "    style_3D(ax, lims=lims, nofill=True)\n",
    "    ax.set_xlabel(r\"$w_1$\", labelpad=15, fontsize=30)\n",
    "    ax.set_ylabel(r\"$w_2$\", labelpad=15, fontsize=30)\n",
    "    ax.set_zlabel(r\"$a$\", labelpad=15, fontsize=30)\n",
    "    ax.set_box_aspect(aspect=None, zoom=0.85)\n",
    "    ax.view_init(elev=29) #azim=-60\n",
    "style_heatmaps(ax4)\n",
    "ax4.set_xlabel(r\"$\\beta_1$\", labelpad=-10, fontsize=30)\n",
    "ax4.set_ylabel(r\"$\\beta_2$\", labelpad=0, fontsize=30)\n",
    "ax4.legend(loc=\"best\", fontsize=20, reverse=True)\n",
    "\n",
    "style_axes(ax5)\n",
    "style_axes(ax6)\n",
    "style_axes(ax7)\n",
    "ax5.set_xlabel(r\"Time $t$\", fontsize=30)\n",
    "ax5.set_ylabel(r\"Signed Magnitude $\\mu$\", fontsize=30)\n",
    "ax6.set_xlabel(r\"Time $t$\", fontsize=30)\n",
    "ax6.set_ylabel(r\"Alignment $\\phi$\", fontsize=30)\n",
    "ax7.set_xlabel(r\"Time $t$\", fontsize=30)\n",
    "ax7.set_ylabel(r\"Kernel Distance $S(0,t)$\", fontsize=30)\n",
    "ax6.legend(loc=\"best\", fontsize=20, reverse=True)\n",
    "\n",
    "# Show the plot\n",
    "plt.show()\n",
    "\n",
    "# Save figures\n",
    "if not os.path.exists(\"unstructured\"):\n",
    "    os.makedirs(\"unstructured\")\n",
    "fig1.savefig('unstructured/negative.pdf', bbox_inches='tight')\n",
    "fig2.savefig('unstructured/zero.pdf', bbox_inches='tight')\n",
    "fig3.savefig('unstructured/positive.pdf', bbox_inches='tight')\n",
    "fig4.savefig('unstructured/function-space.pdf', bbox_inches='tight')\n",
    "fig5.savefig('unstructured/nu.pdf', bbox_inches='tight')\n",
    "fig6.savefig('unstructured/mu.pdf', bbox_inches='tight')\n",
    "fig7.savefig('unstructured/kernel-distance.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f34dda5-ed83-4c14-b408-03cb862ba32c",
   "metadata": {
    "id": "7f34dda5-ed83-4c14-b408-03cb862ba32c"
   },
   "source": [
    "## Fig. 4: Interpolating Solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fb16863-29ec-4c76-bed4-7830af4ddd66",
   "metadata": {
    "id": "5fb16863-29ec-4c76-bed4-7830af4ddd66",
    "outputId": "e2e9abc7-a592-4c15-d512-10beee2fad31"
   },
   "outputs": [],
   "source": [
    "# Create 3D plot\n",
    "fig1 = plt.figure(figsize=(8, 8))\n",
    "fig2 = plt.figure(figsize=(8, 8))\n",
    "ax1 = fig1.add_subplot(111, projection='3d', computed_zorder=False)  # First subplot is a 3D plot\n",
    "ax2 = fig2.add_subplot(111)  # Third subplot is a 2D plot\n",
    "\n",
    "# Hyperparameters\n",
    "lims = (-3, 3)\n",
    "R_values = [-4, -2, -1, 0, 1, 2, 4] # a^2 - |w|^2\n",
    "\n",
    "# Setup Data\n",
    "X = np.array([0.5, 1])\n",
    "XX = np.outer(X,X)\n",
    "XY = 1.1 * X\n",
    "\n",
    "# OLS solution\n",
    "OLS = np.linalg.pinv(XX) @ XY\n",
    "\n",
    "# Initialization\n",
    "eps = 0.5\n",
    "alpha = 0.6\n",
    "beta0 = eps * (alpha * np.array([X[1], -X[0]]) + (1 - alpha) * X)\n",
    "ax2.scatter(beta0[0], beta0[1], s=75, c='grey', zorder=5)\n",
    "\n",
    "# Plot hyperbola of equivelent initializations\n",
    "t = np.linspace(1e-5, 10, 1000)\n",
    "plot_trimmed(ax1, beta0[0]/t, beta0[1]/t, t, lims, lw=5, color='grey', zorder=1)\n",
    "plot_trimmed(ax1, -beta0[0]/t, -beta0[1]/t, -t, lims, lw=5, color='grey', zorder=-1)\n",
    "\n",
    "# Plot hyperbola of equivelent min norm OLS solutions\n",
    "a_opt = 1\n",
    "w_opt = OLS\n",
    "t = np.linspace(0.1, 10, 100)\n",
    "plot_trimmed(ax1, w_opt[0]/t, w_opt[1]/t, a_opt*t, lims, lw=5, color='k', zorder=-1)\n",
    "plot_trimmed(ax1, -w_opt[0]*t, -w_opt[1]*t, -a_opt/t, lims, lw=5, color='k', zorder=-1)\n",
    "\n",
    "# Plot Function Space Surface\n",
    "eps = 0.1\n",
    "lim1 = np.array([0.25, 1.75])# np.array([min(OLS[0], beta0[0]) - eps, max(OLS[0], beta0[0]) + eps])\n",
    "lim2 = np.array([-0.25, 1.25])# np.array([min(OLS[1], beta0[1]) - eps, max(OLS[1], beta0[1]) + eps])\n",
    "plot_loss(ax2, lim1, lim2, XX, XY)\n",
    "ax2.set_xlim(lim1[0] + eps, lim1[1] + eps)#ax2.set_xlim(1.1 * lim1)\n",
    "ax2.set_ylim(lim2[0] + eps, lim2[1] + eps)#ax2.set_ylim(1.1 * lim2)\n",
    "\n",
    "# Plot hyperbola sheet of OLS solutions\n",
    "plot_minima_surface(ax1, OLS, lims, color=\"red\", opacity=0.7)\n",
    "\n",
    "# Null space\n",
    "alpha = np.linspace(-2, 2, 100)\n",
    "v = np.array([-OLS[1], OLS[0]])\n",
    "plot_trimmed(ax2, OLS[0] + alpha*v[0], OLS[1] + alpha*v[1], None, [-4,4], lw=5, color='red', zorder=2)\n",
    "\n",
    "# Plot \\delta = -infty solution\n",
    "tau = np.linalg.norm(OLS)**2 / np.dot(OLS, beta0)\n",
    "ax2.plot([beta0[0], tau * beta0[0]], [beta0[1], tau * beta0[1]], label=r\"$\\delta = -\\infty$\", ls=\"-\", lw=5, c='#5C4033')\n",
    "ax2.plot([beta0[0], tau * beta0[0]], [beta0[1], tau * beta0[1]], ls=\"--\", lw=2, c='k')\n",
    "ax2.scatter(tau * beta0[0], tau * beta0[1], color='k', s=100, marker='o', zorder=4, edgecolor=\"#5C4033\", linewidths=3)\n",
    "\n",
    "for i, R in enumerate(R_values):\n",
    "\n",
    "    # Get color\n",
    "    c = get_diverging_colors_7(i) #cm.tab10(i)\n",
    "\n",
    "    # Initialization\n",
    "    w, a = coordinates(R, beta0)\n",
    "    theta0 = np.concatenate((w, a))\n",
    "    ax1.scatter(w[0], w[1], a, color='k', s=50, edgecolor=c, linewidths=2, zorder=3)\n",
    "\n",
    "    # Gradient Flow\n",
    "    T = 20\n",
    "    grad = get_gradient(XX, XY)\n",
    "    theta = solve_ivp(grad, [0, T], theta0, rtol=1e-6)\n",
    "    time = theta.t\n",
    "    w1_t = theta.y[0,:]\n",
    "    w2_t = theta.y[1,:]\n",
    "    a_t = theta.y[2,:]\n",
    "    beta1 = a_t * w1_t\n",
    "    beta2 = a_t * w2_t\n",
    "\n",
    "    # Plot parameter and function space trajectories\n",
    "    ax1.plot(w1_t, w2_t, a_t, color=c, lw=5, zorder=2)\n",
    "    ax2.plot(beta1, beta2, color=c, lw=5)#, label=r'$\\delta = {}$'.format(R))\n",
    "\n",
    "    # Plot theoretical interpolating solution\n",
    "    null = np.array([-OLS[1], OLS[0]]) / np.linalg.norm(OLS)\n",
    "    alpha = implicit_bias_theory(R, beta0, OLS)\n",
    "    optimal = OLS + alpha * null\n",
    "    w_opt, a_opt = coordinates(R, optimal)\n",
    "    ax1.scatter(w_opt[0], w_opt[1], a_opt, color='k', s=50, edgecolor=c, linewidths=2, zorder=3)\n",
    "    ax2.scatter(optimal[0], optimal[1], color='k', s=100, marker='o', zorder=3, edgecolor=c, linewidths=3)#, label=r'$\\delta = {}$'.format(R))\n",
    "\n",
    "# Plot \\delta = infty solution\n",
    "minnorm_sol = OLS + beta0 - np.dot(beta0, OLS) / np.linalg.norm(OLS)**2 * OLS\n",
    "ax2.plot([beta0[0], minnorm_sol[0]], [beta0[1], minnorm_sol[1]], label=r\"$\\delta = \\infty$\", ls=\"solid\", lw=5, c='#014D4E', zorder=4)\n",
    "ax2.plot([beta0[0], minnorm_sol[0]], [beta0[1], minnorm_sol[1]], ls=\"--\", lw=2, c='k', zorder=4)\n",
    "ax2.scatter(minnorm_sol[0], minnorm_sol[1], color='k', s=100, marker='o', zorder=4, edgecolor=\"#014D4E\", linewidths=3)\n",
    "\n",
    "# Style plots\n",
    "style_3D(ax1, lims=lims, nofill=True)\n",
    "ax1.set_xlabel(r\"$w_1$\", labelpad=15, fontsize=20)\n",
    "ax1.set_ylabel(r\"$w_2$\", labelpad=15, fontsize=20)\n",
    "ax1.set_zlabel(r\"$a$\", labelpad=15, fontsize=20)\n",
    "ax1.set_box_aspect(aspect=None, zoom=0.9)\n",
    "style_heatmaps(ax2)\n",
    "ax2.set_xlabel(r\"$\\beta_1$\", labelpad=-10, fontsize=30)\n",
    "ax2.set_ylabel(r\"$\\beta_2$\", labelpad=-10, fontsize=30)\n",
    "ax2.legend(loc=\"best\", fontsize=20, reverse=True)\n",
    "\n",
    "# Show the plot\n",
    "plt.show()\n",
    "\n",
    "if not os.path.exists(\"null-space\"):\n",
    "    os.makedirs(\"null-space\")\n",
    "fig1.savefig('null-space/parameter-space.pdf', bbox_inches='tight')\n",
    "fig2.savefig('null-space/function-space.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "IB9zoplatKrd",
   "metadata": {
    "id": "IB9zoplatKrd"
   },
   "source": [
    "## Appendix Figure: Verify that $\\nu$ and $\\theta$ match theory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62HTYyk9tJ1A",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "62HTYyk9tJ1A",
    "outputId": "4712db34-ea41-45fd-a6a3-00e3092f70e2"
   },
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "# R_values = [-2, -1, 0, 1, 2]  # a^2 - |w|^2\n",
    "R_values = [-2, -1, -0.2, 0.02, 0.2, 1, 2]    # avoiding 0 for time being\n",
    "# R_values = [-10, -30]\n",
    "\n",
    "# Setup Data\n",
    "XX = np.eye(2)\n",
    "# XY = np.array([0.0, 1.0])\n",
    "# XY = np.array([1/np.sqrt(2), 1/np.sqrt(2)])\n",
    "\n",
    "# Make not exactly anti-aligned to (0, -1)\n",
    "# XY = np.array([0.1, 1.1])\n",
    "XY = np.array([0.5, 0.5])\n",
    "# XY = np.array([0.219, -4.1])  # FIXME\n",
    "# Mainly, aligned, positive R, negative R\n",
    "\n",
    "# XY /= np.linalg.norm(XY)\n",
    "\n",
    "# OLS solution\n",
    "# (NOTE: this is the least-norm solution, but to be specific it is the only solution\n",
    "# since we are taking X^TX to be full-rank, i.e. X to be full-rank)\n",
    "OLS = np.linalg.pinv(XX) @ XY\n",
    "\n",
    "# Initialization\n",
    "# for _j, _beta0 in enumerate((np.array([-1., 0.]),  # orthogonal\n",
    "#                              np.array([0., -1.]),  # anti-aligned\n",
    "#                              np.array([0.5, -0.5]),  # neither but negative overlap\n",
    "#                              np.array([0.5, 0.5]))):  # neither\n",
    "\n",
    "# Is it only when beta is completely anti-aligned with OLS that w hits 0?\n",
    "\n",
    "# Bug is coming up with beta0 starting at the solution (exactly, i.e. no norm difference). Note that\n",
    "# we are always choosing a(0) > 0\n",
    "\n",
    "for _j, _beta0 in enumerate((np.array([-1., 0.]), np.array([0., -1.]), np.array([0.5, -0.5]), np.array([0.5, 0.5]))):  # neither\n",
    "\n",
    "    beta0 = _beta0 / np.linalg.norm(_beta0)\n",
    "\n",
    "    # This will have \\nu\n",
    "    fig, axs = plt.subplots(1, 9, figsize=(20, 3))\n",
    "\n",
    "    i = 0\n",
    "    for R in R_values:\n",
    "\n",
    "        # Initialization (NOTE: this is always setting a to be positive, which will mean that\n",
    "        # phi0 is always positive. Just note this because I think it simplifies things  )\n",
    "        w, a = coordinates(R, beta0)\n",
    "        assert a > 0\n",
    "        theta0 = np.concatenate((w, a))\n",
    "\n",
    "        # Gradient Flow\n",
    "        T = 10\n",
    "        grad = get_gradient(XX, XY)\n",
    "        # theta = solve_ivp(grad, [0, T], theta0, rtol=1e-6)\n",
    "        theta = solve_ivp(grad, [0, T], theta0, atol=1e-11, rtol=1e-11)\n",
    "        time = theta.t\n",
    "        w1_t = theta.y[0, :]\n",
    "        w2_t = theta.y[1, :]\n",
    "        a_t = theta.y[2, :]\n",
    "\n",
    "        # Can certainly blow up\n",
    "        w_t = np.stack([w1_t, w2_t], axis=1)  # (n_timesteps, 2)\n",
    "        # print(w_t[0, :], w)  # These do match\n",
    "        nu = (w_t / a_t[:, None]) @ OLS\n",
    "        if R > 0:\n",
    "            theta = np.arctanh(np.linalg.norm(w_t, axis=1) / a_t)\n",
    "            phi = np.cosh(theta)\n",
    "        elif R < 0:\n",
    "            theta = np.arctanh(a_t / np.linalg.norm(w_t, axis=1))\n",
    "            phi = np.sinh(theta)\n",
    "\n",
    "        # Compute using a and w since that's maybe a better test\n",
    "        signed_beta_norm = a_t * np.linalg.norm(w_t, axis=1)\n",
    "        mu = (w_t / np.linalg.norm(w_t, axis=1, keepdims=True)) @ OLS\n",
    "\n",
    "        nu_theory, phi_theory, theta_theory, signed_beta_norm_theory, mu_theory = \\\n",
    "            theory(R, time, w, a, OLS)\n",
    "\n",
    "        # Solve for w(t) and a(t)\n",
    "        a_theory, w_norm_theory = a_w_norm(R, theta_theory)\n",
    "        w_theory = get_w_theory(w, w_norm_theory, OLS, mu_theory, R, theta)\n",
    "\n",
    "        axs[0].plot(time, nu, lw=3, label=r'$a^2 - \\|w\\|^2 = {}$'.format(R))\n",
    "        axs[0].plot(time, nu_theory, lw=1, ls='--', color='k')\n",
    "\n",
    "        axs[1].plot(time, phi, lw=3, label=r'$a^2 - \\|w\\|^2 = {}$'.format(R))\n",
    "        axs[1].plot(time, phi_theory, lw=1, ls='--', color='k')\n",
    "        print(f\"phi[0]: {phi[0]}, nu[0]: {nu[0]}\")\n",
    "\n",
    "        axs[2].plot(time, theta, lw=3, label=r'$a^2 - \\|w\\|^2 = {}$'.format(R))\n",
    "        axs[2].plot(time, theta_theory, lw=1, ls='--', color='k')\n",
    "\n",
    "        axs[3].plot(time, signed_beta_norm, lw=3, label=r'$a^2 - \\|w\\|^2 = {}$'.format(R))\n",
    "        axs[3].plot(time, signed_beta_norm_theory, lw=1, ls='--', color='k')\n",
    "\n",
    "        axs[4].plot(time, mu, lw=3, label=r'$a^2 - \\|w\\|^2 = {}$'.format(R))\n",
    "        axs[4].plot(time, mu_theory, lw=1, ls='--', color='k')\n",
    "\n",
    "        axs[5].plot(time, a_t, lw=3, label=r'$a^2 - \\|w\\|^2 = {}$'.format(R))\n",
    "        axs[5].plot(time, a_theory, lw=1, ls='--', color='k')\n",
    "        axs[6].plot(time, np.linalg.norm(w_t, axis=1), lw=3)\n",
    "        axs[6].plot(time, w_norm_theory, lw=1, ls='--', color='k')\n",
    "\n",
    "        axs[7].plot(time, w_t[:, 0], lw=3)\n",
    "        axs[7].plot(time, w_theory[:, 0], lw=1, ls='--', color='k')\n",
    "        axs[8].plot(time, w_t[:, 1], lw=3)\n",
    "        axs[8].plot(time, w_theory[:, 1], lw=1, ls='--', color='k')\n",
    "\n",
    "    for ax in axs:\n",
    "        ax.grid()\n",
    "    axs[1].legend()\n",
    "\n",
    "    axs[0].set_title(\"$\\\\frac{w(t)}{a(t)}^\\intercal \\\\beta_*$\")\n",
    "    axs[1].set_title(\"$\\phi(t)$\")\n",
    "    axs[2].set_title(\"$\\\\theta(t)$\")\n",
    "    axs[3].set_title(\"$\\\\nu(t)=a(t)\\|w(t)\\|$\")\n",
    "    axs[4].set_title(\"$\\\\frac{w(t)}{\\|w(t)\\|}^\\intercal \\\\beta_*$\")\n",
    "\n",
    "    axs[5].set_title(\"$a(t)$\")\n",
    "    axs[6].set_title(\"$\\|w(t)\\|$\")\n",
    "    axs[7].set_title(\"$w_1(t)$\")\n",
    "    axs[8].set_title(\"$w_2(t)$\")\n",
    "\n",
    "    if _j == 1 or _j == 2:\n",
    "        axs[0].set_ylim(-3, 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ad47614-ab75-4add-8982-5e25e6ed0426",
   "metadata": {
    "id": "4ad47614-ab75-4add-8982-5e25e6ed0426"
   },
   "source": [
    "## Appendix Figure: Basin of Attraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bab63d82-0588-4e3f-aa37-6da0af02686b",
   "metadata": {
    "id": "bab63d82-0588-4e3f-aa37-6da0af02686b",
    "outputId": "4fea67f4-1aff-4698-8831-681c83ef28a8"
   },
   "outputs": [],
   "source": [
    "# Create figures\n",
    "fig1 = plt.figure(figsize=(8, 8))\n",
    "fig2 = plt.figure(figsize=(8, 8))\n",
    "fig3 = plt.figure(figsize=(8, 8))\n",
    "\n",
    "# Create axes\n",
    "ax1 = fig1.add_subplot(111, projection='3d', computed_zorder=False)  # First plot is a 3D plot\n",
    "ax2 = fig2.add_subplot(111)  # Third subplot is a 2D plot\n",
    "ax3 = fig3.add_subplot(111)  # Third subplot is a 2D plot\n",
    "\n",
    "# Setup Data\n",
    "XX = np.eye(2)\n",
    "XY = np.array([1, 0.25])\n",
    "\n",
    "# OLS solution\n",
    "OLS = np.linalg.pinv(XX) @ XY\n",
    "\n",
    "# Seperating Surface\n",
    "lims = (-5, 5)\n",
    "# lim1, lim2 = 40, 5\n",
    "plot_basin_seperating_surface(ax1, lims, XX, XY, opacity=0.75)\n",
    "\n",
    "# Plot hyperbola of equivelent OLS solutions\n",
    "w_opt, a_opt = coordinates(0, OLS)\n",
    "t = np.linspace(1, 10, 1000)\n",
    "plot_trimmed(ax1, w_opt[0]/t, w_opt[1]/t, a_opt*t, lims, lw=3, color='red', zorder=100)\n",
    "plot_trimmed(ax1, w_opt[0]*t, w_opt[1]*t, a_opt/t, lims, lw=3, color='red', zorder=100)\n",
    "plot_trimmed(ax1, -w_opt[0]*t, -w_opt[1]*t, -a_opt/t, lims, lw=3, color='red', zorder=-100)\n",
    "plot_trimmed(ax1, -w_opt[0]/t, -w_opt[1]/t, -a_opt*t, lims, lw=3, color='red', zorder=-100)\n",
    "\n",
    "t = np.linspace(-10, 10, 1000)\n",
    "plot_trimmed(ax1, -OLS[1]*t, OLS[0]*t, np.zeros_like(t), lims, lw=2, color='k', zorder=90)\n",
    "\n",
    "deltas = [0, -2, -10]\n",
    "epsilon = [1e-3, -1e-3]\n",
    "\n",
    "for i, delta in enumerate(deltas):\n",
    "    for j, eps in enumerate(epsilon):\n",
    "        # Initialization on the seperating surface (can do by choosing w instead)\n",
    "        alpha = -4\n",
    "        k = alpha * (-delta / 2 - np.sqrt(delta**2 + 4 * np.linalg.norm(OLS)**2) / 2)\n",
    "        a = OLS[0]**2 / OLS[1]**2 + 1\n",
    "        b = -2*k*OLS[0]/OLS[1]**2\n",
    "        c = k**2 / OLS[1]**2 + delta - alpha**2\n",
    "        w1 = (-b + np.sqrt(b**2 - 4 * a * c)) / (2 * a)\n",
    "        w2 = (k - OLS[0]*w1) / OLS[1]\n",
    "        a = alpha\n",
    "\n",
    "        w, a = np.array([w1, w2]), np.array([a + eps])\n",
    "        theta0 = np.concatenate((w, a))\n",
    "        ax1.scatter(w[0], w[1], a, color='k', s=50, edgecolor=cm.tab10(i), linewidths=2, zorder=2 * np.sign(eps))\n",
    "\n",
    "        # Gradient Flow\n",
    "        T = 200\n",
    "        grad = get_gradient(XX, XY)\n",
    "        theta = solve_ivp(grad, [0, T], theta0, rtol=1e-6)\n",
    "        time = theta.t\n",
    "        w1_t = theta.y[0,:]\n",
    "        w2_t = theta.y[1,:]\n",
    "        a_t = theta.y[2,:]\n",
    "        beta1 = a_t * w1_t\n",
    "        beta2 = a_t * w2_t\n",
    "        norm = np.sqrt(beta1**2 + beta2**2)\n",
    "        mu = (OLS[0]*beta1 + OLS[1]*beta2) / (norm * np.linalg.norm(OLS))\n",
    "\n",
    "        # Plot parameter and function space trajectories\n",
    "        if eps > 0:\n",
    "            ax1.plot(w1_t, w2_t, a_t, color=cm.tab10(i), lw=3, ls='-', label=r'$a^2 - \\|w\\|^2 = {}$'.format(delta), zorder=1)\n",
    "            ax2.plot(time, mu, color=cm.tab10(i), lw=3, ls='-', label=r'$a^2 - \\|w\\|^2 = {}$'.format(delta))\n",
    "            ax3.plot(time, norm, color=cm.tab10(i), lw=3, ls='-', label=r'$a^2 - \\|w\\|^2 = {}$'.format(delta))\n",
    "        else:\n",
    "            ax1.plot(w1_t, w2_t, a_t, color=cm.tab10(i), lw=3, ls='--', zorder=-1)\n",
    "            ax2.plot(time, mu, color=cm.tab10(i), lw=3, ls='--')\n",
    "            ax3.plot(time, norm, color=cm.tab10(i), lw=3, ls='--')\n",
    "\n",
    "# Show the plot\n",
    "style_3D(ax1, lims=lims, nofill=True)\n",
    "ax1.set_xlabel(r\"$w_1$\", labelpad=15, fontsize=20)\n",
    "ax1.set_ylabel(r\"$w_2$\", labelpad=15, fontsize=20)\n",
    "ax1.set_zlabel(r\"$a$\", labelpad=15, fontsize=20)\n",
    "ax1.set_box_aspect(aspect=None, zoom=0.9)\n",
    "\n",
    "style_axes(ax2)\n",
    "ax2.set_xlabel(r\"time ($t$)\", fontsize=24)\n",
    "ax2.set_ylabel(r\"$\\frac{\\beta^\\intercal \\beta_*}{\\|\\beta\\|\\|\\beta_*\\|}$\", fontsize=24)\n",
    "ax2.legend(loc=\"best\")\n",
    "\n",
    "style_axes(ax3)\n",
    "ax3.set_xlabel(r\"time ($t$)\", fontsize=24)\n",
    "ax3.set_ylabel(r\"$\\|\\beta\\|$\", fontsize=24)\n",
    "ax3.legend(loc=\"best\")\n",
    "\n",
    "fig1.tight_layout()\n",
    "fig2.tight_layout()\n",
    "fig3.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Save figure\n",
    "if not os.path.exists(\"basins\"):\n",
    "    os.makedirs(\"basins\")\n",
    "fig1.savefig('basins/seperating-surface.pdf')\n",
    "fig2.savefig('basins/alignment.pdf')\n",
    "fig3.savefig('basins/norm.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33b3b5fb-9052-410f-9cc7-44fa2c34a3d3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
