{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf396b00",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import random\n",
    "\n",
    "sing_st = torch.zeros(30,3,20)     ##### stores singular values\n",
    "eng_st1 = torch.zeros(30,1)        ##### stores magnitude of positive entries \n",
    "eng_st2 = torch.zeros(30,1)\n",
    "ct = 0\n",
    "torch.set_printoptions(precision=10)\n",
    "torch.set_default_tensor_type(torch.DoubleTensor)\n",
    "for kk in range(100):\n",
    "    \n",
    "    random.seed(kk)\n",
    "    torch.manual_seed(kk)\n",
    "    np.random.seed(kk)\n",
    "    flg = 0\n",
    "    dx = 10;     ### Input data dimension\n",
    "    dy = 1       ### Scalar valued output\n",
    "    N = 100;     ### Number of training datapoints\n",
    "    X = torch.randn(dx,N);\n",
    "    X = torch.nn.functional.normalize(X,p=2,dim=0)    ### Input data sampled from unit-norm sphere \n",
    "    y = torch.randn(N,dy);       ### Randomly generated output\n",
    "\n",
    "    class sq_relu(nn.Module): \n",
    "        def __init__(self): \n",
    "            super(sq_relu, self).__init__() \n",
    "\n",
    "        def forward(self, x): \n",
    "            rl = x\n",
    "            return rl*rl\n",
    "\n",
    "    # Define the neural network \n",
    "    class Net(nn.Module): \n",
    "        def __init__(self, H1,H2): \n",
    "            super(Net, self).__init__() \n",
    "            self.fc1 = nn.Linear(dx,H1,bias=False)\n",
    "            self.fc2 = nn.Linear(H1,H2,bias=False)\n",
    "            self.fc3 = nn.Linear(H2,dy,bias=False)\n",
    "            self.activation = sq_relu()\n",
    "\n",
    "        def forward(self, x): \n",
    "            x = self.fc1(x)\n",
    "            x = self.activation(x)\n",
    "            x = self.fc2(x)\n",
    "            x = self.activation(x)\n",
    "            x = self.fc3(x)\n",
    "            return x\n",
    "\n",
    "    # # ############# NCF gradient training\n",
    "\n",
    "    N_hid1 = 20      ### Number of hidden neurons\n",
    "    N_hid2 = 20\n",
    "\n",
    "    model = Net(N_hid1,N_hid2)   ### Initializing neural network\n",
    "    learning_rate = 0.05\n",
    "    num_epochs = 200000\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
    "    loss_fn = nn.MSELoss()\n",
    "    \n",
    "    init_u1 = torch.randn(dx,N_hid1);\n",
    "    init_u2 = torch.randn(N_hid1,N_hid2);\n",
    "    init_u3 = torch.randn(N_hid2,dy);\n",
    "    nm_tot = torch.sqrt(torch.linalg.matrix_norm(init_u1)**2 + torch.linalg.matrix_norm(init_u2)**2 + torch.linalg.matrix_norm(init_u3)**2)\n",
    "    init_u1 = init_u1/nm_tot\n",
    "    init_u2 = init_u2/nm_tot\n",
    "    init_u3 = init_u3/nm_tot\n",
    "    with torch.no_grad():        ### Unit-norm initializations of the weights\n",
    "        model.fc1.weight.data = init_u1.clone().T\n",
    "        model.fc2.weight.data = init_u2.clone().T\n",
    "        model.fc3.weight.data = init_u3.clone().T\n",
    "    epochs = 0\n",
    "    ##### Checking if the NCF at initialization is positive.\n",
    "    pred_init = model(X.T).detach();\n",
    "    init_loss = torch.sum(pred_init*y)\n",
    "    if init_loss.item() < 1e-8:\n",
    "        continue\n",
    "        \n",
    "    ######### Training begins\n",
    "    while epochs < num_epochs:\n",
    "        pred = model(X.T);\n",
    "        loss = -torch.sum(pred*y)\n",
    "            \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward(retain_graph=True)\n",
    "        \n",
    "        ###### computing gradient of NCF and its alignment with weights\n",
    "        g1 = model.fc1.weight.grad.detach()\n",
    "        g2 = model.fc2.weight.grad.detach()\n",
    "        g3 = model.fc3.weight.grad.detach()\n",
    "        nm_gd = torch.sqrt(torch.linalg.matrix_norm(g1)**2+torch.linalg.matrix_norm(g2)**2+torch.linalg.matrix_norm(g3)**2)\n",
    "        g1 = g1/nm_gd\n",
    "        g2 = g2/nm_gd\n",
    "        g3 = g3/nm_gd\n",
    "        dc_mt = -(7*loss.item()/(nm_gd))\n",
    "        if dc_mt>1-1e-10:\n",
    "            break\n",
    "        \n",
    "        optimizer.step()\n",
    "        ###### Normalizing weights\n",
    "        U1 = model.fc1.weight.data.detach()\n",
    "        U2 = model.fc2.weight.data.detach()\n",
    "        U3 = model.fc3.weight.data.detach()\n",
    "        nm_tot = torch.sqrt(torch.linalg.matrix_norm(U1)**2 + torch.linalg.matrix_norm(U2)**2 + torch.linalg.matrix_norm(U3)**2)\n",
    "        U1 = U1/nm_tot\n",
    "        U2 = U2/nm_tot\n",
    "        U3 = U3/nm_tot\n",
    "        model.fc1.weight.data = U1\n",
    "        model.fc2.weight.data = U2\n",
    "        model.fc3.weight.data = U3\n",
    "        epochs = epochs+1\n",
    "              \n",
    "\n",
    "    P1,S1,Q1 = torch.linalg.svd(U1)\n",
    "    Q1=Q1.T\n",
    "    P2,S2,Q2 = torch.linalg.svd(U2)\n",
    "    Q2=Q2.T\n",
    "    sing_st[ct,0,0:dx] = S1\n",
    "    sing_st[ct,1,0:20] = S2   ##### storing singular values of the hidden weights \n",
    "\n",
    "    ct = ct + 1;\n",
    "    if ct >= 30:\n",
    "        break\n",
    "\n",
    "\n",
    "sing_st = np.float64(sing_st)\n",
    "k1 = 1 - sing_st[0:30,0,0]/np.sqrt(np.sum(sing_st[0:30,0,:]**2,1))\n",
    "k2 = 1 - sing_st[0:30,1,0]/np.sqrt(np.sum(sing_st[0:30,1,:]**2,1))\n",
    "print(\"Maximum value of kappa\")    ### printing maximum value of 1 - stable rank\n",
    "print(np.max([np.max(k1),np.max(k2)]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
