{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "ICLR.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rc9IXqlF-K6Y"
      },
      "source": [
        "# This notebook has the code used to generate of the all the data that was used for the plots in the paper \"Training Data Size Induced Double Descent for Denosing Neural Networks and the Role of Training Noise Level. \n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "u3OEn4NI3nGf"
      },
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "from PIL import Image\n",
        "import seaborn as sb\n",
        "from tqdm import tqdm\n",
        "\n",
        "import torchvision\n",
        "import torch.nn.functional as F\n",
        "from torch.nn import init\n",
        "import torchvision.transforms as Tranforms\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.utils.data import sampler\n",
        "import torch.nn as nn\n",
        "import matplotlib.gridspec as gridspec\n",
        "from IPython.display import clear_output\n",
        "%matplotlib inline"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mAAiyZot93pb"
      },
      "source": [
        "# Set Up"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "49knHVnu3sA1"
      },
      "source": [
        "def gen_data(m,n,r, S1 = None, S2 = None, device = 'cpu'):\n",
        "  U = torch.svd(torch.randn(m,r)).U[:,:r].to(device)\n",
        "  V1 = torch.svd(torch.randn(r,n)).V[:,:r].to(device)\n",
        "  V2 = torch.svd(torch.randn(r,n)).V[:,:r].to(device)\n",
        "\n",
        "  if S1 == None:\n",
        "    S1 = torch.diag(torch.randn(r).square()).to(device)\n",
        "  if S2 == None:\n",
        "    S2 = torch.diag(torch.randn(r).square()).to(device)\n",
        "\n",
        "  X = U.mm(S1.mm(V1.t()))\n",
        "  Xtst = U.mm(S2.mm(V2.t()))\n",
        "\n",
        "  return X, Xtst, S1, S2\n",
        "\n",
        "def gen_data_test(m, ntrn, ntst, r, S1 = None, S2 = None, device = 'cpu'):\n",
        "  U = torch.svd(torch.randn(m,r)).U[:,:r].to(device)\n",
        "  V1 = torch.svd(torch.randn(r,ntrn)).V[:,:r].to(device)\n",
        "  V2 = torch.svd(torch.randn(r,ntst)).V[:,:r].to(device)\n",
        "\n",
        "  if S1 == None:\n",
        "    S1 = torch.diag(torch.randn(r).square()).to(device)\n",
        "  if S2 == None:\n",
        "    S2 = torch.diag(torch.randn(r).square()).to(device)\n",
        "\n",
        "  S1 = S1.to(device)\n",
        "  S2 = S2.to(device)\n",
        "\n",
        "  X = U.mm(S1.mm(V1.t()))\n",
        "  Xtst = U.mm(S2.mm(V2.t()))\n",
        "\n",
        "  return X, Xtst, S1, S2\n",
        "\n",
        "def cal_term_recon(c,thetatrn,thetatst):\n",
        "  if c <= 1:\n",
        "    return (thetatst**2)/(1+(thetatrn**2)*c)**2\n",
        "  else:\n",
        "    return (thetatst**2)/(1+(thetatrn**2))**2\n",
        "\n",
        "def cal_term(c, theta):\n",
        "  if np.abs(c-1) < 1e-6:\n",
        "    return np.inf\n",
        "  if c < 1:\n",
        "    return (c**2)*((theta**2+theta**4)/((1+(theta**2)*c)**2))/(1-c)\n",
        "  else:\n",
        "    return (c/(c-1))*(theta**2)/(1+(theta**2))\n",
        "\n",
        "def gen_error_pair(ntrn, ntst, m, c, thetatrn, thetatst):\n",
        "  recon = cal_term_recon(c, thetatrn, thetatst)/ntst\n",
        "  norm = cal_term(c,thetatrn)/m\n",
        "\n",
        "  return recon+norm\n",
        "\n",
        "def gen_error(ntrn, ntst, m, c, theta, psi, Strn, Stst):\n",
        "  r = Strn.shape[0]\n",
        "  error = 0\n",
        "  for i in range(r):\n",
        "    error+= gen_error_pair(ntrn, ntst, m, c, theta*Strn[i,i], psi*Stst[i,i])\n",
        "  return error\n",
        "\n",
        "def adjust_s(c,m,s):\n",
        "  if c < 1:\n",
        "    return (s.square()*(1-c/(2-c))-c/(M*(2-c))).relu().sqrt()\n",
        "\n",
        "def calc_opt(c,M,N,thetatst):\n",
        "  if c < 1:\n",
        "    thetatrn = (thetatst.square()*(1-c/(2-c)) - c/(M*(2-c))).relu().sqrt()\n",
        "  else:\n",
        "    thetatrn = (2*thetatst.square()*(c-1) - 1/N).relu().sqrt()\n",
        "  return thetatrn\n",
        "\n",
        "def calc_opt_unnormalized(c, M, Ntrn, Ntst, thetatst):\n",
        "  thetatrn = calc_opt(c, M, Ntrn, thetatst/np.sqrt(Ntst))*np.sqrt(Ntrn)\n",
        "  return thetatrn\n",
        "\n",
        "def gen_noise_spherical(m,n, bi=False):\n",
        "  A = torch.nn.functional.normalize(torch.randn(m,n), dim = 0)\n",
        "  if bi:\n",
        "    V = torch.linalg.svd(torch.randn(1,n)).Vh.t()\n",
        "    A = A.mm(V)\n",
        "  return A\n",
        "\n",
        "def gen_noise_rademacher(m,n, bi=False):\n",
        "  U = torch.linalg.svd(torch.randn(m,1)).U\n",
        "  A = U.mm(torch.randn(m,n).sign())/np.sqrt(m)\n",
        "  if bi:\n",
        "    V = torch.linalg.svd(torch.randn(1,n)).Vh.t()\n",
        "    A = A.mm(V)\n",
        "  return A\n",
        "\n",
        "def gen_noise_poisson(m,n, bi=False):\n",
        "  U = torch.linalg.svd(torch.randn(m,1)).U\n",
        "  A = U.mm(torch.poisson(torch.ones(m,n)) - 1)/np.sqrt(m)\n",
        "  if bi:\n",
        "    V = torch.linalg.svd(torch.randn(1,n)).Vh.t()\n",
        "    A = A.mm(V)\n",
        "  return A\n",
        "\n",
        "def gen_noise_bernoulli(m,n, bi=False):\n",
        "  U = torch.linalg.svd(torch.randn(m,1)).U\n",
        "  A = U.mm(torch.bernoulli(torch.ones(m,n)/2) - 0.5)/np.sqrt(m/4)\n",
        "  if bi:\n",
        "    V = torch.linalg.svd(torch.randn(1,n)).Vh.t()\n",
        "    A = A.mm(V)\n",
        "  return A\n"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7mMCSmNu37Tl"
      },
      "source": [
        "# Theoretical Curves for rank 1 data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Mxt3MmF3-Kr"
      },
      "source": [
        "## Changing $c$ by changing $N$"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eOCd8UCN36jz"
      },
      "source": [
        "M = 1000\n",
        "N = torch.arange(2000,100,-100)\n",
        "Ntst = 1000\n",
        "Err = torch.zeros(N.shape[0])\n",
        "bias = torch.zeros(N.shape[0])\n",
        "var = torch.zeros(N.shape[0])\n",
        "Err_emp = torch.zeros(N.shape[0])\n",
        "Err_emp_bias = torch.zeros(N.shape[0])\n",
        "Err_emp_var = torch.zeros(N.shape[0])\n",
        "\n",
        "thetatst = torch.diag(torch.tensor([1.0]))*np.sqrt(Ntst)\n",
        "T = 50\n",
        "for i in range(N.shape[0]):\n",
        "  c = M/N[i]\n",
        "  print(c, N[i])\n",
        "  thetatrn = torch.diag(torch.tensor([1.0]))*N[i].sqrt() #torch.diag(calc_opt(c,M,N[i],torch.tensor([1.0])))\n",
        "  print(thetatrn)\n",
        "  for k in range(T):\n",
        "    X, Xtst, S1, S2 = gen_data_test(M,N[i], Ntst,1,S1 = thetatrn, S2 = torch.diag(torch.tensor([1.0]))*N[i].sqrt())\n",
        "    A = torch.randn_like(X)/np.sqrt(M)\n",
        "    Atst = torch.randn_like(Xtst)/np.sqrt(M)\n",
        "    W = X.mm(torch.pinverse(X+A))\n",
        "    Yp = W.mm(Xtst + Atst)\n",
        "    Err_emp[i] += (Xtst - Yp).square().sum()/(T*Ntst)\n",
        "    Err_emp_bias[i] += (Xtst - W.mm(Xtst)).square().sum()/(T*Ntst)\n",
        "    Err_emp_var[i] += (W).square().sum()/(T*Ntst)\n",
        "  \n",
        "  var[i] = cal_term(c, thetatrn[0,0])\n",
        "  bias[i] = cal_term_recon(c, thetatrn[0,0], thetatst[0,0])\n",
        "  Err[i] = gen_error(N[i],Ntst,M,c,1,1,thetatrn, thetatst)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9hBkSGsX3sRK"
      },
      "source": [
        "v = Err_emp_var*Ntst\n",
        "v[10] = np.inf"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KjiXPZwt4HEB"
      },
      "source": [
        "plt.plot(N/M, var)\n",
        "plt.xlabel(\"1/C = Ntrn/M\")\n",
        "plt.ylabel(\"Variance\")\n",
        "plt.legend()\n",
        "plt.savefig(\"N-var\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W2q4k7iM4KGU"
      },
      "source": [
        "b = Err_emp_bias*Ntst\n",
        "b[10] = np.inf\n",
        "plt.plot(N/M, bias)\n",
        "plt.xlabel(\"1/C = Ntrn/M\")\n",
        "plt.ylabel(\"bias\")\n",
        "plt.legend()\n",
        "plt.savefig(\"N-bias\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KBLVZPTb4UmC"
      },
      "source": [
        "## Changing $c$ by changing $M$"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sDPgzz9q4Kjd"
      },
      "source": [
        "N = 1000\n",
        "M = torch.arange(10,2000,10)\n",
        "Err = torch.zeros(M.shape[0])\n",
        "var = torch.zeros(M.shape[0])\n",
        "bias = torch.zeros(M.shape[0])\n",
        "Err_emp = torch.zeros(M.shape[0])\n",
        "Err_emp_bias = torch.zeros(M.shape[0])\n",
        "Err_emp_var = torch.zeros(M.shape[0])\n",
        "T = 5\n",
        "\n",
        "thetatst = torch.diag(torch.tensor([1.0]))*np.sqrt(N)\n",
        "for i in range(M.shape[0]):\n",
        "  c = M[i]/N\n",
        "  print(c, N)\n",
        "  thetatrn = torch.diag(torch.tensor([1.0]))*np.sqrt(N) #torch.diag(calc_opt(c,M[i],N,torch.tensor([1.0])))\n",
        "  print(thetatrn)\n",
        "  for j in range(T):\n",
        "    X, Xtst, S1, S2 = gen_data(M[i],N,1,S1 = thetatrn, S2 = torch.diag(torch.tensor([1.0])))\n",
        "    A = torch.randn_like(X)/np.sqrt(M[i])\n",
        "    Atst = torch.randn_like(Xtst)/np.sqrt(M[i])\n",
        "    W = X.mm(torch.pinverse(X+A))\n",
        "    Yp = W.mm(Xtst + Atst)\n",
        "    Err_emp[i] += (Xtst - Yp).square().sum()/(T*N)\n",
        "    Err_emp_bias[i] += (Xtst - W.mm(Xtst)).square().sum()/(T*N)\n",
        "    Err_emp_var[i] += (W).square().sum()/(T*N)\n",
        "  \n",
        "  var[i] = cal_term(c, thetatrn[0,0])\n",
        "  bias[i] = cal_term_recon(c, thetatrn[0,0], thetatst[0,0])\n",
        "  Err[i] = gen_error(N,N,M[i],c,1,1,S1,S2)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3sdR_SD-4ua6"
      },
      "source": [
        "plt.plot(M/N, var)\n",
        "plt.xlabel(\"C = M/Ntrn\")\n",
        "plt.ylabel(\"Variance\")\n",
        "plt.legend()\n",
        "plt.savefig(\"M-var\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AtHfPKMm4xr0"
      },
      "source": [
        "plt.plot(M/N, bias)\n",
        "plt.xlabel(\"C = M/Ntrn\")\n",
        "plt.ylabel(\"bias\")\n",
        "plt.legend()\n",
        "plt.savefig(\"M-bias\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "baiszhmq40nB"
      },
      "source": [
        "# Approximation Error for the Formula"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5veRywk09-Ob"
      },
      "source": [
        "## Low SNR"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gjKbakLI426L"
      },
      "source": [
        "#Generate Strn and Stst\n",
        "C = 10\n",
        "\n",
        "theory_error = torch.zeros(11,2*C).to('cuda')\n",
        "emperical_error = torch.zeros(11,2*C).to('cuda')\n",
        "\n",
        "theory_norm = torch.zeros(11,2*C).to('cuda')\n",
        "emperical_norm = torch.zeros(11,2*C).to('cuda')\n",
        "\n",
        "theory_recon = torch.zeros(11,2*C).to('cuda')\n",
        "emperical_recon = torch.zeros(11,2*C).to('cuda')\n",
        "\n",
        "T = 10\n",
        "\n",
        "R = [1,2,3,5,10,20,50,100,150,200,250]\n",
        "r_idx = 0\n",
        "\n",
        "for r in R:\n",
        "\n",
        "    for i in range(2*C):\n",
        "        Strn = torch.diag(torch.randn(r).square()).to('cuda')\n",
        "        Stst = torch.diag(torch.randn(r).square()).to('cuda')\n",
        "        m = 2500\n",
        "        n = np.maximum(r,int(((i+1)/C)*m))\n",
        "        c = m/n\n",
        "\n",
        "        print(\"n = \",n,\", m = \",m,\", c = \",c,\", r = \",r)\n",
        "\n",
        "        theory_error[r_idx, i] = gen_error(n,n,m,c,1,1,Strn,Stst)\n",
        "\n",
        "        for k in range(r):\n",
        "            theory_recon[r_idx, i] += cal_term_recon(c, Strn[k,k], Stst[k,k])/n\n",
        "            theory_norm[r_idx, i] += cal_term(c, Strn[k,k])\n",
        "\n",
        "        #generate data\n",
        "        for t in range(T):\n",
        "            X, Xtst, S1, S2 = gen_data(m,n,r,S1 = Strn, S2 = Stst, device = 'cuda')\n",
        "\n",
        "            X = torch.ones(T,1,1, device = X.device)*X\n",
        "            Xtst = torch.ones(T,1,1, device = Xtst.device)*Xtst\n",
        "            A = torch.randn_like(X)/np.sqrt(m)\n",
        "            Atst = torch.randn_like(Xtst)/np.sqrt(m)\n",
        "\n",
        "            Y = X+A\n",
        "            Ytst = Xtst + Atst\n",
        "\n",
        "            W = torch.matmul(X,Y.pinverse())\n",
        "            emperical_error[r_idx, i] += (Xtst - W.bmm(Ytst)).square().sum()/(T*T*n)\n",
        "\n",
        "        print((theory_error[r_idx, i]-emperical_error[r_idx, i])/emperical_error[r_idx, i])\n",
        "\n",
        "    r_idx += 1"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-oh-3kar-A5q"
      },
      "source": [
        "## High SNR"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "I7tnAAqx-D6z"
      },
      "source": [
        "#Generate Strn and Stst\n",
        "C = 10\n",
        "\n",
        "theory_error2 = torch.zeros(11,2*C).to('cuda')\n",
        "emperical_error2 = torch.zeros(11,2*C).to('cuda')\n",
        "\n",
        "theory_norm2 = torch.zeros(11,2*C).to('cuda')\n",
        "emperical_norm2 = torch.zeros(11,2*C).to('cuda')\n",
        "\n",
        "theory_recon2 = torch.zeros(11,2*C).to('cuda')\n",
        "emperical_recon2 = torch.zeros(11,2*C).to('cuda')\n",
        "\n",
        "T = 10\n",
        "\n",
        "R = [1,2,3,5,10,20,50,100,150,200,250]\n",
        "r_idx = 0\n",
        "\n",
        "for r in R:\n",
        "\n",
        "    for i in range(2*C):\n",
        "        m = 2500\n",
        "        n = np.maximum(r,int(((i+1)/C)*m))\n",
        "        c = m/n\n",
        "        Strn = (torch.diag(torch.randn(r)).square()*np.sqrt(n)).to('cuda')\n",
        "        Stst = (torch.diag(torch.randn(r)).square()*np.sqrt(n)).to('cuda')\n",
        "        \n",
        "\n",
        "        print(\"n = \",n,\", m = \",m,\", c = \",c,\", r = \",r)\n",
        "\n",
        "        theory_error2[r_idx, i] = gen_error(n,n,m,c,1,1,Strn,Stst)\n",
        "\n",
        "        for k in range(r):\n",
        "            theory_recon2[r_idx, i] += cal_term_recon(c, Strn[k,k], Stst[k,k])/n\n",
        "            theory_norm2[r_idx, i] += cal_term(c, Strn[k,k])\n",
        "\n",
        "        #generate data\n",
        "        for t in range(T):\n",
        "            X, Xtst, S1, S2 = gen_data(m,n,r,S1 = Strn, S2 = Stst, device = 'cuda')\n",
        "\n",
        "            X = torch.ones(T,1,1, device = X.device)*X\n",
        "            Xtst = torch.ones(T,1,1, device = Xtst.device)*Xtst\n",
        "            A = torch.randn_like(X)/np.sqrt(m)\n",
        "            Atst = torch.randn_like(Xtst)/np.sqrt(m)\n",
        "\n",
        "            Y = X+A\n",
        "            Ytst = Xtst + Atst\n",
        "\n",
        "            W = torch.matmul(X,Y.pinverse())\n",
        "            emperical_error2[r_idx, i] += (Xtst - W.bmm(Ytst)).square().sum()/(T*T*n)\n",
        "\n",
        "        print((theory_error2[r_idx, i]-emperical_error2[r_idx, i])/emperical_error2[r_idx, i])\n",
        "\n",
        "    r_idx += 1"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TAIMPOdc7ljd"
      },
      "source": [
        "# Introduction Double Descent Figure"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Db-bHmUd8OmM"
      },
      "source": [
        "## Deep Network Denoising  "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zG1-EnNQ86Eh"
      },
      "source": [
        "### Linear Rank 1 and Non Linear Synthetic"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f8vXLMuH7ohs"
      },
      "source": [
        "W0 = torch.randn(100,100)\n",
        "W1 = torch.randn(100,100)\n",
        "\n",
        "r = 1 # change rank for non - linear\n",
        "U = torch.svd(torch.randn(100,100)).U[:,:r]\n",
        "Vtrn = torch.svd(torch.randn(250000,r)).U[:,:r]\n",
        "Vtst = torch.svd(torch.randn(1000,r)).U[:,:r]\n",
        "\n",
        "X = U.mm(Vtrn.t()).t() # for linear\n",
        "Xtst = U.mm(Vtst.t()).to('cuda').t() # for linear\n",
        "\n",
        "X = W1.mm(W0.mm(X.t()).relu()).relu() # For non linear synthetic\n",
        "Xtst = W1.mm(W0.mm(Xtst.t().cpu()).relu()).relu() # For non linear synthetic\n",
        "\n",
        "device = 'cuda'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1EZpERlq8B76"
      },
      "source": [
        "def do_sim(theta_train, theta_test, Xtrn, Xtst, bias = False):\n",
        "    M = Xtrn.shape[1]\n",
        "    Ntrn = Xtrn.shape[0]\n",
        "\n",
        "    X = theta_train*Xtrn + torch.randn(Ntrn,M, device = device)/np.sqrt(M)\n",
        "    Y = theta_train*Xtrn\n",
        "\n",
        "    X_tst = theta_test*Xtst + torch.randn(Ntst,M, device = device)/np.sqrt(M)\n",
        "    \n",
        "    model = torch.nn.Sequential(nn.Linear(M,M, bias = bias),\n",
        "                            nn.ReLU(),\n",
        "                            nn.Linear(M,M, bias = bias),\n",
        "                            nn.ReLU(),\n",
        "                            nn.Linear(M,M, bias = bias))\n",
        "\n",
        "    model = model.to(device)\n",
        "\n",
        "    optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)\n",
        "    \n",
        "    for i in range(1500):\n",
        "        optimizer.zero_grad()\n",
        "        Y_pred = model(X)\n",
        "        loss = nn.functional.mse_loss(Y_pred,Y)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "    with torch.no_grad():\n",
        "        Y_tst = model(X_tst)\n",
        "        error = nn.functional.mse_loss(Y_tst,theta_test*Xtst)\n",
        "        w_norms = 0\n",
        "        for param in model.parameters():\n",
        "            w_norms += param.square().sum()\n",
        "\n",
        "    del model\n",
        "    del X,Y,Y_tst,X_tst\n",
        "    return error, w_norms "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JKJHF84k8GzR"
      },
      "source": [
        "Ndata = torch.tensor([100,125,150,175,200,250,300,500,700,900,\n",
        "         1000,1500,2000,2500,3000,4000,5000, 6000, 7000, 8000, 9000, 10000,15000\n",
        "         ,20000,25000,30000,40000,50000,60000,70000,80000,90000,100000])\n",
        "\n",
        "Ntst = 1000\n",
        "psi_norm = 0.1\n",
        "\n",
        "T = 10    \n",
        "avg_error = torch.zeros((len(Ndata),3,T), device = device)\n",
        "theta = torch.zeros((len(Ndata),3), device = device)\n",
        "w_norm = torch.zeros((len(Ndata),3,T), device = device)\n",
        "    \n",
        "for k in range(len(Ndata)):\n",
        "    Ntrn = Ndata[k]\n",
        "    print(Ntrn)\n",
        "    M = 100\n",
        "\n",
        "    C = M/Ntrn\n",
        "\n",
        "    Xtrn = X[:Ntrn,:].to(device)\n",
        "\n",
        "    print(Xtrn.shape, Xtst.shape)\n",
        "\n",
        "    Xtrn *= np.sqrt(Ntrn)/Xtrn.norm()\n",
        "    Xtst *= np.sqrt(Ntst)/Xtst.norm()\n",
        "    \n",
        "    print(Xtrn.shape, Xtst.shape)\n",
        "\n",
        "    for i in range(3):\n",
        "        if i == 0:\n",
        "            theta[k,i] = 0.5*psi_norm\n",
        "        elif i ==1:\n",
        "            theta[k,i] = psi_norm\n",
        "        else:\n",
        "            theta[k,i] = 2*psi_norm\n",
        "\n",
        "        for j in tqdm(range(T)):\n",
        "            a, w = do_sim(theta[k,i], theta[k,i], Xtrn, Xtst, bias = False)\n",
        "            avg_error[k,i,j] = a\n",
        "            w_norm[k,i,j] = w\n",
        "\n",
        "    \n",
        "    number = '%s.pt'%Ntrn\n",
        "    torch.save(avg_error,\"rank50nl-3l-error-without-bias-\"+number)\n",
        "    torch.save(theta,\"rank50nl-3l-theta-without-bias-\"+number)\n",
        "    torch.save(w_norm,\"rank50nl-3l-wnorm-without-bias-\"+number)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fE0FDHYx8nOt"
      },
      "source": [
        "### For MNIST"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1YeLQ8aH8mR2"
      },
      "source": [
        "mnist_train = torchvision.datasets.MNIST('./MNIST_data', train=True, download=True,\n",
        "                           transform=Tranforms.ToTensor())\n",
        "mnist_test = torchvision.datasets.MNIST('./MNIST_data', train=False, download=True,\n",
        "                           transform=Tranforms.ToTensor())\n",
        "\n",
        "Ndata = [30,40,50,100,200,300,500,750,1000,1200,1400,1600,1800,2000,2500,3000,4000,5000,7500,10000,15000,20000,30000,40000,50000]\n",
        "\n",
        "for Ntrn in Ndata:\n",
        "    print(Ntrn)\n",
        "\n",
        "    C = M/Ntrn\n",
        "\n",
        "    X = mnist_train.data\n",
        "    y = mnist_train.targets\n",
        "    X = X.to(device).to(torch.float32)/256\n",
        "    \n",
        "    print(X.max())\n",
        "\n",
        "    Xtst = mnist_test.data\n",
        "    Xtst = Xtst.to(device).to(torch.float32)/256\n",
        "\n",
        "    Xtrn = X[:Ntrn,:,:].to(device).to(torch.float32)\n",
        "\n",
        "    Ntst = Xtst.shape[0]\n",
        "    psi_norm = 0.1\n",
        "    psi = psi_norm*np.sqrt(Ntst)\n",
        "\n",
        "    Xtrn = Xtrn.reshape(Ntrn,784)\n",
        "    Xtst = Xtst.reshape(Ntst,784)\n",
        "    \n",
        "    print(Xtrn.shape, Xtst.shape)\n",
        "\n",
        "    T = 50\n",
        "    N = 40\n",
        "\n",
        "    avg_error = torch.zeros((2,T), device = device)\n",
        "    theta = torch.zeros(2, device = device)\n",
        "    w_norm = torch.zeros((2,T), device = device)\n",
        "\n",
        "    Xtrn *= np.sqrt(Ntrn)/Xtrn.norm()\n",
        "    Xtst *= np.sqrt(Ntst)/Xtst.norm()\n",
        "\n",
        "    for i in range(2):\n",
        "        if i == 0:\n",
        "            theta[i] = 0.5*psi_norm\n",
        "        else:\n",
        "            theta[i] = 2*psi_norm\n",
        "\n",
        "        for j in range(T):\n",
        "            a, w = do_sim(theta[i], theta[i], Xtrn, Xtst, bias = False)\n",
        "            avg_error[i,j] = a\n",
        "            w_norm[i,j] = w\n",
        "\n",
        "    \n",
        "    number = '%s.pt'%Ntrn\n",
        "    torch.save(avg_error,\"moretheta-3l-error-without-bias-finer-\"+number)\n",
        "    torch.save(theta,\"moretheta-3l-theta-without-bias-finer-\"+number)\n",
        "    torch.save(w_norm,\"moretheta-3l-wnorm-without-bias-finer-\"+number)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HJlws4Ph8-6N"
      },
      "source": [
        "## Linear Network"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FZGXqvf29BCB"
      },
      "source": [
        "import torch\n",
        "import torchvision\n",
        "import torchvision.transforms as Tranforms\n",
        "\n",
        "mnist_train = torchvision.datasets.MNIST('./MNIST_data', train=True, download=True,\n",
        "                           transform=Tranforms.ToTensor())\n",
        "mnist_test = torchvision.datasets.MNIST('./MNIST_data', train=False, download=True,\n",
        "                           transform=Tranforms.ToTensor())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5QfwOgd-9HRt"
      },
      "source": [
        "X = mnist_train.data\n",
        "X = X.to(torch.float32).reshape(-1,784)/256"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xXOkwqd19LK6"
      },
      "source": [
        "Xtst = mnist_test.data\n",
        "Xtst = Xtst.to(torch.float32).reshape(-1,784)/256\n",
        "Xtst = Xtst.t().to('cuda')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f1wa-udE9Lk1"
      },
      "source": [
        "W0 = torch.randn(784,784)\n",
        "W1 = torch.randn(784,784)\n",
        "\n",
        "r = 5\n",
        "U = torch.svd(torch.randn(784,784)).U[:,:r]\n",
        "Vtrn = torch.svd(torch.randn(50000,r)).U[:,:r]\n",
        "Vtst = torch.svd(torch.randn(10000,r)).U[:,:r]\n",
        "\n",
        "X = W1.mm(W0.mm(U.mm(Vtrn.t())).relu()).relu().t()\n",
        "Xtst = W1.mm(W0.mm(U.mm(Vtst.t())).relu()).relu().to('cuda')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5gJCptLF9Nc3"
      },
      "source": [
        "Ntst = 10000"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4S8VX9MW9PmH"
      },
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "m = 784\n",
        "\n",
        "Ns = torch.tensor([100,200,300,400,500,600,700,725,750,775,800,850,900,1000,1500,2000,2500,3000,5000,7500,10000])\n",
        "theta = [0.5,1,2]\n",
        "\n",
        "emperical_norm  = torch.zeros(3,Ns.shape[0]).to('cuda')\n",
        "emperical_error = torch.zeros(3,Ns.shape[0]).to('cuda')\n",
        "\n",
        "print(emperical_error)\n",
        "\n",
        "for i in range(Ns.shape[0]):\n",
        "  Ntrn = Ns[i]\n",
        "  c = m/Ntrn\n",
        "\n",
        "  Xtrn = X[:Ntrn,:].t().to('cuda')\n",
        "  \n",
        "  for j in range(3):\n",
        "\n",
        "    Xtrn *= theta[j]*np.sqrt(Ntrn)/Xtrn.norm()\n",
        "    Xtst *= theta[j]*np.sqrt(Ntst)/Xtst.norm()\n",
        "\n",
        "    print(Ntrn)\n",
        "    print(Xtrn.shape)\n",
        "    T = 100\n",
        "    for k in tqdm(range(T)):\n",
        "      A = torch.randn_like(Xtrn)/np.sqrt(m)\n",
        "      Atst = torch.randn_like(Xtst)/np.sqrt(m)\n",
        "\n",
        "      Y = Xtrn + A\n",
        "      Ytst = Xtst + Atst\n",
        "\n",
        "      W = torch.mm(Xtrn,Y.pinverse())\n",
        "\n",
        "      emperical_norm[j,i] += (W.norm().square())/T\n",
        "      emperical_error[j,i] += (Xtst - W.mm(Ytst)).norm().square()/(T*Ntst)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5J8busDG9XdH"
      },
      "source": [
        "# More Indepth Figure For MNIST and CIFAR10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xDoWSQkE9lzw"
      },
      "source": [
        "Show the example for cifar10. Just replace the dataset for MNIST"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "M6M4D9Vv9tcj"
      },
      "source": [
        "cifar_train = torchvision.datasets.CIFAR10('./CIFAR_data', train=True, download=True,\n",
        "                           transform=Tranforms.ToTensor())\n",
        "cifar_test = torchvision.datasets.CIFAR10('./CIFAR_data', train=False, download=True,\n",
        "                           transform=Tranforms.ToTensor())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Vk3QoMzs9bMu"
      },
      "source": [
        "Ndata = [10,20,30,40,50,100,200,300,500,750,1000,1250,1500,1750,2000,2500,3000,4000,5000,7500,10000,12500,15000]\n",
        "\n",
        "for Ntrn in Ndata:\n",
        "    X = torch.tensor(cifar_train.data)\n",
        "    y = torch.tensor(cifar_train.targets)\n",
        "    X = X.to(device).to(torch.float32)\n",
        "    \n",
        "    print(X.max())\n",
        "\n",
        "    Xtst = torch.tensor(cifar_test.data)\n",
        "    Xtst = Xtst.to(device).to(torch.float32)\n",
        "\n",
        "    Xtrn = X[:Ntrn,:,:].to(device).to(torch.float32)\n",
        "\n",
        "    Ntst = Xtst.shape[0]\n",
        "    psi_norm = 0.1\n",
        "    psi = psi_norm*np.sqrt(Ntst)\n",
        "\n",
        "    Xtrn = Xtrn.reshape(Ntrn,-1)\n",
        "    Xtst = Xtst.reshape(Ntst,-1)\n",
        "    \n",
        "    print(Xtrn.shape, Xtst.shape)\n",
        "\n",
        "    T = 5\n",
        "    N = 20\n",
        "\n",
        "    avg_error = torch.zeros((N,T), device = device)\n",
        "    theta = torch.zeros(N, device = device)\n",
        "    w_norm = torch.zeros((N,T), device = device)\n",
        "\n",
        "    Xtrn *= np.sqrt(Ntrn)/Xtrn.norm()\n",
        "    Xtst *= np.sqrt(Ntst)/Xtst.norm()\n",
        "\n",
        "    for i in tqdm(range(N)):\n",
        "        theta[i] = psi_norm*((i+1)/N)\n",
        "  \n",
        "        for j in range(T):\n",
        "            a, w = do_sim(theta[i], psi_norm, Xtrn, Xtst, bias = False)\n",
        "            avg_error[i,j] = a\n",
        "            w_norm[i, j] = w\n",
        "    \n",
        "    number = '%s.pt'%Ntrn\n",
        "    torch.save(avg_error,\"Cifar100-2layer-relu/error-without-bias-\"+number)\n",
        "    torch.save(theta,\"Cifar100-2layer-relu/theta-without-bias-\"+number)\n",
        "    torch.save(w_norm,\"Cifar100-2layer-relu/wnorm-without-bias-\"+number)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}