{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3ed44999",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e57ab430",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from datautils import *\n",
    "from database import *\n",
    "from modelutils import *\n",
    "from quant import *\n",
    "import time\n",
    "from timm.optim import Lamb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e8fe2fac",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
     ]
    }
   ],
   "source": [
    "dataloader, testloader = get_loaders(\n",
    "    \"imagenet\", path=\"\",\n",
    "    batchsize=-1, workers=8,\n",
    "    nsamples=1024, seed=0,\n",
    "    noaug=False\n",
    ")\n",
    "get_model, test, run = get_functions(\"rn50\")\n",
    "modelp = get_model()\n",
    "model_orig = get_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fd1c695e",
   "metadata": {},
   "outputs": [],
   "source": [
    "db = SparsityDatabase(\"unstr\", \"rn50\", prefix='', dev='cpu')\n",
    "modelp = modelp.to('cpu')\n",
    "layersp = find_layers(modelp)\n",
    "with open(\"rn50_unstr_300x_dp.txt\", 'r') as f:\n",
    "    config = {}\n",
    "    for l in f.readlines():\n",
    "        level, name = l.strip().split(' ')\n",
    "        config[name] = level \n",
    "db.stitch(layersp, config)\n",
    "modelp = modelp.to(DEV)\n",
    "layersp = find_layers(modelp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "99802eb5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv1.weight 0.38732993197278914\n",
      "bn1.weight 1.0\n",
      "layer1.0.conv1.weight 0.59033203125\n",
      "layer1.0.bn1.weight 1.0\n",
      "layer1.0.conv2.weight 0.166748046875\n",
      "layer1.0.bn2.weight 1.0\n",
      "layer1.0.conv3.weight 0.254150390625\n",
      "layer1.0.bn3.weight 1.0\n",
      "layer1.0.downsample.0.weight 0.430419921875\n",
      "layer1.0.downsample.1.weight 1.0\n",
      "layer1.1.conv1.weight 0.28240966796875\n",
      "layer1.1.bn1.weight 1.0\n",
      "layer1.1.conv2.weight 0.2824164496527778\n",
      "layer1.1.bn2.weight 1.0\n",
      "layer1.1.conv3.weight 0.31378173828125\n",
      "layer1.1.bn3.weight 1.0\n",
      "layer1.2.conv1.weight 0.12152099609375\n",
      "layer1.2.bn1.weight 1.0\n",
      "layer1.2.conv2.weight 0.1852756076388889\n",
      "layer1.2.bn2.weight 1.0\n",
      "layer1.2.conv3.weight 0.28240966796875\n",
      "layer1.2.bn3.weight 1.0\n",
      "layer2.0.conv1.weight 0.38739013671875\n",
      "layer2.0.bn1.weight 1.0\n",
      "layer2.0.conv2.weight 0.3486735026041667\n",
      "layer2.0.bn2.weight 1.0\n",
      "layer2.0.conv3.weight 0.2824249267578125\n",
      "layer2.0.bn3.weight 1.0\n",
      "layer2.0.downsample.0.weight 0.2058868408203125\n",
      "layer2.0.downsample.1.weight 1.0\n",
      "layer2.1.conv1.weight 0.15008544921875\n",
      "layer2.1.bn1.weight 1.0\n",
      "layer2.1.conv2.weight 0.07976616753472222\n",
      "layer2.1.bn2.weight 1.0\n",
      "layer2.1.conv3.weight 0.0984649658203125\n",
      "layer2.1.bn3.weight 1.0\n",
      "layer2.2.conv1.weight 0.1852874755859375\n",
      "layer2.2.bn1.weight 1.0\n",
      "layer2.2.conv2.weight 0.3486735026041667\n",
      "layer2.2.bn2.weight 1.0\n",
      "layer2.2.conv3.weight 0.348663330078125\n",
      "layer2.2.bn3.weight 1.0\n",
      "layer2.3.conv1.weight 0.07177734375\n",
      "layer2.3.bn1.weight 1.0\n",
      "layer2.3.conv2.weight 0.13508436414930555\n",
      "layer2.3.bn2.weight 1.0\n",
      "layer2.3.conv3.weight 0.3874053955078125\n",
      "layer2.3.bn3.weight 1.0\n",
      "layer3.0.conv1.weight 0.34867095947265625\n",
      "layer3.0.bn1.weight 1.0\n",
      "layer3.0.conv2.weight 0.2541859944661458\n",
      "layer3.0.bn2.weight 1.0\n",
      "layer3.0.conv3.weight 0.7289962768554688\n",
      "layer3.0.bn3.weight 1.0\n",
      "layer3.0.downsample.0.weight 0.2824287414550781\n",
      "layer3.0.downsample.1.weight 1.0\n",
      "layer3.1.conv1.weight 0.12157440185546875\n",
      "layer3.1.bn1.weight 1.0\n",
      "layer3.1.conv2.weight 0.1500939263237847\n",
      "layer3.1.bn2.weight 1.0\n",
      "layer3.1.conv3.weight 0.5904884338378906\n",
      "layer3.1.bn3.weight 1.0\n",
      "layer3.2.conv1.weight 0.3138084411621094\n",
      "layer3.2.bn1.weight 1.0\n",
      "layer3.2.conv2.weight 0.12157524956597222\n",
      "layer3.2.bn2.weight 1.0\n",
      "layer3.2.conv3.weight 0.3138084411621094\n",
      "layer3.2.bn3.weight 1.0\n",
      "layer3.3.conv1.weight 0.3138084411621094\n",
      "layer3.3.bn1.weight 1.0\n",
      "layer3.3.conv2.weight 0.3486768934461806\n",
      "layer3.3.bn2.weight 1.0\n",
      "layer3.3.conv3.weight 0.47829437255859375\n",
      "layer3.3.bn3.weight 1.0\n",
      "layer3.4.conv1.weight 0.5314407348632812\n",
      "layer3.4.bn1.weight 1.0\n",
      "layer3.4.conv2.weight 0.2287665473090278\n",
      "layer3.4.bn2.weight 1.0\n",
      "layer3.4.conv3.weight 0.47829437255859375\n",
      "layer3.4.bn3.weight 1.0\n",
      "layer3.5.conv1.weight 0.4304656982421875\n",
      "layer3.5.bn1.weight 1.0\n",
      "layer3.5.conv2.weight 0.3486768934461806\n",
      "layer3.5.bn2.weight 1.0\n",
      "layer3.5.conv3.weight 0.4304656982421875\n",
      "layer3.5.bn3.weight 1.0\n",
      "layer4.0.conv1.weight 0.5904884338378906\n",
      "layer4.0.bn1.weight 1.0\n",
      "layer4.0.conv2.weight 0.5314407348632812\n",
      "layer4.0.bn2.weight 1.0\n",
      "layer4.0.conv3.weight 0.47829627990722656\n",
      "layer4.0.bn3.weight 1.0\n",
      "layer4.0.downsample.0.weight 0.3486781120300293\n",
      "layer4.0.downsample.1.weight 1.0\n",
      "layer4.1.conv1.weight 0.3874197006225586\n",
      "layer4.1.bn1.weight 1.0\n",
      "layer4.1.conv2.weight 0.47829649183485246\n",
      "layer4.1.bn2.weight 1.0\n",
      "layer4.1.conv3.weight 0.8099994659423828\n",
      "layer4.1.bn3.weight 1.0\n",
      "layer4.2.conv1.weight 0.5314407348632812\n",
      "layer4.2.bn1.weight 1.0\n",
      "layer4.2.conv2.weight 0.47829649183485246\n",
      "layer4.2.bn2.weight 1.0\n",
      "layer4.2.conv3.weight 0.6560993194580078\n",
      "layer4.2.bn3.weight 1.0\n",
      "fc.weight 1.0\n"
     ]
    }
   ],
   "source": [
    "total_nz = 0\n",
    "\n",
    "for n, p in modelp.named_parameters():\n",
    "    if \"weight\" not in n:\n",
    "        continue\n",
    "    print(n, (p != 0).sum().item() / p.numel())\n",
    "    total_nz += (p != 0).sum().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1afe1ed7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating ...\n",
      "74.52\n"
     ]
    }
   ],
   "source": [
    "test(modelp, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bb66c0dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "0 1\n",
      "0 2\n",
      "0 3\n",
      "0 4\n",
      "0 5\n",
      "0 6\n",
      "0 7\n",
      "1 0\n",
      "1 1\n",
      "1 2\n",
      "1 3\n",
      "1 4\n",
      "1 5\n",
      "1 6\n",
      "1 7\n",
      "2 0\n",
      "2 1\n",
      "2 2\n",
      "2 3\n",
      "2 4\n",
      "2 5\n",
      "2 6\n",
      "2 7\n",
      "3 0\n",
      "3 1\n",
      "3 2\n",
      "3 3\n",
      "3 4\n",
      "3 5\n",
      "3 6\n",
      "3 7\n",
      "4 0\n",
      "4 1\n",
      "4 2\n",
      "4 3\n",
      "4 4\n",
      "4 5\n",
      "4 6\n",
      "4 7\n",
      "5 0\n",
      "5 1\n",
      "5 2\n",
      "5 3\n",
      "5 4\n",
      "5 5\n",
      "5 6\n",
      "5 7\n",
      "6 0\n",
      "6 1\n",
      "6 2\n",
      "6 3\n",
      "6 4\n",
      "6 5\n",
      "6 6\n",
      "6 7\n",
      "7 0\n",
      "7 1\n",
      "7 2\n",
      "7 3\n",
      "7 4\n",
      "7 5\n",
      "7 6\n",
      "7 7\n",
      "8 0\n",
      "8 1\n",
      "8 2\n",
      "8 3\n",
      "8 4\n",
      "8 5\n",
      "8 6\n",
      "8 7\n",
      "9 0\n",
      "9 1\n",
      "9 2\n",
      "9 3\n",
      "9 4\n",
      "9 5\n",
      "9 6\n",
      "9 7\n"
     ]
    }
   ],
   "source": [
    "handles = []\n",
    "\n",
    "def add_batch(layer, inp, out):\n",
    "    layer.batches = [(inp[0].detach(), out.detach())]\n",
    "    X = inp[0].detach().float()\n",
    "    #print(X.shape)\n",
    "    #assert X.shape[2] == 1\n",
    "    # TODO: unfold\n",
    "    #X = X.permute(0, 2, 3, 1)\n",
    "    #X = X.reshape(-1, X.shape[-1])\n",
    "    if isinstance(layer, nn.Conv2d):\n",
    "        unfold = nn.Unfold(\n",
    "            layer.kernel_size,\n",
    "            dilation=layer.dilation,\n",
    "            padding=layer.padding,\n",
    "            stride=layer.stride\n",
    "        )\n",
    "        X = unfold(X)\n",
    "        X = X.permute([1, 0, 2])\n",
    "        X = X.flatten(1)\n",
    "    layer.XX += X.matmul(X.T)\n",
    "\n",
    "for n, m in model_orig.named_modules():\n",
    "    if type(m) == nn.Conv2d:\n",
    "        Wf = m.weight.flatten(1)\n",
    "        m.XX = torch.zeros(Wf.shape[1], Wf.shape[1], device=m.weight.device)\n",
    "        handles.append(m.register_forward_hook(add_batch))\n",
    "        \n",
    "for i in range(10):\n",
    "    for j, batch in enumerate(dataloader):\n",
    "        print(i, j)\n",
    "        with torch.no_grad():\n",
    "            run(model_orig, batch)\n",
    "        \n",
    "for h in handles:\n",
    "    h.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "17c99a91",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_other2(A, W, nnz, Z, U, print_sc=None, debug=False, reg=0, rho_start=0.03, iters=5, prune_iters=2):\n",
    "    XX = A.T.matmul(A)\n",
    "    norm2 = torch.diag(XX).sqrt() + 1e-8\n",
    "    An = A / norm2\n",
    "    XX = An.T.matmul(An)\n",
    "    XX += torch.diag(torch.ones_like(XX.diag())) * XX.diag().mean() * reg\n",
    "    \n",
    "    #norm2 = torch.ones_like(norm2)\n",
    "    Wnn = W# * norm2.unsqueeze(1)\n",
    "    rho = 1\n",
    "    XY = An.T.matmul(Wnn)\n",
    "    XXinv = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho)\n",
    "    XXinv2 = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho_start)\n",
    "    U = U * norm2.unsqueeze(1)\n",
    "    Z = Z * norm2.unsqueeze(1)\n",
    "    \n",
    "    #B = torch.linalg.solve(XX, XY)\n",
    "    B = XXinv2.matmul(XY + rho_start*(Z-U))\n",
    "    \n",
    "    #U = torch.zeros_like(B)\n",
    "    \n",
    "    #Z = B\n",
    "    \n",
    "    bsparsity = min(0.99, 1 - nnz/B.numel())\n",
    "    #print(\"bs\", bsparsity)\n",
    "\n",
    "\n",
    "    for itt in range(iters):\n",
    "        if itt < prune_iters:\n",
    "            cur_sparsity = bsparsity# - bsparsity * (1 - (itt + 1) / iterative_prune) ** 3\n",
    "            thres = (B+U).abs().flatten().sort()[0][int(B.numel() * cur_sparsity)]\n",
    "            mask = ((B+U).abs() > thres)\n",
    "            del thres\n",
    "\n",
    "        Z = (B + U) * mask    \n",
    "\n",
    "        U = U + (B - Z)    \n",
    "\n",
    "        B = XXinv.matmul(XY + rho*(Z-U))\n",
    "        #B = torch.linalg.solve(XX + torch.eye(XX.shape[1], device=XX.device)*rho, XY + rho*(Z-U))\n",
    "        if debug:\n",
    "            print(itt, cur_sparsity, (Z != 0).sum().item() / Z.numel())\n",
    "            print_sc(A.matmul(B / norm2.unsqueeze(1)))\n",
    "            print_sc(A.matmul(Z / norm2.unsqueeze(1)))\n",
    "            print(((An != 0).sum() + (Z != 0).sum()) / W.numel())\n",
    "            print(\"-------\")\n",
    "    if debug:\n",
    "        print(\"opt end\")\n",
    "\n",
    "    return Z / norm2.unsqueeze(1), U / norm2.unsqueeze(1)    \n",
    "    \n",
    "def mag_prune(W, sp=0.6):\n",
    "    thres = (W).abs().flatten().sort()[0][int(W.numel() * sp)]\n",
    "    mask = ((W).abs() > thres)\n",
    "    return W * mask\n",
    "\n",
    "def ent(p):\n",
    "    return -(p * np.log2(p) + (1-p) * np.log2(1-p))\n",
    "\n",
    "def factorizeT(W, XX, asp=0.16, sp=0.4, iters=40):\n",
    "    #W = lx.weight.detach().T.float()\n",
    "    nza = int(W.shape[0]**2 * asp)\n",
    "    nzb = int(W.numel() * sp - nza)\n",
    "    \n",
    "    Az = torch.eye(W.shape[0], device=W.device)\n",
    "    Au = torch.zeros_like(Az)\n",
    "    norm = XX.diag().sqrt().unsqueeze(1) + 1e-8\n",
    "    norm = torch.ones_like(norm)\n",
    "       \n",
    "    Wn = W * norm\n",
    "       \n",
    "    Bz = mag_prune(Wn, (1 - nzb/2/W.numel()))\n",
    "    Bu = torch.zeros_like(Bz)\n",
    "    \n",
    "    for itt in range(iters):\n",
    "        #if itt < 10:\n",
    "        #    rho_start = 0.0\n",
    "        #elif itt < 15:\n",
    "        #    rho_start = 0.00\n",
    "        #else:\n",
    "        #    rho_start = 0.1\n",
    "        rho_start = min(1.0, itt / (iters-3))**3\n",
    "        Az, Au = (x.T for x in find_other2(Bz.T, Wn.T, nza, Az.T, Au.T, reg=1e-2, debug=False, rho_start=rho_start))\n",
    "                \n",
    "        Bz, Bu = find_other2(Az, Wn, nzb, Bz, Bu, reg=1e-2, debug=False, rho_start=rho_start)\n",
    "    \n",
    "    #print(((Az != 0).sum() + (Bz != 0).sum()).item() / W.numel(), (Az != 0).sum().item() / Az.numel(),\n",
    "    #      (Bz != 0).sum().item() / Bz.numel(), Az.shape, Bz.shape,\n",
    "    #     (Az.numel()*ent((Az != 0).sum().item() / Az.numel()) + Bz.numel()*ent((Bz != 0).sum().item() / Bz.numel())) / W.numel(), \n",
    "    #    ent(0.4), ent(0.5))\n",
    "    return ((Az / norm).matmul(Bz)).T, Bz.T, (Az / norm).T\n",
    "\n",
    "\n",
    "def factorizef(W, XX, asp=0.16, sp=0.4, iters=200, l_prev=None):\n",
    "    s_time = time.time()\n",
    "    if W.shape[0] >= W.shape[1]:\n",
    "        return factorizeT(W.T, XX, sp=sp, asp=asp, iters=iters)\n",
    "    \n",
    "    nza = int(W.shape[0]**2 * asp)\n",
    "    nzb = int(W.numel() * sp - nza)\n",
    "    norm = XX.diag().sqrt() + 1e-8\n",
    "    norm = torch.ones_like(norm)\n",
    "\n",
    "    Wn = W * norm\n",
    "    \n",
    "    Az = torch.eye(W.shape[0], device=W.device)\n",
    "    Au = torch.zeros_like(Az)\n",
    "\n",
    "    Bz = mag_prune(Wn, (1 - nzb/2/W.numel()))\n",
    "    Bu = torch.zeros_like(Bz)\n",
    "    \n",
    "    for itt in range(iters):\n",
    "        #if itt < 10:\n",
    "        #    rho_start = 0.0\n",
    "        #elif itt < 15:\n",
    "        #    rho_start = 0.00\n",
    "        #else:\n",
    "        #    rho_start = 0.1\n",
    "            \n",
    "        rho_start = min(1.0, itt / (iters-3))**3\n",
    "        Az, Au = (x.T for x in find_other2(Bz.T, Wn.T, nza, Az.T, Au.T, reg=1e-2, debug=False, rho_start=rho_start))\n",
    "                \n",
    "        Bz, Bu = find_other2(Az, Wn, nzb, Bz, Bu, reg=1e-2, debug=False, rho_start=rho_start)\n",
    "        \n",
    "        #print(itt, time.time() - s_time, end =\" \") \n",
    "        #print_scores(Az.matmul(Bz / norm))\n",
    "        \n",
    "        \n",
    "    #print(((Az != 0).sum() + (Bz != 0).sum()).item() / W.numel(), (Az != 0).sum().item() / Az.numel(),\n",
    "    #      (Bz != 0).sum().item() / Bz.numel(), Az.shape, Bz.shape,\n",
    "    #     (Az.numel()*ent((Az != 0).sum().item() / Az.numel()) + Bz.numel()*ent((Bz != 0).sum().item() / Bz.numel())) / W.numel(), \n",
    "    #    ent(0.4), ent(0.5))\n",
    "    return Az.matmul(Bz / norm), Az, Bz / norm\n",
    "\n",
    "def finalize(XXb, W, Ab, Bb):\n",
    "    fsparsity = 1 - (Ab != 0).sum() / Ab.numel()\n",
    "    mask = (Ab != 0).T\n",
    "\n",
    "    XX = Bb.matmul(XXb).matmul(Bb.T)\n",
    "    XY = Bb.matmul(XXb).matmul(W.detach().float().T)\n",
    "\n",
    "    norm2 = torch.diag(XX).sqrt() + 1e-8\n",
    "    XX = XX / norm2 / norm2.unsqueeze(1)\n",
    "    XY = XY / norm2.unsqueeze(1)\n",
    "    Ax = (Ab * norm2).T.clone()\n",
    "    #Ax = torch.linalg.solve(XX, XY)\n",
    "\n",
    "    rho = 1\n",
    "    XXinv = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho)\n",
    "    U = torch.zeros_like(Ax)\n",
    "    for itt in range(200):\n",
    "        #if itt < 150:\n",
    "        #    cur_sparsity = fsparsity - fsparsity * (1 - (itt + 1) / 150) ** 3\n",
    "        #    thres = (Ax+U).abs().flatten().sort()[0][int(Ax.numel() * cur_sparsity)]\n",
    "        #    mask = ((Ax+U).abs() > thres)\n",
    "        #    del thres\n",
    "\n",
    "        \n",
    "        Z = (Ax + U) * mask    \n",
    "\n",
    "        U = U + (Ax - Z)    \n",
    "\n",
    "        Ax = XXinv.matmul(XY + rho*(Z-U))\n",
    "\n",
    "    Ac = Z.T / norm2\n",
    "    return Ac\n",
    "\n",
    "def find_a(B, Za, Ua, rho, D, Q, E, R, XX, W):\n",
    "    F = rho*(Za-Ua) + XX.matmul(W).matmul(B.T)\n",
    "    \n",
    "    right = Q.T.matmul(F).matmul(R)\n",
    "    \n",
    "    div = D.unsqueeze(1).matmul(E.unsqueeze(0)) + rho\n",
    "    QAR = right / div\n",
    "    \n",
    "    A3 = Q.matmul(QAR).matmul(R.T)\n",
    "    return A3\n",
    "\n",
    "\n",
    "def get_at(XX, W, A, B):\n",
    "    mask = (A != 0)\n",
    "    \n",
    "    norm2 = torch.diag(XX).sqrt() + 1e-8\n",
    "    XXn = XX / norm2 / norm2.unsqueeze(1)\n",
    "    \n",
    "    #XXn += torch.diag(XXn.diag()*0 + 0.01*XXn.diag().mean())\n",
    "    \n",
    "    Wn = W * norm2.unsqueeze(1)\n",
    "    #XY = XY / norm2.unsqueeze(1)\n",
    "    \n",
    "    normB = torch.norm(B, dim=1) + 1e-8\n",
    "    Bn = B / normB.unsqueeze(1)\n",
    "    BBn = Bn.matmul(Bn.T)\n",
    "    #BBn += torch.diag(BBn.diag()*0 + 0.01*BBn.diag().mean())\n",
    "    #print(BBn.diag())\n",
    "    \n",
    "    D, Q = torch.linalg.eigh(XXn)\n",
    "    E, R = torch.linalg.eigh(BBn)\n",
    "    \n",
    "    #print(D, E)\n",
    "    \n",
    "    Za = A * norm2.unsqueeze(1) * normB\n",
    "    Ua = torch.zeros_like(Za)\n",
    "    rho = 1\n",
    "    \n",
    "    for itt in range(20):\n",
    "        A2 = find_a(Bn, Za, Ua, rho, D, Q, E, R, XXn, Wn)\n",
    "        Wx = (A2 / norm2.unsqueeze(1) / normB).matmul(B)\n",
    "        #print(itt)\n",
    "        #print(\"   errx\", (Wx - W).T.matmul(XX).matmul((Wx - W)).diag().sum().item())\n",
    "        \n",
    "        Za = (A2 + Ua) * mask\n",
    "        Ua = Ua + (A2 - Za)\n",
    "        Wx = (Za / norm2.unsqueeze(1) / normB).matmul(B)\n",
    "        #print(\"   errz\", (Wx - W).T.matmul(XX).matmul((Wx - W)).diag().sum().item())\n",
    "    return Za / norm2.unsqueeze(1) / normB\n",
    "\n",
    "def factorize(XX, W, sp, l_prev=None):\n",
    "    W = W.detach().float()\n",
    "    asp = max(0.16, sp/2)\n",
    "    W2, Ab, Bb = factorizef(W, XX, sp=sp, asp=asp, l_prev=l_prev)\n",
    "    print(\"err_prefin\", (W2 - W).matmul(XX).matmul((W2 - W).T).diag().sum().item())\n",
    "    Ac = finalize(XX, W, Ab, Bb)\n",
    "    W3 = Ac.matmul(Bb)\n",
    "    assert W3.shape == W.shape\n",
    "    print(\"err_fin   \", (W3 - W).matmul(XX).matmul((W3 - W).T).diag().sum().item())\n",
    "    #fin_b(XX, W, Ac, Bb)\n",
    "    \n",
    "    Bc = get_at(XX, W.T, Bb.T, Ac.T).T\n",
    "    \n",
    "    W4 = Ac.matmul(Bc)\n",
    "    assert W3.shape == W.shape\n",
    "    print(\"err_fin2   \", (W4 - W).matmul(XX).matmul((W4 - W).T).diag().sum().item())\n",
    "    \n",
    "    print(\"sparsity check\", ((Ac != 0).sum() + (Bc != 0).sum()).item() / W3.numel())\n",
    "    return W4, (Ac.cpu(), Bc.cpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3cd37466",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "err_prefin 216261.34375\n",
      "err_fin    18613.828125\n",
      "err_fin2    15016.3046875\n",
      "sparsity check 0.58984375\n",
      "layer1.0.conv1 0.59033203125 torch.Size([64, 64, 1, 1]) 27640.7734375 15016.3046875 \n",
      "err_prefin 4613095.0\n",
      "err_fin    674008.625\n",
      "err_fin2    436372.875\n",
      "sparsity check 0.1666937934027778\n",
      "layer1.0.conv2 0.166748046875 torch.Size([64, 64, 3, 3]) 270299.5625 436372.875 bad\n",
      "err_prefin 1210875.5\n",
      "err_fin    226695.09375\n",
      "err_fin2    222362.5\n",
      "sparsity check 0.2540283203125\n",
      "layer1.0.conv3 0.254150390625 torch.Size([256, 64, 1, 1]) 155570.5625 222362.5 bad\n",
      "err_prefin 2750552.0\n",
      "err_fin    272261.8125\n",
      "err_fin2    267898.15625\n",
      "sparsity check 0.4302978515625\n",
      "layer1.0.downsample.0 0.430419921875 torch.Size([256, 64, 1, 1]) 361443.9375 267898.15625 \n",
      "err_prefin 943236.875\n",
      "err_fin    304767.3125\n",
      "err_fin2    148208.984375\n",
      "sparsity check 0.28228759765625\n",
      "layer1.1.conv1 0.28240966796875 torch.Size([64, 256, 1, 1]) 133102.5625 148208.984375 bad\n",
      "err_prefin 1640304.375\n",
      "err_fin    714469.75\n",
      "err_fin2    337674.0625\n",
      "sparsity check 0.2823621961805556\n",
      "layer1.1.conv2 0.2824164496527778 torch.Size([64, 64, 3, 3]) 233889.859375 337674.0625 bad\n",
      "err_prefin 372065.75\n",
      "err_fin    154249.0625\n",
      "err_fin2    151873.625\n",
      "sparsity check 0.31365966796875\n",
      "layer1.1.conv3 0.31378173828125 torch.Size([256, 64, 1, 1]) 193306.09375 151873.625 \n",
      "err_prefin 6091756.0\n",
      "err_fin    2304855.0\n",
      "err_fin2    1807763.75\n",
      "sparsity check 0.12139892578125\n",
      "layer1.2.conv1 0.12152099609375 torch.Size([64, 256, 1, 1]) 1535052.375 1807763.75 bad\n",
      "err_prefin 4440164.0\n",
      "err_fin    2749006.5\n",
      "err_fin2    1665399.375\n",
      "sparsity check 0.18522135416666666\n",
      "layer1.2.conv2 0.1852756076388889 torch.Size([64, 64, 3, 3]) 1387377.5 1665399.375 bad\n",
      "err_prefin 254454.125\n",
      "err_fin    180026.65625\n",
      "err_fin2    178256.375\n",
      "sparsity check 0.28228759765625\n",
      "layer1.2.conv3 0.28240966796875 torch.Size([256, 64, 1, 1]) 259815.3125 178256.375 \n",
      "err_prefin 1605288.25\n",
      "err_fin    452022.0\n",
      "err_fin2    290838.71875\n",
      "sparsity check 0.3873291015625\n",
      "layer2.0.conv1 0.38739013671875 torch.Size([128, 256, 1, 1]) 547095.8125 290838.71875 \n",
      "err_prefin 757876.8125\n",
      "err_fin    396273.96875\n",
      "err_fin2    215700.625\n",
      "sparsity check 0.3486599392361111\n",
      "layer2.0.conv2 0.3486735026041667 torch.Size([128, 128, 3, 3]) 243216.0625 215700.625 \n",
      "err_prefin 415472.0\n",
      "err_fin    210704.5\n",
      "err_fin2    207597.703125\n",
      "sparsity check 0.2823944091796875\n",
      "layer2.0.conv3 0.2824249267578125 torch.Size([512, 128, 1, 1]) 307151.9375 207597.703125 \n",
      "err_prefin 1149402.75\n",
      "err_fin    307029.75\n",
      "err_fin2    281336.1875\n",
      "sparsity check 0.20587158203125\n",
      "layer2.0.downsample.0 0.2058868408203125 torch.Size([512, 256, 1, 1]) 534038.875 281336.1875 \n",
      "err_prefin 514511.25\n",
      "err_fin    190300.59375\n",
      "err_fin2    141036.484375\n",
      "sparsity check 0.150054931640625\n",
      "layer2.1.conv1 0.15008544921875 torch.Size([128, 512, 1, 1]) 173098.96875 141036.484375 \n",
      "err_prefin 3022651.5\n",
      "err_fin    582711.1875\n",
      "err_fin2    471264.28125\n",
      "sparsity check 0.07975260416666667\n",
      "layer2.1.conv2 0.07976616753472222 torch.Size([128, 128, 3, 3]) 256422.453125 471264.28125 bad\n",
      "err_prefin 610936.25\n",
      "err_fin    427276.5625\n",
      "err_fin2    409610.5625\n",
      "sparsity check 0.0984344482421875\n",
      "layer2.1.conv3 0.0984649658203125 torch.Size([512, 128, 1, 1]) 319872.875 409610.5625 bad\n",
      "err_prefin 1572108.0\n",
      "err_fin    705519.25\n",
      "err_fin2    500200.59375\n",
      "sparsity check 0.1852569580078125\n",
      "layer2.2.conv1 0.1852874755859375 torch.Size([128, 512, 1, 1]) 502911.25 500200.59375 \n",
      "err_prefin 595265.625\n",
      "err_fin    279926.6875\n",
      "err_fin2    144490.703125\n",
      "sparsity check 0.3486599392361111\n",
      "layer2.2.conv2 0.3486735026041667 torch.Size([128, 128, 3, 3]) 152981.40625 144490.703125 \n",
      "err_prefin 229127.96875\n",
      "err_fin    107335.0\n",
      "err_fin2    105811.3125\n",
      "sparsity check 0.3486328125\n",
      "layer2.2.conv3 0.348663330078125 torch.Size([512, 128, 1, 1]) 197878.90625 105811.3125 \n",
      "err_prefin 7027040.0\n",
      "err_fin    3226210.0\n",
      "err_fin2    2965640.5\n",
      "sparsity check 0.071746826171875\n",
      "layer2.3.conv1 0.07177734375 torch.Size([128, 512, 1, 1]) 2963943.75 2965640.5 bad\n",
      "err_prefin 3001089.5\n",
      "err_fin    1610144.25\n",
      "err_fin2    1236240.625\n",
      "sparsity check 0.13507080078125\n",
      "layer2.3.conv2 0.13508436414930555 torch.Size([128, 128, 3, 3]) 1139468.75 1236240.625 bad\n",
      "err_prefin 131894.625\n",
      "err_fin    77573.71875\n",
      "err_fin2    76872.34375\n",
      "sparsity check 0.3873748779296875\n",
      "layer2.3.conv3 0.3874053955078125 torch.Size([512, 128, 1, 1]) 145862.40625 76872.34375 \n",
      "err_prefin 1235279.375\n",
      "err_fin    622507.1875\n",
      "err_fin2    481052.1875\n",
      "sparsity check 0.34865570068359375\n",
      "layer3.0.conv1 0.34867095947265625 torch.Size([256, 512, 1, 1]) 1018197.0625 481052.1875 \n",
      "err_prefin 629585.4375\n",
      "err_fin    403177.15625\n",
      "err_fin2    279877.34375\n",
      "sparsity check 0.2541826036241319\n",
      "layer3.0.conv2 0.2541859944661458 torch.Size([256, 256, 3, 3]) 380160.71875 279877.34375 \n",
      "err_prefin 6988.478515625\n",
      "err_fin    1747.9176025390625\n",
      "err_fin2    1720.9501953125\n",
      "sparsity check 0.7289886474609375\n",
      "layer3.0.conv3 0.7289962768554688 torch.Size([1024, 256, 1, 1]) 6336.5849609375 1720.9501953125 \n",
      "err_prefin 381735.9375\n",
      "err_fin    160588.78125\n",
      "err_fin2    153949.953125\n",
      "sparsity check 0.2824249267578125\n",
      "layer3.0.downsample.0 0.2824287414550781 torch.Size([1024, 512, 1, 1]) 380189.25 153949.953125 \n",
      "err_prefin 755230.5\n",
      "err_fin    337813.59375\n",
      "err_fin2    269405.1875\n",
      "sparsity check 0.1215667724609375\n",
      "layer3.1.conv1 0.12157440185546875 torch.Size([256, 1024, 1, 1]) 333123.625 269405.1875 \n",
      "err_prefin 961092.875\n",
      "err_fin    475658.4375\n",
      "err_fin2    345664.375\n",
      "sparsity check 0.15009053548177084\n",
      "layer3.1.conv2 0.1500939263237847 torch.Size([256, 256, 3, 3]) 359667.21875 345664.375 \n",
      "err_prefin 15077.365234375\n",
      "err_fin    6845.02490234375\n",
      "err_fin2    6753.8349609375\n",
      "sparsity check 0.5904808044433594\n",
      "layer3.1.conv3 0.5904884338378906 torch.Size([1024, 256, 1, 1]) 19731.44921875 6753.8349609375 \n",
      "err_prefin 174411.203125\n",
      "err_fin    92170.890625\n",
      "err_fin2    60751.21875\n",
      "sparsity check 0.3138008117675781\n",
      "layer3.2.conv1 0.3138084411621094 torch.Size([256, 1024, 1, 1]) 92160.5625 60751.21875 \n",
      "err_prefin 1015147.125\n",
      "err_fin    513084.0\n",
      "err_fin2    393983.25\n",
      "sparsity check 0.12157185872395833\n",
      "layer3.2.conv2 0.12157524956597222 torch.Size([256, 256, 3, 3]) 391634.875 393983.25 bad\n",
      "err_prefin 103127.671875\n",
      "err_fin    58268.796875\n",
      "err_fin2    57496.8515625\n",
      "sparsity check 0.3138008117675781\n",
      "layer3.2.conv3 0.3138084411621094 torch.Size([1024, 256, 1, 1]) 111967.390625 57496.8515625 \n",
      "err_prefin 266649.0\n",
      "err_fin    149698.984375\n",
      "err_fin2    106122.46875\n",
      "sparsity check 0.3138008117675781\n",
      "layer3.3.conv1 0.3138084411621094 torch.Size([256, 1024, 1, 1]) 159325.953125 106122.46875 \n",
      "err_prefin 168828.09375\n",
      "err_fin    99552.59375\n",
      "err_fin2    64053.1796875\n",
      "sparsity check 0.3486735026041667\n",
      "layer3.3.conv2 0.3486768934461806 torch.Size([256, 256, 3, 3]) 86669.734375 64053.1796875 \n",
      "err_prefin 24312.3046875\n",
      "err_fin    13796.65234375\n",
      "err_fin2    13658.080078125\n",
      "sparsity check 0.4782867431640625\n",
      "layer3.3.conv3 0.47829437255859375 torch.Size([1024, 256, 1, 1]) 34852.06640625 13658.080078125 \n",
      "err_prefin 48055.1953125\n",
      "err_fin    28429.37109375\n",
      "err_fin2    20127.453125\n",
      "sparsity check 0.53143310546875\n",
      "layer3.4.conv1 0.5314407348632812 torch.Size([256, 1024, 1, 1]) 50097.734375 20127.453125 \n",
      "err_prefin 376362.03125\n",
      "err_fin    219965.09375\n",
      "err_fin2    158764.53125\n",
      "sparsity check 0.2287631564670139\n",
      "layer3.4.conv2 0.2287665473090278 torch.Size([256, 256, 3, 3]) 186770.6875 158764.53125 \n",
      "err_prefin 22415.859375\n",
      "err_fin    13319.857421875\n",
      "err_fin2    13183.2041015625\n",
      "sparsity check 0.4782867431640625\n",
      "layer3.4.conv3 0.47829437255859375 torch.Size([1024, 256, 1, 1]) 32696.341796875 13183.2041015625 \n",
      "err_prefin 128751.1875\n",
      "err_fin    82834.984375\n",
      "err_fin2    63508.4296875\n",
      "sparsity check 0.43045806884765625\n",
      "layer3.5.conv1 0.4304656982421875 torch.Size([256, 1024, 1, 1]) 129100.71875 63508.4296875 \n",
      "err_prefin 175985.4375\n",
      "err_fin    107409.828125\n",
      "err_fin2    72440.2109375\n",
      "sparsity check 0.3486735026041667\n",
      "layer3.5.conv2 0.3486768934461806 torch.Size([256, 256, 3, 3]) 107564.765625 72440.2109375 \n",
      "err_prefin 48075.57421875\n",
      "err_fin    25988.125\n",
      "err_fin2    25690.767578125\n",
      "sparsity check 0.43045806884765625\n",
      "layer3.5.conv3 0.4304656982421875 torch.Size([1024, 256, 1, 1]) 59980.3125 25690.767578125 \n",
      "err_prefin 51164.5703125\n",
      "err_fin    31196.986328125\n",
      "err_fin2    26808.36328125\n",
      "sparsity check 0.590484619140625\n",
      "layer4.0.conv1 0.5904884338378906 torch.Size([512, 1024, 1, 1]) 115178.9765625 26808.36328125 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "err_prefin 18906.97265625\n",
      "err_fin    13446.25390625\n",
      "err_fin2    9080.3759765625\n",
      "sparsity check 0.5314398871527778\n",
      "layer4.0.conv2 0.5314407348632812 torch.Size([512, 512, 3, 3]) 23763.48046875 9080.3759765625 \n",
      "err_prefin 28323.662109375\n",
      "err_fin    12818.5263671875\n",
      "err_fin2    12665.55859375\n",
      "sparsity check 0.47829437255859375\n",
      "layer4.0.conv3 0.47829627990722656 torch.Size([2048, 512, 1, 1]) 37507.96875 12665.55859375 \n",
      "err_prefin 49622.78125\n",
      "err_fin    28954.6171875\n",
      "err_fin2    28306.537109375\n",
      "sparsity check 0.3486771583557129\n",
      "layer4.0.downsample.0 0.3486781120300293 torch.Size([2048, 1024, 1, 1]) 88051.0234375 28306.537109375 \n",
      "err_prefin 591586.0\n",
      "err_fin    292809.875\n",
      "err_fin2    183296.140625\n",
      "sparsity check 0.3874177932739258\n",
      "layer4.1.conv1 0.3874197006225586 torch.Size([512, 2048, 1, 1]) 348949.75 183296.140625 \n",
      "err_prefin 32276.4140625\n",
      "err_fin    23840.4140625\n",
      "err_fin2    16659.689453125\n",
      "sparsity check 0.47829564412434894\n",
      "layer4.1.conv2 0.47829649183485246 torch.Size([512, 512, 3, 3]) 36791.4296875 16659.689453125 \n",
      "err_prefin 758.2838134765625\n",
      "err_fin    374.8442077636719\n",
      "err_fin2    367.9161376953125\n",
      "sparsity check 0.80999755859375\n",
      "layer4.1.conv3 0.8099994659423828 torch.Size([2048, 512, 1, 1]) 1455.04248046875 367.9161376953125 \n",
      "err_prefin 655162.875\n",
      "err_fin    280272.75\n",
      "err_fin2    154152.40625\n",
      "sparsity check 0.5314388275146484\n",
      "layer4.2.conv1 0.5314407348632812 torch.Size([512, 2048, 1, 1]) 362123.3125 154152.40625 \n",
      "err_prefin 19483.11328125\n",
      "err_fin    14229.76171875\n",
      "err_fin2    9536.4765625\n",
      "sparsity check 0.47829564412434894\n",
      "layer4.2.conv2 0.47829649183485246 torch.Size([512, 512, 3, 3]) 35742.44140625 9536.4765625 \n",
      "err_prefin 3610.513427734375\n",
      "err_fin    1588.97119140625\n",
      "err_fin2    1564.582275390625\n",
      "sparsity check 0.656097412109375\n",
      "layer4.2.conv3 0.6560993194580078 torch.Size([2048, 512, 1, 1]) 7593.02783203125 1564.582275390625 \n"
     ]
    }
   ],
   "source": [
    "sd_pruned = modelp.state_dict()\n",
    "out_admm = {}\n",
    "\n",
    "for n, m in model_orig.named_modules():\n",
    "    if type(m) == nn.Conv2d and m.weight.shape[1] > 3:\n",
    "        w_pruned = sd_pruned[n+\".weight\"].flatten(1)\n",
    "        sparsity = (w_pruned != 0).sum().item() / w_pruned.numel()\n",
    "        w_orig = m.weight.flatten(1)\n",
    "        w_admm, facts = factorize(m.XX, w_orig, sparsity)\n",
    "        e1 = (w_orig - w_pruned).matmul(m.XX).matmul((w_orig - w_pruned).T).diag().sum().item()\n",
    "        e2 = (w_orig - w_admm).matmul(m.XX).matmul((w_orig - w_admm).T).diag().sum().item()\n",
    "        print(n, sparsity, m.weight.shape, \n",
    "              e1,\n",
    "              e2,\n",
    "              \"bad\" if e1 < e2 else \"\"\n",
    "             )\n",
    "        out_admm[n] = (w_admm.reshape(w_pruned.shape), facts)\n",
    "        #m.XX = None\n",
    "        \n",
    "for n, m in modelp.named_modules():\n",
    "    if n in out_admm:\n",
    "        m.weight.data = out_admm[n][0].reshape(m.weight.shape)\n",
    "        m.weight.facts = out_admm[n][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "99b05797",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating ...\n",
      "75.35\n"
     ]
    }
   ],
   "source": [
    "test(modelp, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8b943db1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batchnorm tuning ...\n",
      "0.8246843367815018\n",
      "000\n",
      "001\n",
      "002\n",
      "003\n",
      "004\n",
      "005\n",
      "006\n",
      "007\n",
      "008\n",
      "009\n",
      "010\n",
      "011\n",
      "012\n",
      "013\n",
      "014\n",
      "015\n",
      "016\n",
      "017\n",
      "018\n",
      "019\n",
      "020\n",
      "021\n",
      "022\n",
      "023\n",
      "024\n",
      "025\n",
      "026\n",
      "027\n",
      "028\n",
      "029\n",
      "030\n",
      "031\n",
      "032\n",
      "033\n",
      "034\n",
      "035\n",
      "036\n",
      "037\n",
      "038\n",
      "039\n",
      "040\n",
      "041\n",
      "042\n",
      "043\n",
      "044\n",
      "045\n",
      "046\n",
      "047\n",
      "048\n",
      "049\n",
      "050\n",
      "051\n",
      "052\n",
      "053\n",
      "054\n",
      "055\n",
      "056\n",
      "057\n",
      "058\n",
      "059\n",
      "060\n",
      "061\n",
      "062\n",
      "063\n",
      "064\n",
      "065\n",
      "066\n",
      "067\n",
      "068\n",
      "069\n",
      "070\n",
      "071\n",
      "072\n",
      "073\n",
      "074\n",
      "075\n",
      "076\n",
      "077\n",
      "078\n",
      "079\n",
      "080\n",
      "081\n",
      "082\n",
      "083\n",
      "084\n",
      "085\n",
      "086\n",
      "087\n",
      "088\n",
      "089\n",
      "090\n",
      "091\n",
      "092\n",
      "093\n",
      "094\n",
      "095\n",
      "096\n",
      "097\n",
      "098\n",
      "099\n",
      "0.8033170402050018\n"
     ]
    }
   ],
   "source": [
    "print('Batchnorm tuning ...')\n",
    "\n",
    "loss = 0\n",
    "with torch.no_grad():\n",
    "    for batch in dataloader:\n",
    "        loss += run(modelp, batch, loss=True)\n",
    "print(loss / 1024)\n",
    "\n",
    "batchnorms = find_layers(modelp, [nn.BatchNorm2d])\n",
    "for bn in batchnorms.values():\n",
    "    bn.reset_running_stats()\n",
    "    bn.momentum = .1\n",
    "modelp.train()\n",
    "with torch.no_grad():\n",
    "    i = 0\n",
    "    while i < 100:\n",
    "        for batch in dataloader:\n",
    "            if i == 100:\n",
    "                break\n",
    "            print('%03d' % i)\n",
    "            run(modelp, batch)\n",
    "            i += 1\n",
    "modelp.eval()\n",
    "\n",
    "loss = 0\n",
    "with torch.no_grad():\n",
    "    for batch in dataloader:\n",
    "        loss += run(modelp, batch, loss=True)\n",
    "print(loss / 1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "daaca80f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating ...\n",
      "75.56\n"
     ]
    }
   ],
   "source": [
    "test(modelp, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "0b5404a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv1.weight 0.38732993197278914\n",
      "layer1.0.conv1.weight 0.58984375 ff\n",
      "layer1.0.conv2.weight 0.1666937934027778 ff\n",
      "layer1.0.conv3.weight 0.2540283203125 ff\n",
      "layer1.0.downsample.0.weight 0.4302978515625 ff\n",
      "layer1.0.downsample.1.weight 1.0\n",
      "layer1.1.conv1.weight 0.28228759765625 ff\n",
      "layer1.1.conv2.weight 0.2823621961805556 ff\n",
      "layer1.1.conv3.weight 0.31365966796875 ff\n",
      "layer1.2.conv1.weight 0.12139892578125 ff\n",
      "layer1.2.conv2.weight 0.18522135416666666 ff\n",
      "layer1.2.conv3.weight 0.28228759765625 ff\n",
      "layer2.0.conv1.weight 0.3873291015625 ff\n",
      "layer2.0.conv2.weight 0.3486599392361111 ff\n",
      "layer2.0.conv3.weight 0.2823944091796875 ff\n",
      "layer2.0.downsample.0.weight 0.20587158203125 ff\n",
      "layer2.0.downsample.1.weight 1.0\n",
      "layer2.1.conv1.weight 0.150054931640625 ff\n",
      "layer2.1.conv2.weight 0.07975260416666667 ff\n",
      "layer2.1.conv3.weight 0.0984344482421875 ff\n",
      "layer2.2.conv1.weight 0.1852569580078125 ff\n",
      "layer2.2.conv2.weight 0.3486599392361111 ff\n",
      "layer2.2.conv3.weight 0.3486328125 ff\n",
      "layer2.3.conv1.weight 0.071746826171875 ff\n",
      "layer2.3.conv2.weight 0.13507080078125 ff\n",
      "layer2.3.conv3.weight 0.3873748779296875 ff\n",
      "layer3.0.conv1.weight 0.34865570068359375 ff\n",
      "layer3.0.conv2.weight 0.2541826036241319 ff\n",
      "layer3.0.conv3.weight 0.7289886474609375 ff\n",
      "layer3.0.downsample.0.weight 0.2824249267578125 ff\n",
      "layer3.0.downsample.1.weight 1.0\n",
      "layer3.1.conv1.weight 0.1215667724609375 ff\n",
      "layer3.1.conv2.weight 0.15009053548177084 ff\n",
      "layer3.1.conv3.weight 0.5904808044433594 ff\n",
      "layer3.2.conv1.weight 0.3138008117675781 ff\n",
      "layer3.2.conv2.weight 0.12157185872395833 ff\n",
      "layer3.2.conv3.weight 0.3138008117675781 ff\n",
      "layer3.3.conv1.weight 0.3138008117675781 ff\n",
      "layer3.3.conv2.weight 0.3486735026041667 ff\n",
      "layer3.3.conv3.weight 0.4782867431640625 ff\n",
      "layer3.4.conv1.weight 0.53143310546875 ff\n",
      "layer3.4.conv2.weight 0.2287631564670139 ff\n",
      "layer3.4.conv3.weight 0.4782867431640625 ff\n",
      "layer3.5.conv1.weight 0.43045806884765625 ff\n",
      "layer3.5.conv2.weight 0.3486735026041667 ff\n",
      "layer3.5.conv3.weight 0.43045806884765625 ff\n",
      "layer4.0.conv1.weight 0.590484619140625 ff\n",
      "layer4.0.conv2.weight 0.5314398871527778 ff\n",
      "layer4.0.conv3.weight 0.47829437255859375 ff\n",
      "layer4.0.downsample.0.weight 0.3486771583557129 ff\n",
      "layer4.0.downsample.1.weight 1.0\n",
      "layer4.1.conv1.weight 0.3874177932739258 ff\n",
      "layer4.1.conv2.weight 0.47829564412434894 ff\n",
      "layer4.1.conv3.weight 0.80999755859375 ff\n",
      "layer4.2.conv1.weight 0.5314388275146484 ff\n",
      "layer4.2.conv2.weight 0.47829564412434894 ff\n",
      "layer4.2.conv3.weight 0.656097412109375 ff\n",
      "fc.weight 1.0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(12227236, 25506752)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_nz = 0\n",
    "total = 0\n",
    "\n",
    "for n, p in modelp.named_parameters():\n",
    "    if \"weight\" not in n or \"bn\" in n:\n",
    "        continue\n",
    "    \n",
    "    if hasattr(p, \"facts\"):\n",
    "        ff = (p.facts[0] != 0).sum().item() + (p.facts[1] != 0).sum().item() #(p != 0).sum().item()\n",
    "        total_nz += ff\n",
    "        print(n, ff / p.numel(), \"ff\")\n",
    "    else:\n",
    "        total_nz += (p != 0).sum().item()\n",
    "        print(n, (p != 0).sum().item() / p.numel())\n",
    "    total += p.numel()\n",
    "    \n",
    "total_nz, total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "66a47ed7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer3[4].conv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "314eccbf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([256, 256]), torch.Size([256, 2304]))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer3[4].conv2.weight.facts[0].shape, modelp.layer3[4].conv2.weight.facts[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ed541870",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(device(type='cuda', index=0), device(type='cuda', index=0))"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f1 = nn.Conv2d(256, 256, 1, bias=False)\n",
    "f2 = nn.Conv2d(256, 256, 3, padding=1, bias=False)\n",
    "f1.weight.data = modelp.layer3[4].conv2.weight.facts[0].reshape(f1.weight.shape)\n",
    "f2.weight.data = modelp.layer3[4].conv2.weight.facts[1].reshape(f2.weight.shape)\n",
    "f1 = f1.cuda()\n",
    "f2 = f2.cuda()\n",
    "f1.weight.device, f2.weight.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "617b70bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0039, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>)\n"
     ]
    }
   ],
   "source": [
    "xx = torch.randn(10,256,15,15).cuda()\n",
    "with torch.amp.autocast(\"cuda\"):\n",
    "    o1 = modelp.layer3[4].conv2(xx)\n",
    "    ot = f2(xx)\n",
    "    o2 = f1(ot)\n",
    "    print((o1 - o2).abs().max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "70a285d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layer1.0\n",
      "layer1.1\n",
      "layer1.2\n",
      "layer2.0\n",
      "layer2.1\n",
      "layer2.2\n",
      "layer2.3\n",
      "layer3.0\n",
      "layer3.1\n",
      "layer3.2\n",
      "layer3.3\n",
      "layer3.4\n",
      "layer3.5\n",
      "layer4.0\n",
      "layer4.1\n",
      "layer4.2\n"
     ]
    }
   ],
   "source": [
    "def boo(m, i, o):\n",
    "    print(\"boo\", i[0].shape)\n",
    "\n",
    "for n, m in modelp.named_modules():\n",
    "    if \"Bottleneck\" in str(type(m)):\n",
    "        print(n)\n",
    "        if hasattr(m, \"conv1b\"):\n",
    "            m.conv1 = m.conv1b\n",
    "        ff = m.conv1.weight.facts\n",
    "        m.conv1b = m.conv1\n",
    "        m.conv1 = nn.Sequential(\n",
    "            nn.Conv2d(m.conv1b.in_channels, m.conv1b.out_channels, 1, bias=False),\n",
    "            nn.Conv2d(m.conv1b.out_channels, m.conv1b.out_channels, 1, bias=False)\n",
    "        )\n",
    "        m.conv1[0].weight.data = ff[1].reshape(m.conv1[0].weight.shape)\n",
    "        m.conv1[1].weight.data = ff[0].reshape(m.conv1[1].weight.shape)\n",
    "        m.conv1.cuda()\n",
    "        \n",
    "        #print(n)\n",
    "        if hasattr(m, \"conv2b\"):\n",
    "            m.conv2 = m.conv2b\n",
    "        ff = m.conv2.weight.facts\n",
    "        m.conv2b = m.conv2\n",
    "        m.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(m.conv2b.in_channels, m.conv2b.out_channels, 3, padding=1, stride=m.conv2b.stride, bias=False),\n",
    "            nn.Conv2d(m.conv2b.out_channels, m.conv2b.out_channels, 1, bias=False)\n",
    "        )\n",
    "        #m.conv2[0].register_forward_hook(boo)\n",
    "        m.conv2[0].weight.data = ff[1].reshape(m.conv2[0].weight.shape)\n",
    "        m.conv2[1].weight.data = ff[0].reshape(m.conv2[1].weight.shape)\n",
    "        m.conv2.cuda()\n",
    "        \n",
    "        if hasattr(m, \"conv3b\"):\n",
    "            m.conv3 = m.conv3b\n",
    "        ff = m.conv3.weight.facts\n",
    "        m.conv3b = m.conv3\n",
    "        m.conv3 = nn.Sequential(\n",
    "            nn.Conv2d(m.conv3b.in_channels, m.conv3b.in_channels, 1, bias=False),\n",
    "            nn.Conv2d(m.conv3b.in_channels, m.conv3b.out_channels, 1, bias=False)\n",
    "        )\n",
    "        m.conv3[0].weight.data = ff[1].reshape(m.conv3[0].weight.shape)\n",
    "        m.conv3[1].weight.data = ff[0].reshape(m.conv3[1].weight.shape)\n",
    "        m.conv3.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "177807d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating ...\n",
      "75.56\n"
     ]
    }
   ],
   "source": [
    "test(modelp, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "32c55597",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       ")"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer2[0].conv1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "bc956b1e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       ")"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer2[0].conv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "08523486",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "  (1): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       ")"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer2[0].conv3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12ffb56e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
