{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "20eb9fb8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum value of rho\n",
      "2.771707283575233e-07\n",
      "Maximum value of kappa\n",
      "1.222033585435156e-11\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import random\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "\n",
    "sing_st = torch.zeros(30,3,20)     ##### stores singular values\n",
    "eng_st1 = torch.zeros(30,1)        ##### stores magnitude of positive entries \n",
    "ct = 0\n",
    "torch.set_printoptions(precision=10)\n",
    "torch.set_default_tensor_type(torch.DoubleTensor)\n",
    "for kk in range(100):\n",
    "    \n",
    "    random.seed(kk)\n",
    "    torch.manual_seed(kk)\n",
    "    np.random.seed(kk)\n",
    "    flg = 0\n",
    "    dx = 10;     ### Input data dimension\n",
    "    dy = 1       ### Scalar valued output\n",
    "    N = 100;     ### Number of training datapoints\n",
    "    X = torch.randn(dx,N);\n",
    "    X = torch.nn.functional.normalize(X,p=2,dim=0)    ### Input data sampled from unit-norm sphere \n",
    "    y = torch.randn(N,dy);       ### Randomly generated output\n",
    "\n",
    "    # Define 2-layer the neural network \n",
    "    class Net(nn.Module): \n",
    "        def __init__(self, H1): \n",
    "            super(Net, self).__init__() \n",
    "            self.fc1 = nn.Linear(dx,H1,bias=False)\n",
    "            self.fc2 = nn.Linear(H1,dy,bias=False)\n",
    "            self.activation = nn.LeakyReLU(0.1)\n",
    "\n",
    "        def forward(self, x): \n",
    "            x = self.fc1(x)\n",
    "            x = self.activation(x)\n",
    "            x = self.fc2(x)\n",
    "            return x\n",
    "\n",
    "    # # ############# NCF gradient training\n",
    "\n",
    "    N_hid1 = 20      ### Number of hidden neurons\n",
    "\n",
    "    model = Net(N_hid1)   ### Initializing neural network\n",
    "    learning_rate = 0.01  \n",
    "    num_epochs = 150000\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
    "    loss_fn = nn.MSELoss()\n",
    "\n",
    "    init_u1 = torch.randn(dx,N_hid1);\n",
    "    init_u2 = torch.randn(N_hid1,dy);\n",
    "    nm_tot = torch.sqrt(torch.linalg.matrix_norm(init_u1)**2 + torch.linalg.matrix_norm(init_u2)**2 )\n",
    "    init_u1 = init_u1/nm_tot\n",
    "    init_u2 = init_u2/nm_tot\n",
    "    with torch.no_grad():           ### Unit-norm initializations of the weights\n",
    "        model.fc1.weight.data = init_u1.clone().T\n",
    "        model.fc2.weight.data = init_u2.clone().T\n",
    "\n",
    "    epochs = 0\n",
    "    ##### Checking if the NCF at initialization is positive.\n",
    "    pred_init = model(X.T).detach();\n",
    "    init_loss = torch.sum(pred_init*y)\n",
    "    if init_loss.item() < 0.000001:\n",
    "        continue\n",
    "        \n",
    "#     ############  storing indices for perfoming mini-batch SGD \n",
    "    ind_all = np.zeros((num_epochs,16),dtype=int)\n",
    "    for i in range(num_epochs):\n",
    "        ind_all[i,:] = np.random.randint(1, 100, 16).tolist()\n",
    "    st_1 = torch.zeros(num_epochs)\n",
    "    \n",
    "    ######### Training begins\n",
    "    while epochs < num_epochs:\n",
    "        Bm = torch.zeros(N,N)\n",
    "        for index in range(16):\n",
    "            Bm[ind_all[epochs,index], ind_all[epochs,index]] = 1\n",
    "            \n",
    "        pred = model(X.T);\n",
    "        loss = -torch.sum((Bm@pred)*(Bm@y))\n",
    "            \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward(retain_graph=True)\n",
    "        optimizer.step()\n",
    "        ###### Normalizing weights\n",
    "        U1 = model.fc1.weight.data.detach()\n",
    "        U2 = model.fc2.weight.data.detach()\n",
    "        nm_tot = torch.sqrt(torch.linalg.matrix_norm(U1)**2 + torch.linalg.matrix_norm(U2)**2)\n",
    "        U1 = U1/nm_tot\n",
    "        U2 = U2/nm_tot\n",
    "        model.fc1.weight.data = U1\n",
    "        model.fc2.weight.data = U2\n",
    "        epochs = epochs+1\n",
    "              \n",
    "\n",
    "    P1,S1,Q1 = torch.linalg.svd(U1)\n",
    "    Q1=Q1.T\n",
    "    sing_st[ct,0,0:dx] = S1   ##### storing singular values of the hidden weights \n",
    "\n",
    "    U1_r = S1[0]*P1[:,0:1]*Q1[:,0:1].T ### rank-one approximation of the hidden weights\n",
    "    b = torch.argmax(torch.abs(P1[:,0:1]))\n",
    "    q = torch.sign(P1[b,0:1])\n",
    "    eng_st1[ct,0] = torch.norm(torch.relu(-q*P1[:,0:1]))/torch.norm(P1[:,0:1])  ### computing the relative magnitude\n",
    "                                                                                ### of negative entries\n",
    "    ct = ct + 1;\n",
    "    if ct >= 30:\n",
    "        break\n",
    "\n",
    "eng_st1 = np.float64(eng_st1)\n",
    "print(\"Maximum value of rho\")\n",
    "print(np.max(eng_st1))    ### printing maximum value of relative magnitude of negative entries\n",
    "\n",
    "\n",
    "sing_st = np.float64(sing_st)\n",
    "k1 = 1 - sing_st[0:30,0,0]/np.sqrt(np.sum(sing_st[0:30,0,:]**2,1))\n",
    "print(\"Maximum value of kappa\")    ### printing maximum value of 1 - stable rank\n",
    "print(np.max([np.max(k1)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed0a337c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
