{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3ed44999",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e57ab430",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from datautils import *\n",
    "from database import *\n",
    "from modelutils import *\n",
    "from quant import *\n",
    "import time\n",
    "from timm.optim import Lamb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e8fe2fac",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
     ]
    }
   ],
   "source": [
    "dataloader, testloader = get_loaders(\n",
    "    \"imagenet\", path=\"\",\n",
    "    batchsize=-1, workers=8,\n",
    "    nsamples=1024, seed=0,\n",
    "    noaug=False\n",
    ")\n",
    "get_model, test, run = get_functions(\"rn50\")\n",
    "modelp = get_model()\n",
    "model_orig = get_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fd1c695e",
   "metadata": {},
   "outputs": [],
   "source": [
    "db = SparsityDatabase(\"unstr\", \"rn50\", prefix='', dev='cpu')\n",
    "modelp = modelp.to('cpu')\n",
    "layersp = find_layers(modelp)\n",
    "with open(\"rn50_unstr_200x_dp.txt\", 'r') as f:\n",
    "    config = {}\n",
    "    for l in f.readlines():\n",
    "        level, name = l.strip().split(' ')\n",
    "        config[name] = level \n",
    "db.stitch(layersp, config)\n",
    "modelp = modelp.to(DEV)\n",
    "layersp = find_layers(modelp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "99802eb5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv1.weight 0.38732993197278914\n",
      "bn1.weight 1.0\n",
      "layer1.0.conv1.weight 0.984375\n",
      "layer1.0.bn1.weight 1.0\n",
      "layer1.0.conv2.weight 0.4304470486111111\n",
      "layer1.0.bn2.weight 1.0\n",
      "layer1.0.conv3.weight 0.254150390625\n",
      "layer1.0.bn3.weight 1.0\n",
      "layer1.0.downsample.0.weight 0.478271484375\n",
      "layer1.0.downsample.1.weight 1.0\n",
      "layer1.1.conv1.weight 0.3486328125\n",
      "layer1.1.bn1.weight 1.0\n",
      "layer1.1.conv2.weight 0.478271484375\n",
      "layer1.1.bn2.weight 1.0\n",
      "layer1.1.conv3.weight 0.5904541015625\n",
      "layer1.1.bn3.weight 1.0\n",
      "layer1.2.conv1.weight 0.12152099609375\n",
      "layer1.2.bn1.weight 1.0\n",
      "layer1.2.conv2.weight 0.2824164496527778\n",
      "layer1.2.bn2.weight 1.0\n",
      "layer1.2.conv3.weight 0.65606689453125\n",
      "layer1.2.bn3.weight 1.0\n",
      "layer2.0.conv1.weight 0.478271484375\n",
      "layer2.0.bn1.weight 1.0\n",
      "layer2.0.conv2.weight 0.5904880099826388\n",
      "layer2.0.bn2.weight 1.0\n",
      "layer2.0.conv3.weight 0.656097412109375\n",
      "layer2.0.bn3.weight 1.0\n",
      "layer2.0.downsample.0.weight 0.2058868408203125\n",
      "layer2.0.downsample.1.weight 1.0\n",
      "layer2.1.conv1.weight 0.4782867431640625\n",
      "layer2.1.bn1.weight 1.0\n",
      "layer2.1.conv2.weight 0.20588514539930555\n",
      "layer2.1.bn2.weight 1.0\n",
      "layer2.1.conv3.weight 0.109405517578125\n",
      "layer2.1.bn3.weight 1.0\n",
      "layer2.2.conv1.weight 0.2824249267578125\n",
      "layer2.2.bn1.weight 1.0\n",
      "layer2.2.conv2.weight 0.4304606119791667\n",
      "layer2.2.bn2.weight 1.0\n",
      "layer2.2.conv3.weight 0.348663330078125\n",
      "layer2.2.bn3.weight 1.0\n",
      "layer2.3.conv1.weight 0.07177734375\n",
      "layer2.3.bn1.weight 1.0\n",
      "layer2.3.conv2.weight 0.6560940212673612\n",
      "layer2.3.bn2.weight 1.0\n",
      "layer2.3.conv3.weight 0.656097412109375\n",
      "layer2.3.bn3.weight 1.0\n",
      "layer3.0.conv1.weight 0.656097412109375\n",
      "layer3.0.bn1.weight 1.0\n",
      "layer3.0.conv2.weight 0.2824283175998264\n",
      "layer3.0.bn2.weight 1.0\n",
      "layer3.0.conv3.weight 0.7289962768554688\n",
      "layer3.0.bn3.weight 1.0\n",
      "layer3.0.downsample.0.weight 0.7289981842041016\n",
      "layer3.0.downsample.1.weight 1.0\n",
      "layer3.1.conv1.weight 0.3138084411621094\n",
      "layer3.1.bn1.weight 1.0\n",
      "layer3.1.conv2.weight 0.3486768934461806\n",
      "layer3.1.bn2.weight 1.0\n",
      "layer3.1.conv3.weight 0.7289962768554688\n",
      "layer3.1.bn3.weight 1.0\n",
      "layer3.2.conv1.weight 0.5904884338378906\n",
      "layer3.2.bn1.weight 1.0\n",
      "layer3.2.conv2.weight 0.18530103895399305\n",
      "layer3.2.bn2.weight 1.0\n",
      "layer3.2.conv3.weight 0.5314407348632812\n",
      "layer3.2.bn3.weight 1.0\n",
      "layer3.3.conv1.weight 0.656097412109375\n",
      "layer3.3.bn1.weight 1.0\n",
      "layer3.3.conv2.weight 0.5314398871527778\n",
      "layer3.3.bn2.weight 1.0\n",
      "layer3.3.conv3.weight 0.47829437255859375\n",
      "layer3.3.bn3.weight 1.0\n",
      "layer3.4.conv1.weight 0.5314407348632812\n",
      "layer3.4.bn1.weight 1.0\n",
      "layer3.4.conv2.weight 0.2287665473090278\n",
      "layer3.4.bn2.weight 1.0\n",
      "layer3.4.conv3.weight 0.80999755859375\n",
      "layer3.4.bn3.weight 1.0\n",
      "layer3.5.conv1.weight 0.4304656982421875\n",
      "layer3.5.bn1.weight 1.0\n",
      "layer3.5.conv2.weight 0.4782952202690972\n",
      "layer3.5.bn2.weight 1.0\n",
      "layer3.5.conv3.weight 0.5904884338378906\n",
      "layer3.5.bn3.weight 1.0\n",
      "layer4.0.conv1.weight 0.5904884338378906\n",
      "layer4.0.bn1.weight 1.0\n",
      "layer4.0.conv2.weight 0.5314407348632812\n",
      "layer4.0.bn2.weight 1.0\n",
      "layer4.0.conv3.weight 0.8099994659423828\n",
      "layer4.0.bn3.weight 1.0\n",
      "layer4.0.downsample.0.weight 0.7289996147155762\n",
      "layer4.0.downsample.1.weight 1.0\n",
      "layer4.1.conv1.weight 0.6560993194580078\n",
      "layer4.1.bn1.weight 1.0\n",
      "layer4.1.conv2.weight 0.8099996778700087\n",
      "layer4.1.bn2.weight 1.0\n",
      "layer4.1.conv3.weight 1.0\n",
      "layer4.1.bn3.weight 1.0\n",
      "layer4.2.conv1.weight 0.728999137878418\n",
      "layer4.2.bn1.weight 1.0\n",
      "layer4.2.conv2.weight 0.6560999552408854\n",
      "layer4.2.bn2.weight 1.0\n",
      "layer4.2.conv3.weight 0.8099994659423828\n",
      "layer4.2.bn3.weight 1.0\n",
      "fc.weight 1.0\n"
     ]
    }
   ],
   "source": [
    "total_nz = 0\n",
    "\n",
    "for n, p in modelp.named_parameters():\n",
    "    if \"weight\" not in n:\n",
    "        continue\n",
    "    print(n, (p != 0).sum().item() / p.numel())\n",
    "    total_nz += (p != 0).sum().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1afe1ed7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating ...\n",
      "75.47\n"
     ]
    }
   ],
   "source": [
    "test(modelp, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bb66c0dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "0 1\n",
      "0 2\n",
      "0 3\n",
      "0 4\n",
      "0 5\n",
      "0 6\n",
      "0 7\n",
      "1 0\n",
      "1 1\n",
      "1 2\n",
      "1 3\n",
      "1 4\n",
      "1 5\n",
      "1 6\n",
      "1 7\n",
      "2 0\n",
      "2 1\n",
      "2 2\n",
      "2 3\n",
      "2 4\n",
      "2 5\n",
      "2 6\n",
      "2 7\n",
      "3 0\n",
      "3 1\n",
      "3 2\n",
      "3 3\n",
      "3 4\n",
      "3 5\n",
      "3 6\n",
      "3 7\n",
      "4 0\n",
      "4 1\n",
      "4 2\n",
      "4 3\n",
      "4 4\n",
      "4 5\n",
      "4 6\n",
      "4 7\n",
      "5 0\n",
      "5 1\n",
      "5 2\n",
      "5 3\n",
      "5 4\n",
      "5 5\n",
      "5 6\n",
      "5 7\n",
      "6 0\n",
      "6 1\n",
      "6 2\n",
      "6 3\n",
      "6 4\n",
      "6 5\n",
      "6 6\n",
      "6 7\n",
      "7 0\n",
      "7 1\n",
      "7 2\n",
      "7 3\n",
      "7 4\n",
      "7 5\n",
      "7 6\n",
      "7 7\n",
      "8 0\n",
      "8 1\n",
      "8 2\n",
      "8 3\n",
      "8 4\n",
      "8 5\n",
      "8 6\n",
      "8 7\n",
      "9 0\n",
      "9 1\n",
      "9 2\n",
      "9 3\n",
      "9 4\n",
      "9 5\n",
      "9 6\n",
      "9 7\n"
     ]
    }
   ],
   "source": [
    "handles = []\n",
    "\n",
    "def add_batch(layer, inp, out):\n",
    "    layer.batches = [(inp[0].detach(), out.detach())]\n",
    "    X = inp[0].detach().float()\n",
    "    #print(X.shape)\n",
    "    #assert X.shape[2] == 1\n",
    "    # TODO: unfold\n",
    "    #X = X.permute(0, 2, 3, 1)\n",
    "    #X = X.reshape(-1, X.shape[-1])\n",
    "    if isinstance(layer, nn.Conv2d):\n",
    "        unfold = nn.Unfold(\n",
    "            layer.kernel_size,\n",
    "            dilation=layer.dilation,\n",
    "            padding=layer.padding,\n",
    "            stride=layer.stride\n",
    "        )\n",
    "        X = unfold(X)\n",
    "        X = X.permute([1, 0, 2])\n",
    "        X = X.flatten(1)\n",
    "    layer.XX += X.matmul(X.T)\n",
    "\n",
    "for n, m in model_orig.named_modules():\n",
    "    if type(m) == nn.Conv2d:\n",
    "        Wf = m.weight.flatten(1)\n",
    "        m.XX = torch.zeros(Wf.shape[1], Wf.shape[1], device=m.weight.device)\n",
    "        handles.append(m.register_forward_hook(add_batch))\n",
    "        \n",
    "for i in range(10):\n",
    "    for j, batch in enumerate(dataloader):\n",
    "        print(i, j)\n",
    "        with torch.no_grad():\n",
    "            run(model_orig, batch)\n",
    "        \n",
    "for h in handles:\n",
    "    h.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "17c99a91",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_other2(A, W, nnz, Z, U, print_sc=None, debug=False, reg=0, rho_start=0.03, iters=5, prune_iters=2):\n",
    "    XX = A.T.matmul(A)\n",
    "    norm2 = torch.diag(XX).sqrt() + 1e-8\n",
    "    An = A / norm2\n",
    "    XX = An.T.matmul(An)\n",
    "    XX += torch.diag(torch.ones_like(XX.diag())) * XX.diag().mean() * reg\n",
    "    \n",
    "    #norm2 = torch.ones_like(norm2)\n",
    "    Wnn = W# * norm2.unsqueeze(1)\n",
    "    rho = 1\n",
    "    XY = An.T.matmul(Wnn)\n",
    "    XXinv = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho)\n",
    "    XXinv2 = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho_start)\n",
    "    U = U * norm2.unsqueeze(1)\n",
    "    Z = Z * norm2.unsqueeze(1)\n",
    "    \n",
    "    #B = torch.linalg.solve(XX, XY)\n",
    "    B = XXinv2.matmul(XY + rho_start*(Z-U))\n",
    "    \n",
    "    #U = torch.zeros_like(B)\n",
    "    \n",
    "    #Z = B\n",
    "    \n",
    "    bsparsity = min(0.99, 1 - nnz/B.numel())\n",
    "    #print(\"bs\", bsparsity)\n",
    "\n",
    "\n",
    "    for itt in range(iters):\n",
    "        if itt < prune_iters:\n",
    "            cur_sparsity = bsparsity# - bsparsity * (1 - (itt + 1) / iterative_prune) ** 3\n",
    "            thres = (B+U).abs().flatten().sort()[0][int(B.numel() * cur_sparsity)]\n",
    "            mask = ((B+U).abs() > thres)\n",
    "            del thres\n",
    "\n",
    "        Z = (B + U) * mask    \n",
    "\n",
    "        U = U + (B - Z)    \n",
    "\n",
    "        B = XXinv.matmul(XY + rho*(Z-U))\n",
    "        #B = torch.linalg.solve(XX + torch.eye(XX.shape[1], device=XX.device)*rho, XY + rho*(Z-U))\n",
    "        if debug:\n",
    "            print(itt, cur_sparsity, (Z != 0).sum().item() / Z.numel())\n",
    "            print_sc(A.matmul(B / norm2.unsqueeze(1)))\n",
    "            print_sc(A.matmul(Z / norm2.unsqueeze(1)))\n",
    "            print(((An != 0).sum() + (Z != 0).sum()) / W.numel())\n",
    "            print(\"-------\")\n",
    "    if debug:\n",
    "        print(\"opt end\")\n",
    "\n",
    "    return Z / norm2.unsqueeze(1), U / norm2.unsqueeze(1)    \n",
    "    \n",
    "def mag_prune(W, sp=0.6):\n",
    "    thres = (W).abs().flatten().sort()[0][int(W.numel() * sp)]\n",
    "    mask = ((W).abs() > thres)\n",
    "    return W * mask\n",
    "\n",
    "def ent(p):\n",
    "    return -(p * np.log2(p) + (1-p) * np.log2(1-p))\n",
    "\n",
    "def factorizeT(W, XX, asp=0.16, sp=0.4, iters=40):\n",
    "    #W = lx.weight.detach().T.float()\n",
    "    nza = int(W.shape[0]**2 * asp)\n",
    "    nzb = int(W.numel() * sp - nza)\n",
    "    \n",
    "    Az = torch.eye(W.shape[0], device=W.device)\n",
    "    Au = torch.zeros_like(Az)\n",
    "    norm = XX.diag().sqrt().unsqueeze(1) + 1e-8\n",
    "    norm = torch.ones_like(norm)\n",
    "       \n",
    "    Wn = W * norm\n",
    "       \n",
    "    Bz = mag_prune(Wn, (1 - nzb/2/W.numel()))\n",
    "    Bu = torch.zeros_like(Bz)\n",
    "    \n",
    "    for itt in range(iters):\n",
    "        #if itt < 10:\n",
    "        #    rho_start = 0.0\n",
    "        #elif itt < 15:\n",
    "        #    rho_start = 0.00\n",
    "        #else:\n",
    "        #    rho_start = 0.1\n",
    "        rho_start = min(1.0, itt / (iters-3))**3\n",
    "        Az, Au = (x.T for x in find_other2(Bz.T, Wn.T, nza, Az.T, Au.T, reg=1e-2, debug=False, rho_start=rho_start))\n",
    "                \n",
    "        Bz, Bu = find_other2(Az, Wn, nzb, Bz, Bu, reg=1e-2, debug=False, rho_start=rho_start)\n",
    "    \n",
    "    #print(((Az != 0).sum() + (Bz != 0).sum()).item() / W.numel(), (Az != 0).sum().item() / Az.numel(),\n",
    "    #      (Bz != 0).sum().item() / Bz.numel(), Az.shape, Bz.shape,\n",
    "    #     (Az.numel()*ent((Az != 0).sum().item() / Az.numel()) + Bz.numel()*ent((Bz != 0).sum().item() / Bz.numel())) / W.numel(), \n",
    "    #    ent(0.4), ent(0.5))\n",
    "    return ((Az / norm).matmul(Bz)).T, Bz.T, (Az / norm).T\n",
    "\n",
    "\n",
    "def factorizef(W, XX, asp=0.16, sp=0.4, iters=200, l_prev=None):\n",
    "    s_time = time.time()\n",
    "    if W.shape[0] >= W.shape[1]:\n",
    "        return factorizeT(W.T, XX, sp=sp, asp=asp, iters=iters)\n",
    "    \n",
    "    nza = int(W.shape[0]**2 * asp)\n",
    "    nzb = int(W.numel() * sp - nza)\n",
    "    norm = XX.diag().sqrt() + 1e-8\n",
    "    norm = torch.ones_like(norm)\n",
    "\n",
    "    Wn = W * norm\n",
    "    \n",
    "    Az = torch.eye(W.shape[0], device=W.device)\n",
    "    Au = torch.zeros_like(Az)\n",
    "\n",
    "    Bz = mag_prune(Wn, (1 - nzb/2/W.numel()))\n",
    "    Bu = torch.zeros_like(Bz)\n",
    "    \n",
    "    for itt in range(iters):\n",
    "        #if itt < 10:\n",
    "        #    rho_start = 0.0\n",
    "        #elif itt < 15:\n",
    "        #    rho_start = 0.00\n",
    "        #else:\n",
    "        #    rho_start = 0.1\n",
    "            \n",
    "        rho_start = min(1.0, itt / (iters-3))**3\n",
    "        Az, Au = (x.T for x in find_other2(Bz.T, Wn.T, nza, Az.T, Au.T, reg=1e-2, debug=False, rho_start=rho_start))\n",
    "                \n",
    "        Bz, Bu = find_other2(Az, Wn, nzb, Bz, Bu, reg=1e-2, debug=False, rho_start=rho_start)\n",
    "        \n",
    "        #print(itt, time.time() - s_time, end =\" \") \n",
    "        #print_scores(Az.matmul(Bz / norm))\n",
    "        \n",
    "        \n",
    "    #print(((Az != 0).sum() + (Bz != 0).sum()).item() / W.numel(), (Az != 0).sum().item() / Az.numel(),\n",
    "    #      (Bz != 0).sum().item() / Bz.numel(), Az.shape, Bz.shape,\n",
    "    #     (Az.numel()*ent((Az != 0).sum().item() / Az.numel()) + Bz.numel()*ent((Bz != 0).sum().item() / Bz.numel())) / W.numel(), \n",
    "    #    ent(0.4), ent(0.5))\n",
    "    return Az.matmul(Bz / norm), Az, Bz / norm\n",
    "\n",
    "def finalize(XXb, W, Ab, Bb):\n",
    "    fsparsity = 1 - (Ab != 0).sum() / Ab.numel()\n",
    "    mask = (Ab != 0).T\n",
    "\n",
    "    XX = Bb.matmul(XXb).matmul(Bb.T)\n",
    "    XY = Bb.matmul(XXb).matmul(W.detach().float().T)\n",
    "\n",
    "    norm2 = torch.diag(XX).sqrt() + 1e-8\n",
    "    XX = XX / norm2 / norm2.unsqueeze(1)\n",
    "    XY = XY / norm2.unsqueeze(1)\n",
    "    Ax = (Ab * norm2).T.clone()\n",
    "    #Ax = torch.linalg.solve(XX, XY)\n",
    "\n",
    "    rho = 1\n",
    "    XXinv = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho)\n",
    "    U = torch.zeros_like(Ax)\n",
    "    for itt in range(200):\n",
    "        #if itt < 150:\n",
    "        #    cur_sparsity = fsparsity - fsparsity * (1 - (itt + 1) / 150) ** 3\n",
    "        #    thres = (Ax+U).abs().flatten().sort()[0][int(Ax.numel() * cur_sparsity)]\n",
    "        #    mask = ((Ax+U).abs() > thres)\n",
    "        #    del thres\n",
    "\n",
    "        \n",
    "        Z = (Ax + U) * mask    \n",
    "\n",
    "        U = U + (Ax - Z)    \n",
    "\n",
    "        Ax = XXinv.matmul(XY + rho*(Z-U))\n",
    "\n",
    "    Ac = Z.T / norm2\n",
    "    return Ac\n",
    "\n",
    "def find_a(B, Za, Ua, rho, D, Q, E, R, XX, W):\n",
    "    F = rho*(Za-Ua) + XX.matmul(W).matmul(B.T)\n",
    "    \n",
    "    right = Q.T.matmul(F).matmul(R)\n",
    "    \n",
    "    div = D.unsqueeze(1).matmul(E.unsqueeze(0)) + rho\n",
    "    QAR = right / div\n",
    "    \n",
    "    A3 = Q.matmul(QAR).matmul(R.T)\n",
    "    return A3\n",
    "\n",
    "\n",
    "def get_at(XX, W, A, B):\n",
    "    mask = (A != 0)\n",
    "    \n",
    "    norm2 = torch.diag(XX).sqrt() + 1e-8\n",
    "    XXn = XX / norm2 / norm2.unsqueeze(1)\n",
    "    \n",
    "    #XXn += torch.diag(XXn.diag()*0 + 0.01*XXn.diag().mean())\n",
    "    \n",
    "    Wn = W * norm2.unsqueeze(1)\n",
    "    #XY = XY / norm2.unsqueeze(1)\n",
    "    \n",
    "    normB = torch.norm(B, dim=1) + 1e-8\n",
    "    Bn = B / normB.unsqueeze(1)\n",
    "    BBn = Bn.matmul(Bn.T)\n",
    "    #BBn += torch.diag(BBn.diag()*0 + 0.01*BBn.diag().mean())\n",
    "    #print(BBn.diag())\n",
    "    \n",
    "    D, Q = torch.linalg.eigh(XXn)\n",
    "    E, R = torch.linalg.eigh(BBn)\n",
    "    \n",
    "    #print(D, E)\n",
    "    \n",
    "    Za = A * norm2.unsqueeze(1) * normB\n",
    "    Ua = torch.zeros_like(Za)\n",
    "    rho = 1\n",
    "    \n",
    "    for itt in range(20):\n",
    "        A2 = find_a(Bn, Za, Ua, rho, D, Q, E, R, XXn, Wn)\n",
    "        Wx = (A2 / norm2.unsqueeze(1) / normB).matmul(B)\n",
    "        #print(itt)\n",
    "        #print(\"   errx\", (Wx - W).T.matmul(XX).matmul((Wx - W)).diag().sum().item())\n",
    "        \n",
    "        Za = (A2 + Ua) * mask\n",
    "        Ua = Ua + (A2 - Za)\n",
    "        Wx = (Za / norm2.unsqueeze(1) / normB).matmul(B)\n",
    "        #print(\"   errz\", (Wx - W).T.matmul(XX).matmul((Wx - W)).diag().sum().item())\n",
    "    return Za / norm2.unsqueeze(1) / normB\n",
    "\n",
    "def factorize(XX, W, sp, l_prev=None):\n",
    "    W = W.detach().float()\n",
    "    asp = max(0.16, sp/2)\n",
    "    W2, Ab, Bb = factorizef(W, XX, sp=sp, asp=asp, l_prev=l_prev)\n",
    "    print(\"err_prefin\", (W2 - W).matmul(XX).matmul((W2 - W).T).diag().sum().item())\n",
    "    Ac = finalize(XX, W, Ab, Bb)\n",
    "    W3 = Ac.matmul(Bb)\n",
    "    assert W3.shape == W.shape\n",
    "    print(\"err_fin   \", (W3 - W).matmul(XX).matmul((W3 - W).T).diag().sum().item())\n",
    "    #fin_b(XX, W, Ac, Bb)\n",
    "    \n",
    "    Bc = get_at(XX, W.T, Bb.T, Ac.T).T\n",
    "    \n",
    "    W4 = Ac.matmul(Bc)\n",
    "    assert W3.shape == W.shape\n",
    "    print(\"err_fin2   \", (W4 - W).matmul(XX).matmul((W4 - W).T).diag().sum().item())\n",
    "    \n",
    "    print(\"sparsity check\", ((Ac != 0).sum() + (Bc != 0).sum()).item() / W3.numel())\n",
    "    return W4, (Ac.cpu(), Bc.cpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3cd37466",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "err_prefin 12655.3623046875\n",
      "err_fin    444.36785888671875\n",
      "err_fin2    348.65228271484375\n",
      "sparsity check 0.98388671875\n",
      "layer1.0.conv1 0.984375 torch.Size([64, 64, 1, 1]) 0.0 348.65228271484375 bad\n",
      "err_prefin 354029.125\n",
      "err_fin    49131.17578125\n",
      "err_fin2    24805.939453125\n",
      "sparsity check 0.4303927951388889\n",
      "layer1.0.conv2 0.4304470486111111 torch.Size([64, 64, 3, 3]) 21052.04296875 24805.939453125 bad\n",
      "err_prefin 1210875.5\n",
      "err_fin    226695.09375\n",
      "err_fin2    222362.5\n",
      "sparsity check 0.2540283203125\n",
      "layer1.0.conv3 0.254150390625 torch.Size([256, 64, 1, 1]) 155570.5625 222362.5 bad\n",
      "err_prefin 1489075.625\n",
      "err_fin    154516.40625\n",
      "err_fin2    152525.078125\n",
      "sparsity check 0.4781494140625\n",
      "layer1.0.downsample.0 0.478271484375 torch.Size([256, 64, 1, 1]) 220369.59375 152525.078125 \n",
      "err_prefin 547405.5\n",
      "err_fin    181329.84375\n",
      "err_fin2    70531.9765625\n",
      "sparsity check 0.3485107421875\n",
      "layer1.1.conv1 0.3486328125 torch.Size([64, 256, 1, 1]) 69011.8203125 70531.9765625 bad\n",
      "err_prefin 334649.375\n",
      "err_fin    113771.171875\n",
      "err_fin2    45417.421875\n",
      "sparsity check 0.4782172309027778\n",
      "layer1.1.conv2 0.478271484375 torch.Size([64, 64, 3, 3]) 46829.203125 45417.421875 \n",
      "err_prefin 35330.3359375\n",
      "err_fin    12040.1416015625\n",
      "err_fin2    11934.857421875\n",
      "sparsity check 0.59033203125\n",
      "layer1.1.conv3 0.5904541015625 torch.Size([256, 64, 1, 1]) 13709.0185546875 11934.857421875 \n",
      "err_prefin 6091756.0\n",
      "err_fin    2304855.0\n",
      "err_fin2    1807763.75\n",
      "sparsity check 0.12139892578125\n",
      "layer1.2.conv1 0.12152099609375 torch.Size([64, 256, 1, 1]) 1535052.375 1807763.75 bad\n",
      "err_prefin 2296895.0\n",
      "err_fin    1450545.0\n",
      "err_fin2    744670.0\n",
      "sparsity check 0.2823621961805556\n",
      "layer1.2.conv2 0.2824164496527778 torch.Size([64, 64, 3, 3]) 659578.75 744670.0 bad\n",
      "err_prefin 14213.068359375\n",
      "err_fin    7908.6484375\n",
      "err_fin2    7851.63671875\n",
      "sparsity check 0.65594482421875\n",
      "layer1.2.conv3 0.65606689453125 torch.Size([256, 64, 1, 1]) 8256.5947265625 7851.63671875 \n",
      "err_prefin 737360.1875\n",
      "err_fin    191887.65625\n",
      "err_fin2    116754.9375\n",
      "sparsity check 0.47821044921875\n",
      "layer2.0.conv1 0.478271484375 torch.Size([128, 256, 1, 1]) 256275.40625 116754.9375 \n",
      "err_prefin 86283.296875\n",
      "err_fin    45027.61328125\n",
      "err_fin2    21936.904296875\n",
      "sparsity check 0.5904744466145834\n",
      "layer2.0.conv2 0.5904880099826388 torch.Size([128, 128, 3, 3]) 40280.140625 21936.904296875 \n",
      "err_prefin 10868.2353515625\n",
      "err_fin    3316.9794921875\n",
      "err_fin2    3273.54541015625\n",
      "sparsity check 0.65606689453125\n",
      "layer2.0.conv3 0.656097412109375 torch.Size([512, 128, 1, 1]) 7861.669921875 3273.54541015625 \n",
      "err_prefin 1149402.75\n",
      "err_fin    307029.75\n",
      "err_fin2    281336.1875\n",
      "sparsity check 0.20587158203125\n",
      "layer2.0.downsample.0 0.2058868408203125 torch.Size([512, 256, 1, 1]) 534038.875 281336.1875 \n",
      "err_prefin 17991.56640625\n",
      "err_fin    6085.1357421875\n",
      "err_fin2    3241.47119140625\n",
      "sparsity check 0.4782562255859375\n",
      "layer2.1.conv1 0.4782867431640625 torch.Size([128, 512, 1, 1]) 7349.62890625 3241.47119140625 \n",
      "err_prefin 993926.8125\n",
      "err_fin    159554.5\n",
      "err_fin2    84001.703125\n",
      "sparsity check 0.20587158203125\n",
      "layer2.1.conv2 0.20588514539930555 torch.Size([128, 128, 3, 3]) 48617.76953125 84001.703125 bad\n",
      "err_prefin 510674.8125\n",
      "err_fin    338430.53125\n",
      "err_fin2    326936.375\n",
      "sparsity check 0.109375\n",
      "layer2.1.conv3 0.109405517578125 torch.Size([512, 128, 1, 1]) 274602.25 326936.375 bad\n",
      "err_prefin 805693.625\n",
      "err_fin    344825.5625\n",
      "err_fin2    202001.84375\n",
      "sparsity check 0.2823944091796875\n",
      "layer2.2.conv1 0.2824249267578125 torch.Size([128, 512, 1, 1]) 222112.25 202001.84375 \n",
      "err_prefin 320674.375\n",
      "err_fin    138340.796875\n",
      "err_fin2    67606.828125\n",
      "sparsity check 0.4304470486111111\n",
      "layer2.2.conv2 0.4304606119791667 torch.Size([128, 128, 3, 3]) 85943.8828125 67606.828125 \n",
      "err_prefin 229127.96875\n",
      "err_fin    107335.0\n",
      "err_fin2    105811.3125\n",
      "sparsity check 0.3486328125\n",
      "layer2.2.conv3 0.348663330078125 torch.Size([512, 128, 1, 1]) 197878.90625 105811.3125 \n",
      "err_prefin 7027040.0\n",
      "err_fin    3226210.0\n",
      "err_fin2    2965640.5\n",
      "sparsity check 0.071746826171875\n",
      "layer2.3.conv1 0.07177734375 torch.Size([128, 512, 1, 1]) 2963943.75 2965640.5 bad\n",
      "err_prefin 46404.70703125\n",
      "err_fin    23681.744140625\n",
      "err_fin2    12207.75390625\n",
      "sparsity check 0.6560872395833334\n",
      "layer2.3.conv2 0.6560940212673612 torch.Size([128, 128, 3, 3]) 25688.31640625 12207.75390625 \n",
      "err_prefin 14176.36328125\n",
      "err_fin    7396.99609375\n",
      "err_fin2    7317.49072265625\n",
      "sparsity check 0.65606689453125\n",
      "layer2.3.conv3 0.656097412109375 torch.Size([512, 128, 1, 1]) 18818.65625 7317.49072265625 \n",
      "err_prefin 92548.4296875\n",
      "err_fin    37069.046875\n",
      "err_fin2    26354.83984375\n",
      "sparsity check 0.6560821533203125\n",
      "layer3.0.conv1 0.656097412109375 torch.Size([256, 512, 1, 1]) 100137.265625 26354.83984375 \n",
      "err_prefin 523231.0\n",
      "err_fin    338168.75\n",
      "err_fin2    227764.03125\n",
      "sparsity check 0.2824249267578125\n",
      "layer3.0.conv2 0.2824283175998264 torch.Size([256, 256, 3, 3]) 312430.875 227764.03125 \n",
      "err_prefin 6988.478515625\n",
      "err_fin    1747.9176025390625\n",
      "err_fin2    1720.9501953125\n",
      "sparsity check 0.7289886474609375\n",
      "layer3.0.conv3 0.7289962768554688 torch.Size([1024, 256, 1, 1]) 6336.5849609375 1720.9501953125 \n",
      "err_prefin 7308.271484375\n",
      "err_fin    1943.276123046875\n",
      "err_fin2    1875.6319580078125\n",
      "sparsity check 0.7289943695068359\n",
      "layer3.0.downsample.0 0.7289981842041016 torch.Size([1024, 512, 1, 1]) 8350.31640625 1875.6319580078125 \n",
      "err_prefin 133787.984375\n",
      "err_fin    61663.75\n",
      "err_fin2    35690.6875\n",
      "sparsity check 0.3138008117675781\n",
      "layer3.1.conv1 0.3138084411621094 torch.Size([256, 1024, 1, 1]) 58538.73828125 35690.6875 \n",
      "err_prefin 211113.359375\n",
      "err_fin    106078.6953125\n",
      "err_fin2    58835.4453125\n",
      "sparsity check 0.3486735026041667\n",
      "layer3.1.conv2 0.3486768934461806 torch.Size([256, 256, 3, 3]) 77035.6875 58835.4453125 \n",
      "err_prefin 3732.7685546875\n",
      "err_fin    1528.9849853515625\n",
      "err_fin2    1505.3184814453125\n",
      "sparsity check 0.7289886474609375\n",
      "layer3.1.conv3 0.7289962768554688 torch.Size([1024, 256, 1, 1]) 5140.01171875 1505.3184814453125 \n",
      "err_prefin 15862.17578125\n",
      "err_fin    7437.0322265625\n",
      "err_fin2    4555.73095703125\n",
      "sparsity check 0.5904808044433594\n",
      "layer3.2.conv1 0.5904884338378906 torch.Size([256, 1024, 1, 1]) 11982.296875 4555.73095703125 \n",
      "err_prefin 574693.875\n",
      "err_fin    309389.75\n",
      "err_fin2    215710.5\n",
      "sparsity check 0.18529764811197916\n",
      "layer3.2.conv2 0.18530103895399305 torch.Size([256, 256, 3, 3]) 226082.46875 215710.5 \n",
      "err_prefin 18831.3515625\n",
      "err_fin    9512.8388671875\n",
      "err_fin2    9393.466796875\n",
      "sparsity check 0.53143310546875\n",
      "layer3.2.conv3 0.5314407348632812 torch.Size([1024, 256, 1, 1]) 26943.498046875 9393.466796875 \n",
      "err_prefin 13032.3662109375\n",
      "err_fin    6690.4755859375\n",
      "err_fin2    4279.998046875\n",
      "sparsity check 0.6560897827148438\n",
      "layer3.3.conv1 0.656097412109375 torch.Size([256, 1024, 1, 1]) 12841.90625 4279.998046875 \n",
      "err_prefin 36921.03125\n",
      "err_fin    21927.96484375\n",
      "err_fin2    13196.923828125\n",
      "sparsity check 0.5314364963107638\n",
      "layer3.3.conv2 0.5314398871527778 torch.Size([256, 256, 3, 3]) 25136.2265625 13196.923828125 \n",
      "err_prefin 24312.3046875\n",
      "err_fin    13796.65234375\n",
      "err_fin2    13658.080078125\n",
      "sparsity check 0.4782867431640625\n",
      "layer3.3.conv3 0.47829437255859375 torch.Size([1024, 256, 1, 1]) 34852.06640625 13658.080078125 \n",
      "err_prefin 48055.1953125\n",
      "err_fin    28429.37109375\n",
      "err_fin2    20127.453125\n",
      "sparsity check 0.53143310546875\n",
      "layer3.4.conv1 0.5314407348632812 torch.Size([256, 1024, 1, 1]) 50097.734375 20127.453125 \n",
      "err_prefin 376362.03125\n",
      "err_fin    219965.09375\n",
      "err_fin2    158764.53125\n",
      "sparsity check 0.2287631564670139\n",
      "layer3.4.conv2 0.2287665473090278 torch.Size([256, 256, 3, 3]) 186770.6875 158764.53125 \n",
      "err_prefin 918.0406494140625\n",
      "err_fin    413.91326904296875\n",
      "err_fin2    406.43609619140625\n",
      "sparsity check 0.8099899291992188\n",
      "layer3.4.conv3 0.80999755859375 torch.Size([1024, 256, 1, 1]) 1237.635498046875 406.43609619140625 \n",
      "err_prefin 128751.1875\n",
      "err_fin    82834.984375\n",
      "err_fin2    63508.4296875\n",
      "sparsity check 0.43045806884765625\n",
      "layer3.5.conv1 0.4304656982421875 torch.Size([256, 1024, 1, 1]) 129100.71875 63508.4296875 \n",
      "err_prefin 60773.03125\n",
      "err_fin    37502.5078125\n",
      "err_fin2    24327.388671875\n",
      "sparsity check 0.4782918294270833\n",
      "layer3.5.conv2 0.4782952202690972 torch.Size([256, 256, 3, 3]) 45262.74609375 24327.388671875 \n",
      "err_prefin 13072.6982421875\n",
      "err_fin    6469.1328125\n",
      "err_fin2    6396.455078125\n",
      "sparsity check 0.5904808044433594\n",
      "layer3.5.conv3 0.5904884338378906 torch.Size([1024, 256, 1, 1]) 18919.142578125 6396.455078125 \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "err_prefin 51164.5703125\n",
      "err_fin    31196.986328125\n",
      "err_fin2    26808.36328125\n",
      "sparsity check 0.590484619140625\n",
      "layer4.0.conv1 0.5904884338378906 torch.Size([512, 1024, 1, 1]) 115178.9765625 26808.36328125 \n",
      "err_prefin 18906.97265625\n",
      "err_fin    13446.25390625\n",
      "err_fin2    9080.3759765625\n",
      "sparsity check 0.5314398871527778\n",
      "layer4.0.conv2 0.5314407348632812 torch.Size([512, 512, 3, 3]) 23763.48046875 9080.3759765625 \n",
      "err_prefin 1069.96875\n",
      "err_fin    379.2816162109375\n",
      "err_fin2    373.0364990234375\n",
      "sparsity check 0.80999755859375\n",
      "layer4.0.conv3 0.8099994659423828 torch.Size([2048, 512, 1, 1]) 1548.8809814453125 373.0364990234375 \n",
      "err_prefin 1953.743896484375\n",
      "err_fin    930.8853759765625\n",
      "err_fin2    902.2492065429688\n",
      "sparsity check 0.7289986610412598\n",
      "layer4.0.downsample.0 0.7289996147155762 torch.Size([2048, 1024, 1, 1]) 4873.19580078125 902.2492065429688 \n",
      "err_prefin 54108.1640625\n",
      "err_fin    22543.6640625\n",
      "err_fin2    13911.64453125\n",
      "sparsity check 0.656097412109375\n",
      "layer4.1.conv1 0.6560993194580078 torch.Size([512, 2048, 1, 1]) 46442.94140625 13911.64453125 \n",
      "err_prefin 1275.6796875\n",
      "err_fin    748.1761474609375\n",
      "err_fin2    450.983154296875\n",
      "sparsity check 0.8099992540147569\n",
      "layer4.1.conv2 0.8099996778700087 torch.Size([512, 512, 3, 3]) 1407.343505859375 450.983154296875 \n",
      "err_prefin 91.6263427734375\n",
      "err_fin    10.563888549804688\n",
      "err_fin2    10.082448959350586\n",
      "sparsity check 0.9999980926513672\n",
      "layer4.1.conv3 1.0 torch.Size([2048, 512, 1, 1]) 0.0 10.082448959350586 bad\n",
      "err_prefin 89887.8203125\n",
      "err_fin    34042.484375\n",
      "err_fin2    18228.36328125\n",
      "sparsity check 0.7289972305297852\n",
      "layer4.2.conv1 0.728999137878418 torch.Size([512, 2048, 1, 1]) 61554.19140625 18228.36328125 \n",
      "err_prefin 4044.995849609375\n",
      "err_fin    2933.62548828125\n",
      "err_fin2    1780.184814453125\n",
      "sparsity check 0.6560991075303819\n",
      "layer4.2.conv2 0.6560999552408854 torch.Size([512, 512, 3, 3]) 7937.763671875 1780.184814453125 \n",
      "err_prefin 854.5765380859375\n",
      "err_fin    268.6675109863281\n",
      "err_fin2    263.4511413574219\n",
      "sparsity check 0.80999755859375\n",
      "layer4.2.conv3 0.8099994659423828 torch.Size([2048, 512, 1, 1]) 1210.29736328125 263.4511413574219 \n"
     ]
    }
   ],
   "source": [
    "sd_pruned = modelp.state_dict()\n",
    "out_admm = {}\n",
    "\n",
    "for n, m in model_orig.named_modules():\n",
    "    if type(m) == nn.Conv2d and m.weight.shape[1] > 3:\n",
    "        w_pruned = sd_pruned[n+\".weight\"].flatten(1)\n",
    "        sparsity = (w_pruned != 0).sum().item() / w_pruned.numel()\n",
    "        w_orig = m.weight.flatten(1)\n",
    "        w_admm, facts = factorize(m.XX, w_orig, sparsity)\n",
    "        e1 = (w_orig - w_pruned).matmul(m.XX).matmul((w_orig - w_pruned).T).diag().sum().item()\n",
    "        e2 = (w_orig - w_admm).matmul(m.XX).matmul((w_orig - w_admm).T).diag().sum().item()\n",
    "        print(n, sparsity, m.weight.shape, \n",
    "              e1,\n",
    "              e2,\n",
    "              \"bad\" if e1 < e2 else \"\"\n",
    "             )\n",
    "        out_admm[n] = (w_admm.reshape(w_pruned.shape), facts)\n",
    "        #m.XX = None\n",
    "        \n",
    "for n, m in modelp.named_modules():\n",
    "    if n in out_admm:\n",
    "        m.weight.data = out_admm[n][0].reshape(m.weight.shape)\n",
    "        m.weight.facts = out_admm[n][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "99b05797",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating ...\n",
      "75.65\n"
     ]
    }
   ],
   "source": [
    "test(modelp, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8b943db1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batchnorm tuning ...\n",
      "0.7961734533309937\n",
      "000\n",
      "001\n",
      "002\n",
      "003\n",
      "004\n",
      "005\n",
      "006\n",
      "007\n",
      "008\n",
      "009\n",
      "010\n",
      "011\n",
      "012\n",
      "013\n",
      "014\n",
      "015\n",
      "016\n",
      "017\n",
      "018\n",
      "019\n",
      "020\n",
      "021\n",
      "022\n",
      "023\n",
      "024\n",
      "025\n",
      "026\n",
      "027\n",
      "028\n",
      "029\n",
      "030\n",
      "031\n",
      "032\n",
      "033\n",
      "034\n",
      "035\n",
      "036\n",
      "037\n",
      "038\n",
      "039\n",
      "040\n",
      "041\n",
      "042\n",
      "043\n",
      "044\n",
      "045\n",
      "046\n",
      "047\n",
      "048\n",
      "049\n",
      "050\n",
      "051\n",
      "052\n",
      "053\n",
      "054\n",
      "055\n",
      "056\n",
      "057\n",
      "058\n",
      "059\n",
      "060\n",
      "061\n",
      "062\n",
      "063\n",
      "064\n",
      "065\n",
      "066\n",
      "067\n",
      "068\n",
      "069\n",
      "070\n",
      "071\n",
      "072\n",
      "073\n",
      "074\n",
      "075\n",
      "076\n",
      "077\n",
      "078\n",
      "079\n",
      "080\n",
      "081\n",
      "082\n",
      "083\n",
      "084\n",
      "085\n",
      "086\n",
      "087\n",
      "088\n",
      "089\n",
      "090\n",
      "091\n",
      "092\n",
      "093\n",
      "094\n",
      "095\n",
      "096\n",
      "097\n",
      "098\n",
      "099\n",
      "0.7827620953321457\n"
     ]
    }
   ],
   "source": [
    "print('Batchnorm tuning ...')\n",
    "\n",
    "loss = 0\n",
    "with torch.no_grad():\n",
    "    for batch in dataloader:\n",
    "        loss += run(modelp, batch, loss=True)\n",
    "print(loss / 1024)\n",
    "\n",
    "batchnorms = find_layers(modelp, [nn.BatchNorm2d])\n",
    "for bn in batchnorms.values():\n",
    "    bn.reset_running_stats()\n",
    "    bn.momentum = .1\n",
    "modelp.train()\n",
    "with torch.no_grad():\n",
    "    i = 0\n",
    "    while i < 100:\n",
    "        for batch in dataloader:\n",
    "            if i == 100:\n",
    "                break\n",
    "            print('%03d' % i)\n",
    "            run(modelp, batch)\n",
    "            i += 1\n",
    "modelp.eval()\n",
    "\n",
    "loss = 0\n",
    "with torch.no_grad():\n",
    "    for batch in dataloader:\n",
    "        loss += run(modelp, batch, loss=True)\n",
    "print(loss / 1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "daaca80f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating ...\n",
      "75.78\n"
     ]
    }
   ],
   "source": [
    "test(modelp, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "0b5404a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv1.weight 0.38732993197278914\n",
      "layer1.0.conv1.weight 0.98388671875 ff\n",
      "layer1.0.conv2.weight 0.4303927951388889 ff\n",
      "layer1.0.conv3.weight 0.2540283203125 ff\n",
      "layer1.0.downsample.0.weight 0.4781494140625 ff\n",
      "layer1.0.downsample.1.weight 1.0\n",
      "layer1.1.conv1.weight 0.3485107421875 ff\n",
      "layer1.1.conv2.weight 0.4782172309027778 ff\n",
      "layer1.1.conv3.weight 0.59033203125 ff\n",
      "layer1.2.conv1.weight 0.12139892578125 ff\n",
      "layer1.2.conv2.weight 0.2823621961805556 ff\n",
      "layer1.2.conv3.weight 0.65594482421875 ff\n",
      "layer2.0.conv1.weight 0.47821044921875 ff\n",
      "layer2.0.conv2.weight 0.5904744466145834 ff\n",
      "layer2.0.conv3.weight 0.65606689453125 ff\n",
      "layer2.0.downsample.0.weight 0.20587158203125 ff\n",
      "layer2.0.downsample.1.weight 1.0\n",
      "layer2.1.conv1.weight 0.4782562255859375 ff\n",
      "layer2.1.conv2.weight 0.20587158203125 ff\n",
      "layer2.1.conv3.weight 0.109375 ff\n",
      "layer2.2.conv1.weight 0.2823944091796875 ff\n",
      "layer2.2.conv2.weight 0.4304470486111111 ff\n",
      "layer2.2.conv3.weight 0.3486328125 ff\n",
      "layer2.3.conv1.weight 0.071746826171875 ff\n",
      "layer2.3.conv2.weight 0.6560872395833334 ff\n",
      "layer2.3.conv3.weight 0.65606689453125 ff\n",
      "layer3.0.conv1.weight 0.6560821533203125 ff\n",
      "layer3.0.conv2.weight 0.2824249267578125 ff\n",
      "layer3.0.conv3.weight 0.7289886474609375 ff\n",
      "layer3.0.downsample.0.weight 0.7289943695068359 ff\n",
      "layer3.0.downsample.1.weight 1.0\n",
      "layer3.1.conv1.weight 0.3138008117675781 ff\n",
      "layer3.1.conv2.weight 0.3486735026041667 ff\n",
      "layer3.1.conv3.weight 0.7289886474609375 ff\n",
      "layer3.2.conv1.weight 0.5904808044433594 ff\n",
      "layer3.2.conv2.weight 0.18529764811197916 ff\n",
      "layer3.2.conv3.weight 0.53143310546875 ff\n",
      "layer3.3.conv1.weight 0.6560897827148438 ff\n",
      "layer3.3.conv2.weight 0.5314364963107638 ff\n",
      "layer3.3.conv3.weight 0.4782867431640625 ff\n",
      "layer3.4.conv1.weight 0.53143310546875 ff\n",
      "layer3.4.conv2.weight 0.2287631564670139 ff\n",
      "layer3.4.conv3.weight 0.8099899291992188 ff\n",
      "layer3.5.conv1.weight 0.43045806884765625 ff\n",
      "layer3.5.conv2.weight 0.4782918294270833 ff\n",
      "layer3.5.conv3.weight 0.5904808044433594 ff\n",
      "layer4.0.conv1.weight 0.590484619140625 ff\n",
      "layer4.0.conv2.weight 0.5314398871527778 ff\n",
      "layer4.0.conv3.weight 0.80999755859375 ff\n",
      "layer4.0.downsample.0.weight 0.7289986610412598 ff\n",
      "layer4.0.downsample.1.weight 1.0\n",
      "layer4.1.conv1.weight 0.656097412109375 ff\n",
      "layer4.1.conv2.weight 0.8099992540147569 ff\n",
      "layer4.1.conv3.weight 0.9999980926513672 ff\n",
      "layer4.2.conv1.weight 0.7289972305297852 ff\n",
      "layer4.2.conv2.weight 0.6560991075303819 ff\n",
      "layer4.2.conv3.weight 0.80999755859375 ff\n",
      "fc.weight 1.0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(16740648, 25506752)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_nz = 0\n",
    "total = 0\n",
    "\n",
    "for n, p in modelp.named_parameters():\n",
    "    if \"weight\" not in n or \"bn\" in n:\n",
    "        continue\n",
    "    \n",
    "    if hasattr(p, \"facts\"):\n",
    "        ff = (p.facts[0] != 0).sum().item() + (p.facts[1] != 0).sum().item() #(p != 0).sum().item()\n",
    "        total_nz += ff\n",
    "        print(n, ff / p.numel(), \"ff\")\n",
    "    else:\n",
    "        total_nz += (p != 0).sum().item()\n",
    "        print(n, (p != 0).sum().item() / p.numel())\n",
    "    total += p.numel()\n",
    "    \n",
    "total_nz, total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "66a47ed7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer3[4].conv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "314eccbf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([256, 256]), torch.Size([256, 2304]))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer3[4].conv2.weight.facts[0].shape, modelp.layer3[4].conv2.weight.facts[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ed541870",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(device(type='cuda', index=0), device(type='cuda', index=0))"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f1 = nn.Conv2d(256, 256, 1, bias=False)\n",
    "f2 = nn.Conv2d(256, 256, 3, padding=1, bias=False)\n",
    "f1.weight.data = modelp.layer3[4].conv2.weight.facts[0].reshape(f1.weight.shape)\n",
    "f2.weight.data = modelp.layer3[4].conv2.weight.facts[1].reshape(f2.weight.shape)\n",
    "f1 = f1.cuda()\n",
    "f2 = f2.cuda()\n",
    "f1.weight.device, f2.weight.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "617b70bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0039, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>)\n"
     ]
    }
   ],
   "source": [
    "xx = torch.randn(10,256,15,15).cuda()\n",
    "with torch.amp.autocast(\"cuda\"):\n",
    "    o1 = modelp.layer3[4].conv2(xx)\n",
    "    ot = f2(xx)\n",
    "    o2 = f1(ot)\n",
    "    print((o1 - o2).abs().max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "70a285d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layer1.0\n",
      "layer1.1\n",
      "layer1.2\n",
      "layer2.0\n",
      "layer2.1\n",
      "layer2.2\n",
      "layer2.3\n",
      "layer3.0\n",
      "layer3.1\n",
      "layer3.2\n",
      "layer3.3\n",
      "layer3.4\n",
      "layer3.5\n",
      "layer4.0\n",
      "layer4.1\n",
      "layer4.2\n"
     ]
    }
   ],
   "source": [
    "def boo(m, i, o):\n",
    "    print(\"boo\", i[0].shape)\n",
    "\n",
    "for n, m in modelp.named_modules():\n",
    "    if \"Bottleneck\" in str(type(m)):\n",
    "        print(n)\n",
    "        if hasattr(m, \"conv1b\"):\n",
    "            m.conv1 = m.conv1b\n",
    "        ff = m.conv1.weight.facts\n",
    "        m.conv1b = m.conv1\n",
    "        m.conv1 = nn.Sequential(\n",
    "            nn.Conv2d(m.conv1b.in_channels, m.conv1b.out_channels, 1, bias=False),\n",
    "            nn.Conv2d(m.conv1b.out_channels, m.conv1b.out_channels, 1, bias=False)\n",
    "        )\n",
    "        m.conv1[0].weight.data = ff[1].reshape(m.conv1[0].weight.shape)\n",
    "        m.conv1[1].weight.data = ff[0].reshape(m.conv1[1].weight.shape)\n",
    "        m.conv1.cuda()\n",
    "        \n",
    "        #print(n)\n",
    "        if hasattr(m, \"conv2b\"):\n",
    "            m.conv2 = m.conv2b\n",
    "        ff = m.conv2.weight.facts\n",
    "        m.conv2b = m.conv2\n",
    "        m.conv2 = nn.Sequential(\n",
    "            nn.Conv2d(m.conv2b.in_channels, m.conv2b.out_channels, 3, padding=1, stride=m.conv2b.stride, bias=False),\n",
    "            nn.Conv2d(m.conv2b.out_channels, m.conv2b.out_channels, 1, bias=False)\n",
    "        )\n",
    "        #m.conv2[0].register_forward_hook(boo)\n",
    "        m.conv2[0].weight.data = ff[1].reshape(m.conv2[0].weight.shape)\n",
    "        m.conv2[1].weight.data = ff[0].reshape(m.conv2[1].weight.shape)\n",
    "        m.conv2.cuda()\n",
    "        \n",
    "        if hasattr(m, \"conv3b\"):\n",
    "            m.conv3 = m.conv3b\n",
    "        ff = m.conv3.weight.facts\n",
    "        m.conv3b = m.conv3\n",
    "        m.conv3 = nn.Sequential(\n",
    "            nn.Conv2d(m.conv3b.in_channels, m.conv3b.in_channels, 1, bias=False),\n",
    "            nn.Conv2d(m.conv3b.in_channels, m.conv3b.out_channels, 1, bias=False)\n",
    "        )\n",
    "        m.conv3[0].weight.data = ff[1].reshape(m.conv3[0].weight.shape)\n",
    "        m.conv3[1].weight.data = ff[0].reshape(m.conv3[1].weight.shape)\n",
    "        m.conv3.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "177807d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating ...\n",
      "75.78\n"
     ]
    }
   ],
   "source": [
    "test(modelp, testloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "32c55597",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       ")"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer2[0].conv1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "bc956b1e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "  (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       ")"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer2[0].conv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "08523486",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "  (1): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       ")"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modelp.layer2[0].conv3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12ffb56e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
