{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3ed44999",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e57ab430",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from datautils import *\n",
    "from database import *\n",
    "from modelutils import *\n",
    "from quant import *\n",
    "import time\n",
    "from timm.optim import Lamb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e8fe2fac",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
     ]
    }
   ],
   "source": [
    "dataloader, testloader = get_loaders(\n",
    "    \"imagenet\", path=\"\",\n",
    "    batchsize=-1, workers=8,\n",
    "    nsamples=1024, seed=0,\n",
    "    noaug=False\n",
    ")\n",
    "get_model, test, run = get_functions(\"rn50\")\n",
    "modelp = get_model()\n",
    "model_orig = get_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fd1c695e",
   "metadata": {},
   "outputs": [],
   "source": [
    "db = SparsityDatabase(\"unstr\", \"rn50\", prefix='', dev='cpu')\n",
    "modelp = modelp.to('cpu')\n",
    "layersp = find_layers(modelp)\n",
    "with open(\"rn50_unstr_400x_dp.txt\", 'r') as f:\n",
    "    config = {}\n",
    "    for l in f.readlines():\n",
    "        level, name = l.strip().split(' ')\n",
    "        config[name] = level \n",
    "db.stitch(layersp, config)\n",
    "modelp = modelp.to(DEV)\n",
    "layersp = find_layers(modelp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "99802eb5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv1.weight 0.28241921768707484\n",
      "bn1.weight 1.0\n",
      "layer1.0.conv1.weight 0.59033203125\n",
      "layer1.0.bn1.weight 1.0\n",
      "layer1.0.conv2.weight 0.1500922309027778\n",
      "layer1.0.bn2.weight 1.0\n",
      "layer1.0.conv3.weight 0.13507080078125\n",
      "layer1.0.bn3.weight 1.0\n",
      "layer1.0.downsample.0.weight 0.38739013671875\n",
      "layer1.0.downsample.1.weight 1.0\n",
      "layer1.1.conv1.weight 0.228759765625\n",
      "layer1.1.bn1.weight 1.0\n",
      "layer1.1.conv2.weight 0.1350640190972222\n",
      "layer1.1.bn2.weight 1.0\n",
      "layer1.1.conv3.weight 0.28240966796875\n",
      "layer1.1.bn3.weight 1.0\n",
      "layer1.2.conv1.weight 0.12152099609375\n",
      "layer1.2.bn1.weight 1.0\n",
      "layer1.2.conv2.weight 0.1852756076388889\n",
      "layer1.2.bn2.weight 1.0\n",
      "layer1.2.conv3.weight 0.28240966796875\n",
      "layer1.2.bn3.weight 1.0\n",
      "layer2.0.conv1.weight 0.166748046875\n",
      "layer2.0.bn1.weight 1.0\n",
      "layer2.0.conv2.weight 0.1500922309027778\n",
      "layer2.0.bn2.weight 1.0\n",
      "layer2.0.conv3.weight 0.254180908203125\n",
      "layer2.0.bn3.weight 1.0\n",
      "layer2.0.downsample.0.weight 0.15009307861328125\n",
      "layer2.0.downsample.1.weight 1.0\n",
      "layer2.1.conv1.weight 0.15008544921875\n",
      "layer2.1.bn1.weight 1.0\n",
      "layer2.1.conv2.weight 0.07976616753472222\n",
      "layer2.1.bn2.weight 1.0\n",
      "layer2.1.conv3.weight 0.0984649658203125\n",
      "layer2.1.bn3.weight 1.0\n",
      "layer2.2.conv1.weight 0.1852874755859375\n",
      "layer2.2.bn1.weight 1.0\n",
      "layer2.2.conv2.weight 0.0984768337673611\n",
      "layer2.2.bn2.weight 1.0\n",
      "layer2.2.conv3.weight 0.348663330078125\n",
      "layer2.2.bn3.weight 1.0\n",
      "layer2.3.conv1.weight 0.07177734375\n",
      "layer2.3.bn1.weight 1.0\n",
      "layer2.3.conv2.weight 0.13508436414930555\n",
      "layer2.3.bn2.weight 1.0\n",
      "layer2.3.conv3.weight 0.3874053955078125\n",
      "layer2.3.bn3.weight 1.0\n",
      "layer3.0.conv1.weight 0.2824249267578125\n",
      "layer3.0.bn1.weight 1.0\n",
      "layer3.0.conv2.weight 0.2541859944661458\n",
      "layer3.0.bn2.weight 1.0\n",
      "layer3.0.conv3.weight 0.3486747741699219\n",
      "layer3.0.bn3.weight 1.0\n",
      "layer3.0.downsample.0.weight 0.12157630920410156\n",
      "layer3.0.downsample.1.weight 1.0\n",
      "layer3.1.conv1.weight 0.12157440185546875\n",
      "layer3.1.bn1.weight 1.0\n",
      "layer3.1.conv2.weight 0.1500939263237847\n",
      "layer3.1.bn2.weight 1.0\n",
      "layer3.1.conv3.weight 0.3486747741699219\n",
      "layer3.1.bn3.weight 1.0\n",
      "layer3.2.conv1.weight 0.3138084411621094\n",
      "layer3.2.bn1.weight 1.0\n",
      "layer3.2.conv2.weight 0.0984768337673611\n",
      "layer3.2.bn2.weight 1.0\n",
      "layer3.2.conv3.weight 0.3138084411621094\n",
      "layer3.2.bn3.weight 1.0\n",
      "layer3.3.conv1.weight 0.3138084411621094\n",
      "layer3.3.bn1.weight 1.0\n",
      "layer3.3.conv2.weight 0.1500939263237847\n",
      "layer3.3.bn2.weight 1.0\n",
      "layer3.3.conv3.weight 0.47829437255859375\n",
      "layer3.3.bn3.weight 1.0\n",
      "layer3.4.conv1.weight 0.07178878784179688\n",
      "layer3.4.bn1.weight 1.0\n",
      "layer3.4.conv2.weight 0.18530103895399305\n",
      "layer3.4.bn2.weight 1.0\n",
      "layer3.4.conv3.weight 0.4304656982421875\n",
      "layer3.4.bn3.weight 1.0\n",
      "layer3.5.conv1.weight 0.3874168395996094\n",
      "layer3.5.bn1.weight 1.0\n",
      "layer3.5.conv2.weight 0.1500939263237847\n",
      "layer3.5.bn2.weight 1.0\n",
      "layer3.5.conv3.weight 0.4304656982421875\n",
      "layer3.5.bn3.weight 1.0\n",
      "layer4.0.conv1.weight 0.4304656982421875\n",
      "layer4.0.bn1.weight 1.0\n",
      "layer4.0.conv2.weight 0.34867816501193577\n",
      "layer4.0.bn2.weight 1.0\n",
      "layer4.0.conv3.weight 0.47829627990722656\n",
      "layer4.0.bn3.weight 1.0\n",
      "layer4.0.downsample.0.weight 0.3486781120300293\n",
      "layer4.0.downsample.1.weight 1.0\n",
      "layer4.1.conv1.weight 0.3874197006225586\n",
      "layer4.1.bn1.weight 1.0\n",
      "layer4.1.conv2.weight 0.4304669698079427\n",
      "layer4.1.bn2.weight 1.0\n",
      "layer4.1.conv3.weight 0.5314407348632812\n",
      "layer4.1.bn3.weight 1.0\n",
      "layer4.2.conv1.weight 0.5314407348632812\n",
      "layer4.2.bn1.weight 1.0\n",
      "layer4.2.conv2.weight 0.38742023044162327\n",
      "layer4.2.bn2.weight 1.0\n",
      "layer4.2.conv3.weight 0.5314407348632812\n",
      "layer4.2.bn3.weight 1.0\n",
      "fc.weight 1.0\n"
     ]
    }
   ],
   "source": [
    "total_nz = 0\n",
    "\n",
    "for n, p in modelp.named_parameters():\n",
    "    if \"weight\" not in n:\n",
    "        continue\n",
    "    print(n, (p != 0).sum().item() / p.numel())\n",
    "    total_nz += (p != 0).sum().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1afe1ed7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating ...\n",
      "72.63\n"
     ]
    }
   ],
   "source": [
    "test(modelp, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bb66c0dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "0 1\n",
      "0 2\n",
      "0 3\n",
      "0 4\n",
      "0 5\n",
      "0 6\n",
      "0 7\n",
      "1 0\n",
      "1 1\n",
      "1 2\n",
      "1 3\n",
      "1 4\n",
      "1 5\n",
      "1 6\n",
      "1 7\n",
      "2 0\n",
      "2 1\n",
      "2 2\n",
      "2 3\n",
      "2 4\n",
      "2 5\n",
      "2 6\n",
      "2 7\n",
      "3 0\n",
      "3 1\n",
      "3 2\n",
      "3 3\n",
      "3 4\n",
      "3 5\n",
      "3 6\n",
      "3 7\n",
      "4 0\n",
      "4 1\n",
      "4 2\n",
      "4 3\n",
      "4 4\n",
      "4 5\n",
      "4 6\n",
      "4 7\n",
      "5 0\n",
      "5 1\n",
      "5 2\n",
      "5 3\n",
      "5 4\n",
      "5 5\n",
      "5 6\n",
      "5 7\n",
      "6 0\n",
      "6 1\n",
      "6 2\n",
      "6 3\n",
      "6 4\n",
      "6 5\n",
      "6 6\n",
      "6 7\n",
      "7 0\n",
      "7 1\n",
      "7 2\n",
      "7 3\n",
      "7 4\n",
      "7 5\n",
      "7 6\n",
      "7 7\n",
      "8 0\n",
      "8 1\n",
      "8 2\n",
      "8 3\n",
      "8 4\n",
      "8 5\n",
      "8 6\n",
      "8 7\n",
      "9 0\n",
      "9 1\n",
      "9 2\n",
      "9 3\n",
      "9 4\n",
      "9 5\n",
      "9 6\n",
      "9 7\n"
     ]
    }
   ],
   "source": [
    "handles = []\n",
    "\n",
    "def add_batch(layer, inp, out):\n",
    "    layer.batches = [(inp[0].detach(), out.detach())]\n",
    "    X = inp[0].detach().float()\n",
    "    #print(X.shape)\n",
    "    #assert X.shape[2] == 1\n",
    "    # TODO: unfold\n",
    "    #X = X.permute(0, 2, 3, 1)\n",
    "    #X = X.reshape(-1, X.shape[-1])\n",
    "    if isinstance(layer, nn.Conv2d):\n",
    "        unfold = nn.Unfold(\n",
    "            layer.kernel_size,\n",
    "            dilation=layer.dilation,\n",
    "            padding=layer.padding,\n",
    "            stride=layer.stride\n",
    "        )\n",
    "        X = unfold(X)\n",
    "        X = X.permute([1, 0, 2])\n",
    "        X = X.flatten(1)\n",
    "    layer.XX += X.matmul(X.T)\n",
    "\n",
    "for n, m in model_orig.named_modules():\n",
    "    if type(m) == nn.Conv2d:\n",
    "        Wf = m.weight.flatten(1)\n",
    "        m.XX = torch.zeros(Wf.shape[1], Wf.shape[1], device=m.weight.device)\n",
    "        handles.append(m.register_forward_hook(add_batch))\n",
    "        \n",
    "for i in range(10):\n",
    "    for j, batch in enumerate(dataloader):\n",
    "        print(i, j)\n",
    "        with torch.no_grad():\n",
    "            run(model_orig, batch)\n",
    "        \n",
    "for h in handles:\n",
    "    h.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "17c99a91",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_other2(A, W, nnz, Z, U, print_sc=None, debug=False, reg=0, rho_start=0.03, iters=5, prune_iters=2):\n",
    "    XX = A.T.matmul(A)\n",
    "    norm2 = torch.diag(XX).sqrt() + 1e-8\n",
    "    An = A / norm2\n",
    "    XX = An.T.matmul(An)\n",
    "    XX += torch.diag(torch.ones_like(XX.diag())) * XX.diag().mean() * reg\n",
    "    \n",
    "    #norm2 = torch.ones_like(norm2)\n",
    "    Wnn = W# * norm2.unsqueeze(1)\n",
    "    rho = 1\n",
    "    XY = An.T.matmul(Wnn)\n",
    "    XXinv = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho)\n",
    "    XXinv2 = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho_start)\n",
    "    U = U * norm2.unsqueeze(1)\n",
    "    Z = Z * norm2.unsqueeze(1)\n",
    "    \n",
    "    #B = torch.linalg.solve(XX, XY)\n",
    "    B = XXinv2.matmul(XY + rho_start*(Z-U))\n",
    "    \n",
    "    #U = torch.zeros_like(B)\n",
    "    \n",
    "    #Z = B\n",
    "    \n",
    "    bsparsity = min(0.99, 1 - nnz/B.numel())\n",
    "    #print(\"bs\", bsparsity)\n",
    "\n",
    "\n",
    "    for itt in range(iters):\n",
    "        if itt < prune_iters:\n",
    "            cur_sparsity = bsparsity# - bsparsity * (1 - (itt + 1) / iterative_prune) ** 3\n",
    "            thres = (B+U).abs().flatten().sort()[0][int(B.numel() * cur_sparsity)]\n",
    "            mask = ((B+U).abs() > thres)\n",
    "            del thres\n",
    "\n",
    "        Z = (B + U) * mask    \n",
    "\n",
    "        U = U + (B - Z)    \n",
    "\n",
    "        B = XXinv.matmul(XY + rho*(Z-U))\n",
    "        #B = torch.linalg.solve(XX + torch.eye(XX.shape[1], device=XX.device)*rho, XY + rho*(Z-U))\n",
    "        if debug:\n",
    "            print(itt, cur_sparsity, (Z != 0).sum().item() / Z.numel())\n",
    "            print_sc(A.matmul(B / norm2.unsqueeze(1)))\n",
    "            print_sc(A.matmul(Z / norm2.unsqueeze(1)))\n",
    "            print(((An != 0).sum() + (Z != 0).sum()) / W.numel())\n",
    "            print(\"-------\")\n",
    "    if debug:\n",
    "        print(\"opt end\")\n",
    "\n",
    "    return Z / norm2.unsqueeze(1), U / norm2.unsqueeze(1)    \n",
    "    \n",
    "def mag_prune(W, sp=0.6):\n",
    "    thres = (W).abs().flatten().sort()[0][int(W.numel() * sp)]\n",
    "    mask = ((W).abs() > thres)\n",
    "    return W * mask\n",
    "\n",
    "def ent(p):\n",
    "    return -(p * np.log2(p) + (1-p) * np.log2(1-p))\n",
    "\n",
    "def factorizeT(W, XX, asp=0.16, sp=0.4, iters=40):\n",
    "    #W = lx.weight.detach().T.float()\n",
    "    nza = int(W.shape[0]**2 * asp)\n",
    "    nzb = int(W.numel() * sp - nza)\n",
    "    \n",
    "    Az = torch.eye(W.shape[0], device=W.device)\n",
    "    Au = torch.zeros_like(Az)\n",
    "    norm = XX.diag().sqrt().unsqueeze(1) + 1e-8\n",
    "    norm = torch.ones_like(norm)\n",
    "       \n",
    "    Wn = W * norm\n",
    "       \n",
    "    Bz = mag_prune(Wn, (1 - nzb/2/W.numel()))\n",
    "    Bu = torch.zeros_like(Bz)\n",
    "    \n",
    "    for itt in range(iters):\n",
    "        #if itt < 10:\n",
    "        #    rho_start = 0.0\n",
    "        #elif itt < 15:\n",
    "        #    rho_start = 0.00\n",
    "        #else:\n",
    "        #    rho_start = 0.1\n",
    "        rho_start = min(1.0, itt / (iters-3))**3\n",
    "        Az, Au = (x.T for x in find_other2(Bz.T, Wn.T, nza, Az.T, Au.T, reg=1e-2, debug=False, rho_start=rho_start))\n",
    "                \n",
    "        Bz, Bu = find_other2(Az, Wn, nzb, Bz, Bu, reg=1e-2, debug=False, rho_start=rho_start)\n",
    "    \n",
    "    #print(((Az != 0).sum() + (Bz != 0).sum()).item() / W.numel(), (Az != 0).sum().item() / Az.numel(),\n",
    "    #      (Bz != 0).sum().item() / Bz.numel(), Az.shape, Bz.shape,\n",
    "    #     (Az.numel()*ent((Az != 0).sum().item() / Az.numel()) + Bz.numel()*ent((Bz != 0).sum().item() / Bz.numel())) / W.numel(), \n",
    "    #    ent(0.4), ent(0.5))\n",
    "    return ((Az / norm).matmul(Bz)).T, Bz.T, (Az / norm).T\n",
    "\n",
    "\n",
    "def factorizef(W, XX, asp=0.16, sp=0.4, iters=200, l_prev=None):\n",
    "    s_time = time.time()\n",
    "    if W.shape[0] >= W.shape[1]:\n",
    "        return factorizeT(W.T, XX, sp=sp, asp=asp, iters=iters)\n",
    "    \n",
    "    nza = int(W.shape[0]**2 * asp)\n",
    "    nzb = int(W.numel() * sp - nza)\n",
    "    norm = XX.diag().sqrt() + 1e-8\n",
    "    norm = torch.ones_like(norm)\n",
    "\n",
    "    Wn = W * norm\n",
    "    \n",
    "    Az = torch.eye(W.shape[0], device=W.device)\n",
    "    Au = torch.zeros_like(Az)\n",
    "\n",
    "    Bz = mag_prune(Wn, (1 - nzb/2/W.numel()))\n",
    "    Bu = torch.zeros_like(Bz)\n",
    "    \n",
    "    for itt in range(iters):\n",
    "        #if itt < 10:\n",
    "        #    rho_start = 0.0\n",
    "        #elif itt < 15:\n",
    "        #    rho_start = 0.00\n",
    "        #else:\n",
    "        #    rho_start = 0.1\n",
    "            \n",
    "        rho_start = min(1.0, itt / (iters-3))**3\n",
    "        Az, Au = (x.T for x in find_other2(Bz.T, Wn.T, nza, Az.T, Au.T, reg=1e-2, debug=False, rho_start=rho_start))\n",
    "                \n",
    "        Bz, Bu = find_other2(Az, Wn, nzb, Bz, Bu, reg=1e-2, debug=False, rho_start=rho_start)\n",
    "        \n",
    "        #print(itt, time.time() - s_time, end =\" \") \n",
    "        #print_scores(Az.matmul(Bz / norm))\n",
    "        \n",
    "        \n",
    "    #print(((Az != 0).sum() + (Bz != 0).sum()).item() / W.numel(), (Az != 0).sum().item() / Az.numel(),\n",
    "    #      (Bz != 0).sum().item() / Bz.numel(), Az.shape, Bz.shape,\n",
    "    #     (Az.numel()*ent((Az != 0).sum().item() / Az.numel()) + Bz.numel()*ent((Bz != 0).sum().item() / Bz.numel())) / W.numel(), \n",
    "    #    ent(0.4), ent(0.5))\n",
    "    return Az.matmul(Bz / norm), Az, Bz / norm\n",
    "\n",
    "def finalize(XXb, W, Ab, Bb):\n",
    "    fsparsity = 1 - (Ab != 0).sum() / Ab.numel()\n",
    "    mask = (Ab != 0).T\n",
    "\n",
    "    XX = Bb.matmul(XXb).matmul(Bb.T)\n",
    "    XY = Bb.matmul(XXb).matmul(W.detach().float().T)\n",
    "\n",
    "    norm2 = torch.diag(XX).sqrt() + 1e-8\n",
    "    XX = XX / norm2 / norm2.unsqueeze(1)\n",
    "    XY = XY / norm2.unsqueeze(1)\n",
    "    Ax = (Ab * norm2).T.clone()\n",
    "    #Ax = torch.linalg.solve(XX, XY)\n",
    "\n",
    "    rho = 1\n",
    "    XXinv = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho)\n",
    "    U = torch.zeros_like(Ax)\n",
    "    for itt in range(200):\n",
    "        #if itt < 150:\n",
    "        #    cur_sparsity = fsparsity - fsparsity * (1 - (itt + 1) / 150) ** 3\n",
    "        #    thres = (Ax+U).abs().flatten().sort()[0][int(Ax.numel() * cur_sparsity)]\n",
    "        #    mask = ((Ax+U).abs() > thres)\n",
    "        #    del thres\n",
    "\n",
    "        \n",
    "        Z = (Ax + U) * mask    \n",
    "\n",
    "        U = U + (Ax - Z)    \n",
    "\n",
    "        Ax = XXinv.matmul(XY + rho*(Z-U))\n",
    "\n",
    "    Ac = Z.T / norm2\n",
    "    return Ac\n",
    "\n",
    "def find_a(B, Za, Ua, rho, D, Q, E, R, XX, W):\n",
    "    F = rho*(Za-Ua) + XX.matmul(W).matmul(B.T)\n",
    "    \n",
    "    right = Q.T.matmul(F).matmul(R)\n",
    "    \n",
    "    div = D.unsqueeze(1).matmul(E.unsqueeze(0)) + rho\n",
    "    QAR = right / div\n",
    "    \n",
    "    A3 = Q.matmul(QAR).matmul(R.T)\n",
    "    return A3\n",
    "\n",
    "\n",
    "def get_at(XX, W, A, B):\n",
    "    mask = (A != 0)\n",
    "    \n",
    "    norm2 = torch.diag(XX).sqrt() + 1e-8\n",
    "    XXn = XX / norm2 / norm2.unsqueeze(1)\n",
    "    \n",
    "    #XXn += torch.diag(XXn.diag()*0 + 0.01*XXn.diag().mean())\n",
    "    \n",
    "    Wn = W * norm2.unsqueeze(1)\n",
    "    #XY = XY / norm2.unsqueeze(1)\n",
    "    \n",
    "    normB = torch.norm(B, dim=1) + 1e-8\n",
    "    Bn = B / normB.unsqueeze(1)\n",
    "    BBn = Bn.matmul(Bn.T)\n",
    "    #BBn += torch.diag(BBn.diag()*0 + 0.01*BBn.diag().mean())\n",
    "    #print(BBn.diag())\n",
    "    \n",
    "    D, Q = torch.linalg.eigh(XXn)\n",
    "    E, R = torch.linalg.eigh(BBn)\n",
    "    \n",
    "    #print(D, E)\n",
    "    \n",
    "    Za = A * norm2.unsqueeze(1) * normB\n",
    "    Ua = torch.zeros_like(Za)\n",
    "    rho = 1\n",
    "    \n",
    "    for itt in range(20):\n",
    "        A2 = find_a(Bn, Za, Ua, rho, D, Q, E, R, XXn, Wn)\n",
    "        Wx = (A2 / norm2.unsqueeze(1) / normB).matmul(B)\n",
    "        #print(itt)\n",
    "        #print(\"   errx\", (Wx - W).T.matmul(XX).matmul((Wx - W)).diag().sum().item())\n",
    "        \n",
    "        Za = (A2 + Ua) * mask\n",
    "        Ua = Ua + (A2 - Za)\n",
    "        Wx = (Za / norm2.unsqueeze(1) / normB).matmul(B)\n",
    "        #print(\"   errz\", (Wx - W).T.matmul(XX).matmul((Wx - W)).diag().sum().item())\n",
    "    return Za / norm2.unsqueeze(1) / normB\n",
    "\n",
    "def factorize(XX, W, sp, l_prev=None):\n",
    "    W = W.detach().float()\n",
    "    asp = max(0.16, sp/2)\n",
    "    W2, Ab, Bb = factorizef(W, XX, sp=sp, asp=asp, l_prev=l_prev)\n",
    "    print(\"err_prefin\", (W2 - W).matmul(XX).matmul((W2 - W).T).diag().sum().item())\n",
    "    Ac = finalize(XX, W, Ab, Bb)\n",
    "    W3 = Ac.matmul(Bb)\n",
    "    assert W3.shape == W.shape\n",
    "    print(\"err_fin   \", (W3 - W).matmul(XX).matmul((W3 - W).T).diag().sum().item())\n",
    "    #fin_b(XX, W, Ac, Bb)\n",
    "    \n",
    "    Bc = get_at(XX, W.T, Bb.T, Ac.T).T\n",
    "    \n",
    "    W4 = Ac.matmul(Bc)\n",
    "    assert W3.shape == W.shape\n",
    "    print(\"err_fin2   \", (W4 - W).matmul(XX).matmul((W4 - W).T).diag().sum().item())\n",
    "    \n",
    "    print(\"sparsity check\", ((Ac != 0).sum() + (Bc != 0).sum()).item() / W3.numel())\n",
    "    return W4, (Ac.cpu(), Bc.cpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3cd37466",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "err_prefin 216261.34375\n",
      "err_fin    18613.828125\n",
      "err_fin2    15016.3046875\n",
      "sparsity check 0.58984375\n",
      "layer1.0.conv1 0.59033203125 torch.Size([64, 64, 1, 1]) 27640.7734375 15016.3046875 \n",
      "err_prefin 5138185.5\n",
      "err_fin    833410.3125\n",
      "err_fin2    566937.25\n",
      "sparsity check 0.15003797743055555\n",
      "layer1.0.conv2 0.1500922309027778 torch.Size([64, 64, 3, 3]) 326139.6875 566937.25 bad\n",
      "err_prefin 4077767.0\n",
      "err_fin    1004332.4375\n",
      "err_fin2    930536.375\n",
      "sparsity check 0.13494873046875\n",
      "layer1.0.conv3 0.13507080078125 torch.Size([256, 64, 1, 1]) 683674.4375 930536.375 bad\n",
      "err_prefin 3806934.75\n",
      "err_fin    424953.0\n",
      "err_fin2    418320.1875\n",
      "sparsity check 0.38726806640625\n",
      "layer1.0.downsample.0 0.38739013671875 torch.Size([256, 64, 1, 1]) 556299.25 418320.1875 \n",
      "err_prefin 1717065.25\n",
      "err_fin    506788.0625\n",
      "err_fin2    268020.375\n",
      "sparsity check 0.2286376953125\n",
      "layer1.1.conv1 0.228759765625 torch.Size([64, 256, 1, 1]) 230354.375 268020.375 bad\n",
      "err_prefin 6621598.0\n",
      "err_fin    2514085.0\n",
      "err_fin2    1669178.5\n",
      "sparsity check 0.135009765625\n",
      "layer1.1.conv2 0.1350640190972222 torch.Size([64, 64, 3, 3]) 913964.25 1669178.5 bad\n",
      "err_prefin 479083.3125\n",
      "err_fin    209593.15625\n",
      "err_fin2    205432.953125\n",
      "sparsity check 0.28228759765625\n",
      "layer1.1.conv3 0.28240966796875 torch.Size([256, 64, 1, 1]) 255576.75 205432.953125 \n",
      "err_prefin 6091756.0\n",
      "err_fin    2304855.0\n",
      "err_fin2    1807763.75\n",
      "sparsity check 0.12139892578125\n",
      "layer1.2.conv1 0.12152099609375 torch.Size([64, 256, 1, 1]) 1535052.375 1807763.75 bad\n",
      "err_prefin 4440164.0\n",
      "err_fin    2749006.5\n",
      "err_fin2    1665399.375\n",
      "sparsity check 0.18522135416666666\n",
      "layer1.2.conv2 0.1852756076388889 torch.Size([64, 64, 3, 3]) 1387377.5 1665399.375 bad\n",
      "err_prefin 254454.125\n",
      "err_fin    180026.65625\n",
      "err_fin2    178256.375\n",
      "sparsity check 0.28228759765625\n",
      "layer1.2.conv3 0.28240966796875 torch.Size([256, 64, 1, 1]) 259815.3125 178256.375 \n",
      "err_prefin 10543388.0\n",
      "err_fin    3050370.25\n",
      "err_fin2    2579513.75\n",
      "sparsity check 0.16668701171875\n",
      "layer2.0.conv1 0.166748046875 torch.Size([128, 256, 1, 1]) 3653419.0 2579513.75 \n",
      "err_prefin 3423744.0\n",
      "err_fin    1708272.0\n",
      "err_fin2    1207850.0\n",
      "sparsity check 0.1500786675347222\n",
      "layer2.0.conv2 0.1500922309027778 torch.Size([128, 128, 3, 3]) 1109724.0 1207850.0 bad\n",
      "err_prefin 548572.25\n",
      "err_fin    279588.8125\n",
      "err_fin2    274631.84375\n",
      "sparsity check 0.254150390625\n",
      "layer2.0.conv3 0.254180908203125 torch.Size([512, 128, 1, 1]) 385049.59375 274631.84375 \n",
      "err_prefin 2608328.0\n",
      "err_fin    803321.125\n",
      "err_fin2    713538.25\n",
      "sparsity check 0.15007781982421875\n",
      "layer2.0.downsample.0 0.15009307861328125 torch.Size([512, 256, 1, 1]) 948079.5 713538.25 \n",
      "err_prefin 514511.25\n",
      "err_fin    190300.59375\n",
      "err_fin2    141036.484375\n",
      "sparsity check 0.150054931640625\n",
      "layer2.1.conv1 0.15008544921875 torch.Size([128, 512, 1, 1]) 173098.96875 141036.484375 \n",
      "err_prefin 3022651.5\n",
      "err_fin    582711.1875\n",
      "err_fin2    471264.28125\n",
      "sparsity check 0.07975260416666667\n",
      "layer2.1.conv2 0.07976616753472222 torch.Size([128, 128, 3, 3]) 256422.453125 471264.28125 bad\n",
      "err_prefin 610936.25\n",
      "err_fin    427276.5625\n",
      "err_fin2    409610.5625\n",
      "sparsity check 0.0984344482421875\n",
      "layer2.1.conv3 0.0984649658203125 torch.Size([512, 128, 1, 1]) 319872.875 409610.5625 bad\n",
      "err_prefin 1572108.0\n",
      "err_fin    705519.25\n",
      "err_fin2    500200.59375\n",
      "sparsity check 0.1852569580078125\n",
      "layer2.2.conv1 0.1852874755859375 torch.Size([128, 512, 1, 1]) 502911.25 500200.59375 \n",
      "err_prefin 4220247.0\n",
      "err_fin    1843608.75\n",
      "err_fin2    1436846.375\n",
      "sparsity check 0.09846327039930555\n",
      "layer2.2.conv2 0.0984768337673611 torch.Size([128, 128, 3, 3]) 1221081.75 1436846.375 bad\n",
      "err_prefin 229127.96875\n",
      "err_fin    107335.0\n",
      "err_fin2    105811.3125\n",
      "sparsity check 0.3486328125\n",
      "layer2.2.conv3 0.348663330078125 torch.Size([512, 128, 1, 1]) 197878.90625 105811.3125 \n",
      "err_prefin 7027040.0\n",
      "err_fin    3226210.0\n",
      "err_fin2    2965640.5\n",
      "sparsity check 0.071746826171875\n",
      "layer2.3.conv1 0.07177734375 torch.Size([128, 512, 1, 1]) 2963943.75 2965640.5 bad\n",
      "err_prefin 3001089.5\n",
      "err_fin    1610144.25\n",
      "err_fin2    1236240.625\n",
      "sparsity check 0.13507080078125\n",
      "layer2.3.conv2 0.13508436414930555 torch.Size([128, 128, 3, 3]) 1139468.75 1236240.625 bad\n",
      "err_prefin 131894.625\n",
      "err_fin    77573.71875\n",
      "err_fin2    76872.34375\n",
      "sparsity check 0.3873748779296875\n",
      "layer2.3.conv3 0.3874053955078125 torch.Size([512, 128, 1, 1]) 145862.40625 76872.34375 \n",
      "err_prefin 2103340.75\n",
      "err_fin    1058899.75\n",
      "err_fin2    853329.1875\n",
      "sparsity check 0.28240966796875\n",
      "layer3.0.conv1 0.2824249267578125 torch.Size([256, 512, 1, 1]) 1594296.25 853329.1875 \n",
      "err_prefin 629585.4375\n",
      "err_fin    403177.15625\n",
      "err_fin2    279877.34375\n",
      "sparsity check 0.2541826036241319\n",
      "layer3.0.conv2 0.2541859944661458 torch.Size([256, 256, 3, 3]) 380160.71875 279877.34375 \n",
      "err_prefin 242008.71875\n",
      "err_fin    90252.59375\n",
      "err_fin2    88793.34375\n",
      "sparsity check 0.3486671447753906\n",
      "layer3.0.conv3 0.3486747741699219 torch.Size([1024, 256, 1, 1]) 181254.796875 88793.34375 \n",
      "err_prefin 2288972.75\n",
      "err_fin    1299164.25\n",
      "err_fin2    1171514.5\n",
      "sparsity check 0.12157249450683594\n",
      "layer3.0.downsample.0 0.12157630920410156 torch.Size([1024, 512, 1, 1]) 1340536.25 1171514.5 \n",
      "err_prefin 755230.5\n",
      "err_fin    337813.59375\n",
      "err_fin2    269405.1875\n",
      "sparsity check 0.1215667724609375\n",
      "layer3.1.conv1 0.12157440185546875 torch.Size([256, 1024, 1, 1]) 333123.625 269405.1875 \n",
      "err_prefin 961092.875\n",
      "err_fin    475658.4375\n",
      "err_fin2    345664.375\n",
      "sparsity check 0.15009053548177084\n",
      "layer3.1.conv2 0.1500939263237847 torch.Size([256, 256, 3, 3]) 359667.21875 345664.375 \n",
      "err_prefin 106902.046875\n",
      "err_fin    57244.53125\n",
      "err_fin2    56462.9921875\n",
      "sparsity check 0.3486671447753906\n",
      "layer3.1.conv3 0.3486747741699219 torch.Size([1024, 256, 1, 1]) 107610.9921875 56462.9921875 \n",
      "err_prefin 174411.203125\n",
      "err_fin    92170.890625\n",
      "err_fin2    60751.21875\n",
      "sparsity check 0.3138008117675781\n",
      "layer3.2.conv1 0.3138084411621094 torch.Size([256, 1024, 1, 1]) 92160.5625 60751.21875 \n",
      "err_prefin 1259087.75\n",
      "err_fin    627610.1875\n",
      "err_fin2    500876.8125\n",
      "sparsity check 0.09847344292534722\n",
      "layer3.2.conv2 0.0984768337673611 torch.Size([256, 256, 3, 3]) 496333.25 500876.8125 bad\n",
      "err_prefin 103127.671875\n",
      "err_fin    58268.796875\n",
      "err_fin2    57496.8515625\n",
      "sparsity check 0.3138008117675781\n",
      "layer3.2.conv3 0.3138084411621094 torch.Size([1024, 256, 1, 1]) 111967.390625 57496.8515625 \n",
      "err_prefin 266649.0\n",
      "err_fin    149698.984375\n",
      "err_fin2    106122.46875\n",
      "sparsity check 0.3138008117675781\n",
      "layer3.3.conv1 0.3138084411621094 torch.Size([256, 1024, 1, 1]) 159325.953125 106122.46875 \n",
      "err_prefin 738317.8125\n",
      "err_fin    414787.375\n",
      "err_fin2    320820.5625\n",
      "sparsity check 0.15009053548177084\n",
      "layer3.3.conv2 0.1500939263237847 torch.Size([256, 256, 3, 3]) 340482.0625 320820.5625 \n",
      "err_prefin 24312.3046875\n",
      "err_fin    13796.65234375\n",
      "err_fin2    13658.080078125\n",
      "sparsity check 0.4782867431640625\n",
      "layer3.3.conv3 0.47829437255859375 torch.Size([1024, 256, 1, 1]) 34852.06640625 13658.080078125 \n",
      "err_prefin 2376954.25\n",
      "err_fin    1253416.625\n",
      "err_fin2    1182549.5\n",
      "sparsity check 0.07178115844726562\n",
      "layer3.4.conv1 0.07178878784179688 torch.Size([256, 1024, 1, 1]) 1264393.25 1182549.5 \n",
      "err_prefin 516310.09375\n",
      "err_fin    299541.21875\n",
      "err_fin2    224957.78125\n",
      "sparsity check 0.18529764811197916\n",
      "layer3.4.conv2 0.18530103895399305 torch.Size([256, 256, 3, 3]) 253659.6875 224957.78125 \n",
      "err_prefin 32815.6328125\n",
      "err_fin    19672.140625\n",
      "err_fin2    19472.32421875\n",
      "sparsity check 0.43045806884765625\n",
      "layer3.4.conv3 0.4304656982421875 torch.Size([1024, 256, 1, 1]) 44993.36328125 19472.32421875 \n",
      "err_prefin 190306.328125\n",
      "err_fin    118558.859375\n",
      "err_fin2    92303.6640625\n",
      "sparsity check 0.3874092102050781\n",
      "layer3.5.conv1 0.3874168395996094 torch.Size([256, 1024, 1, 1]) 168779.65625 92303.6640625 \n",
      "err_prefin 740106.375\n",
      "err_fin    437909.34375\n",
      "err_fin2    346309.375\n",
      "sparsity check 0.15009053548177084\n",
      "layer3.5.conv2 0.1500939263237847 torch.Size([256, 256, 3, 3]) 421897.1875 346309.375 \n",
      "err_prefin 48075.57421875\n",
      "err_fin    25988.125\n",
      "err_fin2    25690.767578125\n",
      "sparsity check 0.43045806884765625\n",
      "layer3.5.conv3 0.4304656982421875 torch.Size([1024, 256, 1, 1]) 59980.3125 25690.767578125 \n",
      "err_prefin 191526.46875\n",
      "err_fin    126200.765625\n",
      "err_fin2    110469.3046875\n",
      "sparsity check 0.4304618835449219\n",
      "layer4.0.conv1 0.4304656982421875 torch.Size([512, 1024, 1, 1]) 347434.3125 110469.3046875 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "err_prefin 82400.53125\n",
      "err_fin    57309.78125\n",
      "err_fin2    41713.375\n",
      "sparsity check 0.3486773173014323\n",
      "layer4.0.conv2 0.34867816501193577 torch.Size([512, 512, 3, 3]) 83596.25 41713.375 \n",
      "err_prefin 28323.662109375\n",
      "err_fin    12818.5263671875\n",
      "err_fin2    12665.55859375\n",
      "sparsity check 0.47829437255859375\n",
      "layer4.0.conv3 0.47829627990722656 torch.Size([2048, 512, 1, 1]) 37507.96875 12665.55859375 \n",
      "err_prefin 49622.78125\n",
      "err_fin    28954.6171875\n",
      "err_fin2    28306.537109375\n",
      "sparsity check 0.3486771583557129\n",
      "layer4.0.downsample.0 0.3486781120300293 torch.Size([2048, 1024, 1, 1]) 88051.0234375 28306.537109375 \n",
      "err_prefin 591586.0\n",
      "err_fin    292809.875\n",
      "err_fin2    183296.140625\n",
      "sparsity check 0.3874177932739258\n",
      "layer4.1.conv1 0.3874197006225586 torch.Size([512, 2048, 1, 1]) 348949.75 183296.140625 \n",
      "err_prefin 48011.2265625\n",
      "err_fin    34919.7109375\n",
      "err_fin2    24890.3515625\n",
      "sparsity check 0.43046612209743923\n",
      "layer4.1.conv2 0.4304669698079427 torch.Size([512, 512, 3, 3]) 50902.359375 24890.3515625 \n",
      "err_prefin 12960.853515625\n",
      "err_fin    7664.9638671875\n",
      "err_fin2    7583.9521484375\n",
      "sparsity check 0.5314388275146484\n",
      "layer4.1.conv3 0.5314407348632812 torch.Size([2048, 512, 1, 1]) 24170.521484375 7583.9521484375 \n",
      "err_prefin 655162.875\n",
      "err_fin    280272.75\n",
      "err_fin2    154152.40625\n",
      "sparsity check 0.5314388275146484\n",
      "layer4.2.conv1 0.5314407348632812 torch.Size([512, 2048, 1, 1]) 362123.3125 154152.40625 \n",
      "err_prefin 43244.6796875\n",
      "err_fin    31156.546875\n",
      "err_fin2    21775.88671875\n",
      "sparsity check 0.3874193827311198\n",
      "layer4.2.conv2 0.38742023044162327 torch.Size([512, 512, 3, 3]) 67118.84375 21775.88671875 \n",
      "err_prefin 9432.5107421875\n",
      "err_fin    5138.1865234375\n",
      "err_fin2    5057.681640625\n",
      "sparsity check 0.5314388275146484\n",
      "layer4.2.conv3 0.5314407348632812 torch.Size([2048, 512, 1, 1]) 20718.384765625 5057.681640625 \n"
     ]
    }
   ],
   "source": [
    "sd_pruned = modelp.state_dict()\n",
    "out_admm = {}\n",
    "\n",
    "for n, m in model_orig.named_modules():\n",
    "    if type(m) == nn.Conv2d and m.weight.shape[1] > 3:\n",
    "        w_pruned = sd_pruned[n+\".weight\"].flatten(1)\n",
    "        sparsity = (w_pruned != 0).sum().item() / w_pruned.numel()\n",
    "        w_orig = m.weight.flatten(1)\n",
    "        w_admm, facts = factorize(m.XX, w_orig, sparsity)\n",
    "        e1 = (w_orig - w_pruned).matmul(m.XX).matmul((w_orig - w_pruned).T).diag().sum().item()\n",
    "        e2 = (w_orig - w_admm).matmul(m.XX).matmul((w_orig - w_admm).T).diag().sum().item()\n",
    "        print(n, sparsity, m.weight.shape, \n",
    "              e1,\n",
    "              e2,\n",
    "              \"bad\" if e1 < e2 else \"\"\n",
    "             )\n",
    "        out_admm[n] = (w_admm.reshape(w_pruned.shape), facts)\n",
    "        #m.XX = None\n",
    "        \n",
    "for n, m in modelp.named_modules():\n",
    "    if n in out_admm:\n",
    "        m.weight.data = out_admm[n][0].reshape(m.weight.shape)\n",
    "        m.weight.facts = out_admm[n][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "99b05797",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating ...\n",
      "74.44\n"
     ]
    }
   ],
   "source": [
    "test(modelp, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8b943db1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batchnorm tuning ...\n",
      "0.8818617463111877\n",
      "000\n",
      "001\n",
      "002\n",
      "003\n",
      "004\n",
      "005\n",
      "006\n",
      "007\n",
      "008\n",
      "009\n",
      "010\n",
      "011\n",
      "012\n",
      "013\n",
      "014\n",
      "015\n",
      "016\n",
      "017\n",
      "018\n",
      "019\n",
      "020\n",
      "021\n",
      "022\n",
      "023\n",
      "024\n",
      "025\n",
      "026\n",
      "027\n",
      "028\n",
      "029\n",
      "030\n",
      "031\n",
      "032\n",
      "033\n",
      "034\n",
      "035\n",
      "036\n",
      "037\n",
      "038\n",
      "039\n",
      "040\n",
      "041\n",
      "042\n",
      "043\n",
      "044\n",
      "045\n",
      "046\n",
      "047\n",
      "048\n",
      "049\n",
      "050\n",
      "051\n",
      "052\n",
      "053\n",
      "054\n",
      "055\n",
      "056\n",
      "057\n",
      "058\n",
      "059\n",
      "060\n",
      "061\n",
      "062\n",
      "063\n",
      "064\n",
      "065\n",
      "066\n",
      "067\n",
      "068\n",
      "069\n",
      "070\n",
      "071\n",
      "072\n",
      "073\n",
      "074\n",
      "075\n",
      "076\n",
      "077\n",
      "078\n",
      "079\n",
      "080\n",
      "081\n",
      "082\n",
      "083\n",
      "084\n",
      "085\n",
      "086\n",
      "087\n",
      "088\n",
      "089\n",
      "090\n",
      "091\n",
      "092\n",
      "093\n",
      "094\n",
      "095\n",
      "096\n",
      "097\n",
      "098\n",
      "099\n",
      "0.8327280580997467\n"
     ]
    }
   ],
   "source": [
    "print('Batchnorm tuning ...')\n",
    "\n",
    "loss = 0\n",
    "with torch.no_grad():\n",
    "    for batch in dataloader:\n",
    "        loss += run(modelp, batch, loss=True)\n",
    "print(loss / 1024)\n",
    "\n",
    "batchnorms = find_layers(modelp, [nn.BatchNorm2d])\n",
    "for bn in batchnorms.values():\n",
    "    bn.reset_running_stats()\n",
    "    bn.momentum = .1\n",
    "modelp.train()\n",
    "with torch.no_grad():\n",
    "    i = 0\n",
    "    while i < 100:\n",
    "        for batch in dataloader:\n",
    "            if i == 100:\n",
    "                break\n",
    "            print('%03d' % i)\n",
    "            run(modelp, batch)\n",
    "            i += 1\n",
    "modelp.eval()\n",
    "\n",
    "loss = 0\n",
    "with torch.no_grad():\n",
    "    for batch in dataloader:\n",
    "        loss += run(modelp, batch, loss=True)\n",
    "print(loss / 1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "daaca80f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating ...\n",
      "74.95\n"
     ]
    }
   ],
   "source": [
    "test(modelp, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0b5404a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv1.weight 0.28241921768707484\n",
      "layer1.0.conv1.weight 0.58984375 ff\n",
      "layer1.0.conv2.weight 0.15003797743055555 ff\n",
      "layer1.0.conv3.weight 0.13494873046875 ff\n",
      "layer1.0.downsample.0.weight 0.38726806640625 ff\n",
      "layer1.0.downsample.1.weight 1.0\n",
      "layer1.1.conv1.weight 0.2286376953125 ff\n",
      "layer1.1.conv2.weight 0.135009765625 ff\n",
      "layer1.1.conv3.weight 0.28228759765625 ff\n",
      "layer1.2.conv1.weight 0.12139892578125 ff\n",
      "layer1.2.conv2.weight 0.18522135416666666 ff\n",
      "layer1.2.conv3.weight 0.28228759765625 ff\n",
      "layer2.0.conv1.weight 0.16668701171875 ff\n",
      "layer2.0.conv2.weight 0.1500786675347222 ff\n",
      "layer2.0.conv3.weight 0.254150390625 ff\n",
      "layer2.0.downsample.0.weight 0.15007781982421875 ff\n",
      "layer2.0.downsample.1.weight 1.0\n",
      "layer2.1.conv1.weight 0.150054931640625 ff\n",
      "layer2.1.conv2.weight 0.07975260416666667 ff\n",
      "layer2.1.conv3.weight 0.0984344482421875 ff\n",
      "layer2.2.conv1.weight 0.1852569580078125 ff\n",
      "layer2.2.conv2.weight 0.09846327039930555 ff\n",
      "layer2.2.conv3.weight 0.3486328125 ff\n",
      "layer2.3.conv1.weight 0.071746826171875 ff\n",
      "layer2.3.conv2.weight 0.13507080078125 ff\n",
      "layer2.3.conv3.weight 0.3873748779296875 ff\n",
      "layer3.0.conv1.weight 0.28240966796875 ff\n",
      "layer3.0.conv2.weight 0.2541826036241319 ff\n",
      "layer3.0.conv3.weight 0.3486671447753906 ff\n",
      "layer3.0.downsample.0.weight 0.12157249450683594 ff\n",
      "layer3.0.downsample.1.weight 1.0\n",
      "layer3.1.conv1.weight 0.1215667724609375 ff\n",
      "layer3.1.conv2.weight 0.15009053548177084 ff\n",
      "layer3.1.conv3.weight 0.3486671447753906 ff\n",
      "layer3.2.conv1.weight 0.3138008117675781 ff\n",
      "layer3.2.conv2.weight 0.09847344292534722 ff\n",
      "layer3.2.conv3.weight 0.3138008117675781 ff\n",
      "layer3.3.conv1.weight 0.3138008117675781 ff\n",
      "layer3.3.conv2.weight 0.15009053548177084 ff\n",
      "layer3.3.conv3.weight 0.4782867431640625 ff\n",
      "layer3.4.conv1.weight 0.07178115844726562 ff\n",
      "layer3.4.conv2.weight 0.18529764811197916 ff\n",
      "layer3.4.conv3.weight 0.43045806884765625 ff\n",
      "layer3.5.conv1.weight 0.3874092102050781 ff\n",
      "layer3.5.conv2.weight 0.15009053548177084 ff\n",
      "layer3.5.conv3.weight 0.43045806884765625 ff\n",
      "layer4.0.conv1.weight 0.4304618835449219 ff\n",
      "layer4.0.conv2.weight 0.3486773173014323 ff\n",
      "layer4.0.conv3.weight 0.47829437255859375 ff\n",
      "layer4.0.downsample.0.weight 0.3486771583557129 ff\n",
      "layer4.0.downsample.1.weight 1.0\n",
      "layer4.1.conv1.weight 0.3874177932739258 ff\n",
      "layer4.1.conv2.weight 0.43046612209743923 ff\n",
      "layer4.1.conv3.weight 0.5314388275146484 ff\n",
      "layer4.2.conv1.weight 0.5314388275146484 ff\n",
      "layer4.2.conv2.weight 0.3874193827311198 ff\n",
      "layer4.2.conv3.weight 0.5314388275146484 ff\n",
      "fc.weight 1.0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(10194502, 25506752)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_nz = 0\n",
    "total = 0\n",
    "\n",
    "for n, p in modelp.named_parameters():\n",
    "    if \"weight\" not in n or \"bn\" in n:\n",
    "        continue\n",
    "    \n",
    "    if hasattr(p, \"facts\"):\n",
    "        ff = (p.facts[0] != 0).sum().item() + (p.facts[1] != 0).sum().item() #(p != 0).sum().item()\n",
    "        total_nz += ff\n",
    "        print(n, ff / p.numel(), \"ff\")\n",
    "    else:\n",
    "        total_nz += (p != 0).sum().item()\n",
    "        print(n, (p != 0).sum().item() / p.numel())\n",
    "    total += p.numel()\n",
    "    \n",
    "total_nz, total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "66a47ed7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer3[4].conv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "314eccbf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([256, 256]), torch.Size([256, 2304]))"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer3[4].conv2.weight.facts[0].shape, modelp.layer3[4].conv2.weight.facts[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ed541870",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(device(type='cuda', index=0), device(type='cuda', index=0))"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f1 = nn.Conv2d(256, 256, 1, bias=False)\n",
    "f2 = nn.Conv2d(256, 256, 3, padding=1, bias=False)\n",
    "f1.weight.data = modelp.layer3[4].conv2.weight.facts[0].reshape(f1.weight.shape)\n",
    "f2.weight.data = modelp.layer3[4].conv2.weight.facts[1].reshape(f2.weight.shape)\n",
    "f1 = f1.cuda()\n",
    "f2 = f2.cuda()\n",
    "f1.weight.device, f2.weight.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "617b70bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0020, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>)\n"
     ]
    }
   ],
   "source": [
    "xx = torch.randn(10,256,15,15).cuda()\n",
    "with torch.amp.autocast(\"cuda\"):\n",
    "    o1 = modelp.layer3[4].conv2(xx)\n",
    "    ot = f2(xx)\n",
    "    o2 = f1(ot)\n",
    "    print((o1 - o2).abs().max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "70a285d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layer1.0\n",
      "layer1.1\n",
      "layer1.2\n",
      "layer2.0\n",
      "layer2.1\n",
      "layer2.2\n",
      "layer2.3\n",
      "layer3.0\n",
      "layer3.1\n",
      "layer3.2\n",
      "layer3.3\n",
      "layer3.4\n",
      "layer3.5\n",
      "layer4.0\n",
      "layer4.1\n",
      "layer4.2\n"
     ]
    }
   ],
   "source": [
    "def boo(m, i, o):\n",
    "    print(\"boo\", i[0].shape)\n",
    "\n",
    "for n, m in modelp.named_modules():\n",
    "    if \"Bottleneck\" in str(type(m)):\n",
    "        print(n)\n",
    "        if hasattr(m, \"conv1b\"):\n",
    "            m.conv1 = m.conv1b\n",
    "        ff = m.conv1.weight.facts\n",
    "        m.conv1b = m.conv1\n",
    "        m.conv1 = nn.Sequential(\n",
    "            nn.Conv2d(m.conv1b.in_channels, m.conv1b.out_channels, 1, bias=False),\n",
    "            nn.Conv2d(m.conv1b.out_channels, m.conv1b.out_channels, 1, bias=False)\n",
    "        )\n",
    "        m.conv1[0].weight.data = ff[1].reshape(m.conv1[0].weight.shape)\n",
    "        m.conv1[1].weight.data = ff[0].reshape(m.conv1[1].weight.shape)\n",
    "        m.conv1.cuda()\n",
    "        \n",
    "        #print(n)\n",
    "        if hasattr(m, \"conv2b\"):\n",
    "            m.conv2 = m.conv2b\n",
    "        ff = m.conv2.weight.facts\n",
    "        m.conv2b = m.conv2\n",
    "        m.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(m.conv2b.in_channels, m.conv2b.out_channels, 3, padding=1, stride=m.conv2b.stride, bias=False),\n",
    "            nn.Conv2d(m.conv2b.out_channels, m.conv2b.out_channels, 1, bias=False)\n",
    "        )\n",
    "        #m.conv2[0].register_forward_hook(boo)\n",
    "        m.conv2[0].weight.data = ff[1].reshape(m.conv2[0].weight.shape)\n",
    "        m.conv2[1].weight.data = ff[0].reshape(m.conv2[1].weight.shape)\n",
    "        m.conv2.cuda()\n",
    "        \n",
    "        if hasattr(m, \"conv3b\"):\n",
    "            m.conv3 = m.conv3b\n",
    "        ff = m.conv3.weight.facts\n",
    "        m.conv3b = m.conv3\n",
    "        m.conv3 = nn.Sequential(\n",
    "            nn.Conv2d(m.conv3b.in_channels, m.conv3b.in_channels, 1, bias=False),\n",
    "            nn.Conv2d(m.conv3b.in_channels, m.conv3b.out_channels, 1, bias=False)\n",
    "        )\n",
    "        m.conv3[0].weight.data = ff[1].reshape(m.conv3[0].weight.shape)\n",
    "        m.conv3[1].weight.data = ff[0].reshape(m.conv3[1].weight.shape)\n",
    "        m.conv3.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "177807d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating ...\n",
      "74.95\n"
     ]
    }
   ],
   "source": [
    "test(modelp, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "32c55597",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       ")"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer2[0].conv1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "bc956b1e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer2[0].conv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "08523486",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer2[0].conv3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12ffb56e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
