{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "sizes=[16,32,64,128,256,512, 1024, 2048,4096]\n",
    "import argparse\n",
    "import copy\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import nn_utils\n",
    "from nn_utils import train, test\n",
    "from fc import FC2, FC1\n",
    "from id_utils import pruneID, compare_prune, getID\n",
    "from data_utils import fashionmnist_loader\n",
    "from plt_utils import plt_IDerr\n",
    "import torch.nn.utils.prune as prune\n",
    "\n",
    "\n",
    "\n",
    "import scipy\n",
    "import matplotlib.pyplot as plt \n",
    "%matplotlib inline\n",
    "import torch.nn.functional as F\n",
    "class arguments():\n",
    "    def __init__(self):\n",
    "        self.batch_size=64\n",
    "        self.test_batch_size=1000\n",
    "        self.lr=.2\n",
    "        self.epochs=200\n",
    "        self.log_interval=10\n",
    "        self.verbose=False\n",
    "        self.prune_batch_size=10000\n",
    "        self.dataset=\"fashion-mnist\"\n",
    "        self.use_valid=True\n",
    "        self.ft_proportion=0\n",
    "        self.fine_tune=False\n",
    "args=arguments()\n",
    "ptrain_loader, ptest_loader, pruneLoader = fashionmnist_loader(args)\n",
    "hidden1=1024\n",
    "hidden2=1024\n",
    "data_feats=28*28\n",
    "out_feats=10\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "#model_full = FC2([data_feats, hidden1,hidden2, out_feats], [True, True, False])\n",
    "#optimizer = optim.SGD(model_full.parameters(), lr=args.lr, weight_decay=.0001)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "\n",
    "def pruneSimpleBias(model, k):\n",
    "    X_prune, _ = next(iter(pruneLoader))\n",
    "    newModel=copy.deepcopy(model)\n",
    "    layers=newModel.layers\n",
    "    biases=newModel.biases\n",
    "    X=X_prune\n",
    "    for layer in range(0,len(layers)-1):\n",
    "        Z=newModel.getZ(torch.Tensor(X),layer).detach().numpy()\n",
    "        _, R, P = scipy.linalg.qr(Z, pivoting=True)\n",
    "        W=layers[layer].weight.T\n",
    "        Wk = W[:, P[0:k[layer+1]]]\n",
    "        b=layers[layer].bias\n",
    "        bk=b[P[0:k[layer+1]]]\n",
    "        T = np.concatenate((\n",
    "        np.identity(k[layer+1]),\n",
    "        np.linalg.pinv(R[0:k[layer+1], 0:k[layer+1]]) @ R[0:k[layer+1], k[layer+1]:None]\n",
    "        ), axis=1)\n",
    "        T = T[:, np.argsort(P)]\n",
    "        Wk = nn.Parameter(Wk.T, requires_grad=True)\n",
    "        bk=nn.Parameter(bk, requires_grad=True)\n",
    "        T  = torch.Tensor(T)\n",
    "        newModel.layers[layer].weight=Wk\n",
    "        newModel.layers[layer].bias=bk\n",
    "        newModel.layers[layer+1].weight=nn.Parameter(layers[layer+1].weight@T.T, requires_grad=True)\n",
    "        \n",
    "        X=Z\n",
    "        #X=newModel.getZ(torch.Tensor(X),layer).detach().numpy()\n",
    "    return newModel\n",
    "\n",
    "epochs=50\n",
    "\n",
    "\n",
    "import pickle\n",
    "\n",
    "\n",
    "import torch.nn.utils.prune as prune\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model_full=FC1(data_feats, 4096, out_feats, True, False)\n",
    "lr=.3\n",
    "epochs=30\n",
    "for epoch in range(1,epochs+1):\n",
    "    optimizer = optim.SGD(model_full.parameters(), lr=lr)\n",
    "    train(args, model_full, device, ptrain_loader, \n",
    "        criterion, optimizer, epoch)\n",
    "    lr=lr*.9\n",
    "    print(test(args , model_full, device, ptrain_loader, \n",
    "            criterion, epoch, returnAcc=True))\n",
    "    print(test(args , model_full, device, pruneLoader, \n",
    "            criterion, epoch, returnAcc=True))\n",
    "    print(test(args , model_full, device, ptest_loader, \n",
    "            criterion, epoch, returnAcc=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(test(args , model_full, device, ptest_loader, \n",
    "            criterion, epoch, returnAcc=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_prune, _ = next(iter(pruneLoader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X=X_prune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "RT=0\n",
    "def pruneSimpleBias(model, k):\n",
    "    X_prune, _ = next(iter(pruneLoader))\n",
    "    newModel=copy.deepcopy(model)\n",
    "    layers=newModel.layers\n",
    "    biases=newModel.biases\n",
    "    X=X_prune\n",
    "    for layer in range(0,len(layers)-1):\n",
    "        Z=newModel.getZ(torch.Tensor(X),layer).detach().numpy()\n",
    "        R, P = scipy.linalg.qr(Z, pivoting=True, mode='r')\n",
    "        \n",
    "        W=layers[layer].weight.T\n",
    "        Wk = W[:, P[0:k[layer+1]]]\n",
    "        b=layers[layer].bias\n",
    "        bk=b[P[0:k[layer+1]]]\n",
    "        T = np.concatenate((\n",
    "        np.identity(k[layer+1]),\n",
    "        np.linalg.pinv(R[0:k[layer+1], 0:k[layer+1]]) @ R[0:k[layer+1], k[layer+1]:None]\n",
    "        ), axis=1)\n",
    "        T = T[:, np.argsort(P)]\n",
    "        Wk = nn.Parameter(Wk.T, requires_grad=True)\n",
    "        bk=nn.Parameter(bk, requires_grad=True)\n",
    "        T  = torch.Tensor(T)\n",
    "        newModel.layers[layer].weight=Wk\n",
    "        newModel.layers[layer].bias=bk\n",
    "        newModel.layers[layer+1].weight=nn.Parameter(layers[layer+1].weight@T.T, requires_grad=True)\n",
    "        \n",
    "        X=Z\n",
    "        #X=newModel.getZ(torch.Tensor(X),layer).detach().numpy()\n",
    "    return newModel, R\n",
    "sizes=[4,6,8,10,12,14,16,20,24,28,32,48,64,96,128,192,256,384,512, 768,1024,1536, 2048,2560,3072,3584,3840, 4096]\n",
    "pruned=[]\n",
    "tested=[]\n",
    "trained=[]\n",
    "for size in sizes:\n",
    "    print(size)\n",
    "    model, R=pruneSimpleBias(model_full,  [data_feats,size,  out_feats])\n",
    "    pruned.append(test(args , model, device, pruneLoader, \n",
    "            criterion, epoch, returnAcc=True))\n",
    "    tested.append(test(args , model, device, ptest_loader, \n",
    "            criterion, epoch, returnAcc=True))\n",
    "    trained.append(test(args , model, device, ptrain_loader, \n",
    "            criterion, epoch, returnAcc=True))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rnorm=[]\n",
    "for size in sizes[:-1]:\n",
    "    rnorm.append(scipy.linalg.norm(R[size:, size:], ord=2))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rnorm.append(0)\n",
    "text='k, ||r22||_2'\n",
    "for i in range(0,len(rnorm)):\n",
    "    text+=\"\\n\"\n",
    "    text+=str(sizes[i])+\",\"\n",
    "    text+=str(rnorm[i])\n",
    "print(text)\n",
    "file=open(\"rnorm.txt\", 'w')\n",
    "file.write(text)\n",
    "file.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "svd=scipy.linalg.svd(R, compute_uv=False )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plt.semilogy(svd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,3))\n",
    "plt.xlabel(\"k\", fontsize=12)\n",
    "plt.loglog(np.linspace(1,4096, 4096),svd/svd[0], label=\"Normalized SVD\")\n",
    "plt.loglog(sizes,rnorm/svd[0], label=r\"$||R_{22}||/||R||$\")\n",
    "plt.plot(np.linspace(1,4096, 4096),np.abs(np.diag(R/R[0,0])), label=r\"$|r_{k+1,k+1}/r_{11}|$\")\n",
    "plt.legend(fontsize=12)\n",
    "plt.xticks(fontsize=12)\n",
    "plt.yticks(fontsize=12)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"metrics.pdf\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.linspace(1,4096, 4096)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "plt.figure()\n",
    "plt.loglog(sizes, rnorm)\n",
    "plt.semilogy(svd)\n",
    "\n",
    "#plt.loglog(sizes, pruned)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "pruned=np.array(pruned)\n",
    "tested=np.array(tested)\n",
    "trained=np.array(trained)\n",
    "plt.loglog(sizes, pruned[:,0], '.')\n",
    "plt.loglog(sizes, tested[:,0], '.')\n",
    "plt.loglog(sizes, trained[:,0], '.')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "pruned=np.array(pruned)\n",
    "tested=np.array(tested)\n",
    "trained=np.array(trained)\n",
    "plt.semilogx(sizes, pruned[:,1], '.')\n",
    "plt.semilogx(sizes, tested[:,1], '.')\n",
    "plt.semilogx(sizes, trained[:,1], '.')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(epoch)(0.13046107614529617, 0.95646)\n",
    "(0.2667423486709595, 0.9058)\n",
    "(0.3007843181490898, 0.8954)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,3))\n",
    "rdiag=np.diag(R)\n",
    "plt.semilogx(np.abs(np.diag(R/R[0,0]))[sizes[:-1]], (trained[:-1,1])*100, 'b.',label=\"Training set\")\n",
    "plt.semilogx(np.abs(np.diag(R/R[0,0]))[sizes[:-1]],( tested[:-1,1])*100, 'r.',label=\"Test set\")\n",
    "plt.semilogx(np.abs(np.diag(R/R[0,0]))[sizes[:-1]],( pruned[:-1,1])*100, 'k.',label=\"Pruning set\")\n",
    "plt.hlines( 0.95646*100, np.min(np.abs(np.diag(R/R[0,0]))[sizes[:-1]]), np.max(np.abs(np.diag(R/R[0,0]))[sizes[:-1]]), colors='b')\n",
    "plt.hlines(0.8954*100, np.min(np.abs(np.diag(R/R[0,0]))[sizes[:-1]]), np.max(np.abs(np.diag(R/R[0,0]))[sizes[:-1]]), colors='r')\n",
    "plt.hlines(0.9058*100, np.min(np.abs(np.diag(R/R[0,0]))[sizes[:-1]]), np.max(np.abs(np.diag(R/R[0,0]))[sizes[:-1]]), colors='k')\n",
    "\n",
    "#plt.title(\"accuracy loss from pruning \")\n",
    "plt.xlabel(r\"$|r_{k+1,k+1}/r_{11}|$\",fontsize=12)\n",
    "plt.ylabel(\"Accuracy (%)\",fontsize=12)\n",
    "plt.xlim(10**-2, 10**0)\n",
    "plt.legend(fontsize=12)\n",
    "plt.xticks(fontsize=12)\n",
    "plt.yticks(fontsize=12)\n",
    "#plt.grid(color = 'black', linestyle = '--', linewidth = 0.3)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"accuracyvsrkk.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
