{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hydra import initialize, compose\n",
    "import dotenv\n",
    "import os\n",
    "import pathlib\n",
    "import torch\n",
    "\n",
    "from rigl_torch.utils.checkpoint import Checkpoint\n",
    "from rigl_torch.models import ModelFactory\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mod(run_id: str, device):\n",
    "    with initialize(\"../configs\", version_base=\"1.2.0\"):\n",
    "        cfg = compose(\n",
    "            \"config.yaml\",\n",
    "            overrides=[\n",
    "                \"compute.distributed=False\",\n",
    "                \"dataset=imagenet\",\n",
    "                \"model=vit\",\n",
    "                f\"experiment.run_id={run_id}\",\n",
    "                \"training.batch_size=2\",\n",
    "            ],\n",
    "        )\n",
    "    dotenv.load_dotenv(\"../.env\", override=True)\n",
    "    os.environ[\"IMAGE_NET_PATH\"]\n",
    "    checkpoint_dir = pathlib.Path(f\"../artifacts/checkpoints/20230601_{run_id}\")\n",
    "    checkpoint = Checkpoint.load_best_checkpoint(checkpoint_dir=checkpoint_dir)\n",
    "    model_state = checkpoint.model\n",
    "    model = ModelFactory.load_model(\n",
    "        model=cfg.model.name, dataset=cfg.dataset.name, diet=cfg.rigl.diet\n",
    "    )\n",
    "    model.to(device)\n",
    "    try:\n",
    "        model.load_state_dict(model_state)\n",
    "    except RuntimeError:\n",
    "        model_state = (\n",
    "            checkpoint.get_single_process_model_state_from_distributed_state()\n",
    "        )\n",
    "        model.load_state_dict(model_state)\n",
    "    return model.get_submodule(\"encoder.layers.encoder_layer_11.mlp.3\")\n",
    "\n",
    "\n",
    "__RUN_IDS = {90: \"nrblbn15\"}\n",
    "\n",
    "# t_fc = get_mod(__RUN_IDS[90], \"cpu\") # Run me if you have the artifact on this device\n",
    "\n",
    "with open(\"../artifacts/trained_vit_layers/vit16-mlp-layer-90-torch.pkl\", \"rb\") as handle:  # TODO: try skinnier layer\n",
    "    t_fc = torch.load(handle)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "class CondensedLinearFineGrained(nn.Module):\n",
    "    def __init__(\n",
    "        self, module: nn.Module, dtype: torch.typename = torch.float32\n",
    "    ):\n",
    "        super().__init__()\n",
    "        if dtype is None:\n",
    "            dtype = module.weight.dtype\n",
    "        with torch.no_grad():\n",
    "            active_neuron_idx = module.weight.sum(dim=1) != 0\n",
    "            fine_grained_idx = (module.weight[active_neuron_idx] != 0).to(\n",
    "                torch.bool\n",
    "            )\n",
    "            _, self.input_mask = fine_grained_idx.nonzero(as_tuple=True)\n",
    "            self.input_mask = self.input_mask.reshape(\n",
    "                shape=(module.weight[active_neuron_idx].shape[0], -1)\n",
    "            ).to(torch.uint8)\n",
    "            weight = module.weight[active_neuron_idx].detach().type(dtype)\n",
    "            self.condensed_weight = nn.Parameter(\n",
    "                torch.clone(\n",
    "                    weight[fine_grained_idx]\n",
    "                    .reshape(shape=(weight.shape[0], -1))\n",
    "                    .detach()\n",
    "                    .type(dtype)\n",
    "                ),\n",
    "                requires_grad=False,\n",
    "            )\n",
    "            if hasattr(module, \"bias\"):\n",
    "                self.bias = nn.Parameter(\n",
    "                    torch.clone(\n",
    "                        module.bias[active_neuron_idx].detach().type(dtype)\n",
    "                    ),\n",
    "                    requires_grad=False,\n",
    "                )\n",
    "            else:\n",
    "                self.register_parameter(\"bias\", None)\n",
    "\n",
    "    def forward(self, input: torch.Tensor) -> torch.Tensor:\n",
    "        return (\n",
    "            torch.sum(\n",
    "                self.condensed_weight * input[..., self.input_mask],\n",
    "                dim=input.dim(),\n",
    "            )\n",
    "            + self.bias\n",
    "        )\n",
    "\n",
    "condensed_linear = CondensedLinearFineGrained(t_fc, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Linear(in_features=768, out_features=3072, bias=True)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t_fc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 768])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(1, 768)\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "77.6 µs ± 1.86 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "t_fc(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mike/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/cuda/__init__.py:107: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:109.)\n",
      "  return torch._C._cuda_getDeviceCount() > 0\n"
     ]
    },
    {
     "ename": "Unsupported",
     "evalue": "dynamic shape operator: aten.index.Tensor\n\nfrom user code:\n   File \"/tmp/ipykernel_2951702/2667213753.py\", line 41, in forward\n    self.condensed_weight * input[..., self.input_mask],\n\nSet torch._dynamo.config.verbose=True for more information\n\n\nYou can suppress this exception and fall back to eager by setting:\n    torch._dynamo.config.suppress_errors = True\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mDynamicOutputShapeException\u001b[0m               Traceback (most recent call last)",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:1194\u001b[0m, in \u001b[0;36mrun_node\u001b[0;34m(output_graph, node, args, kwargs, nnmodule)\u001b[0m\n\u001b[1;32m   1193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m op \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcall_function\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m-> 1194\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mnode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtarget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1195\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m op \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcall_method\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/utils/_stats.py:20\u001b[0m, in \u001b[0;36mcount.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     19\u001b[0m simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m=\u001b[39m simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 20\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:987\u001b[0m, in \u001b[0;36mFakeTensorMode.__torch_dispatch__\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m    986\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 987\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdispatch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    988\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1162\u001b[0m, in \u001b[0;36mFakeTensorMode.dispatch\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m   1161\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m run_impl_check(func):\n\u001b[0;32m-> 1162\u001b[0m     op_impl_out \u001b[38;5;241m=\u001b[39m \u001b[43mop_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1163\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m op_impl_out \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mNotImplemented\u001b[39m:\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:453\u001b[0m, in \u001b[0;36mindex_tensor\u001b[0;34m(fake_mode, func, *args, **kwargs)\u001b[0m\n\u001b[1;32m    450\u001b[0m \u001b[38;5;129m@register_op_impl\u001b[39m(aten\u001b[38;5;241m.\u001b[39mindex\u001b[38;5;241m.\u001b[39mTensor)\n\u001b[1;32m    451\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mindex_tensor\u001b[39m(fake_mode, func, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    452\u001b[0m     \u001b[38;5;66;03m# dynamic shape op if indices are bool/uint8\u001b[39;00m\n\u001b[0;32m--> 453\u001b[0m     \u001b[43mcheck_no_bool_index_tensors\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    455\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:432\u001b[0m, in \u001b[0;36mcheck_no_bool_index_tensors\u001b[0;34m(func, self, indices)\u001b[0m\n\u001b[1;32m    431\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m index \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m index\u001b[38;5;241m.\u001b[39mdtype \u001b[38;5;129;01min\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39mbool, torch\u001b[38;5;241m.\u001b[39muint8):\n\u001b[0;32m--> 432\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m DynamicOutputShapeException(func)\n",
      "\u001b[0;31mDynamicOutputShapeException\u001b[0m: aten.index.Tensor",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:1152\u001b[0m, in \u001b[0;36mget_fake_value\u001b[0;34m(node, tx)\u001b[0m\n\u001b[1;32m   1151\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m tx\u001b[38;5;241m.\u001b[39mfake_mode, enable_python_dispatcher():\n\u001b[0;32m-> 1152\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mwrap_fake_exception\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1153\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;28;43;01mlambda\u001b[39;49;00m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mrun_node\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moutput\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnnmodule\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1154\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1155\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m Unsupported:\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:808\u001b[0m, in \u001b[0;36mwrap_fake_exception\u001b[0;34m(fn)\u001b[0m\n\u001b[1;32m    807\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 808\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    809\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m UnsupportedFakeTensorException \u001b[38;5;28;01mas\u001b[39;00m e:\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:1153\u001b[0m, in \u001b[0;36mget_fake_value.<locals>.<lambda>\u001b[0;34m()\u001b[0m\n\u001b[1;32m   1151\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m tx\u001b[38;5;241m.\u001b[39mfake_mode, enable_python_dispatcher():\n\u001b[1;32m   1152\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m wrap_fake_exception(\n\u001b[0;32m-> 1153\u001b[0m             \u001b[38;5;28;01mlambda\u001b[39;00m: \u001b[43mrun_node\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moutput\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnnmodule\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1154\u001b[0m         )\n\u001b[1;32m   1155\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m Unsupported:\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:1206\u001b[0m, in \u001b[0;36mrun_node\u001b[0;34m(output_graph, node, args, kwargs, nnmodule)\u001b[0m\n\u001b[1;32m   1205\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m-> 1206\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m   1207\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFailed running \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mop\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnode\u001b[38;5;241m.\u001b[39mtarget\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m(*\u001b[39m\u001b[38;5;132;01m{\u001b[39;00margs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, **\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkwargs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m):\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m(scroll up for backtrace)\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1208\u001b[0m     ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m   1209\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(op)\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Failed running call_function <built-in function getitem>(*(FakeTensor(FakeTensor(..., device='meta', size=(1, 768)), cpu), (Ellipsis, FakeTensor(FakeTensor(..., device='meta', size=(1145, 206), dtype=torch.uint8), cpu))), **{}):\naten.index.Tensor\n(scroll up for backtrace)",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[0;31mUnsupported\u001b[0m                               Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[8], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m comp_mod \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcompile(condensed_linear, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmax-autotune\u001b[39m\u001b[38;5;124m\"\u001b[39m, fullgraph\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, backend\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minductor\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mcomp_mod\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:82\u001b[0m, in \u001b[0;36mOptimizedModule.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m     81\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 82\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdynamo_ctx\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_orig_mod\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:209\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    207\u001b[0m dynamic_ctx\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__enter__\u001b[39m()\n\u001b[1;32m    208\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 209\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    210\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    211\u001b[0m     set_eval_frame(prior)\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:337\u001b[0m, in \u001b[0;36mcatch_errors_wrapper.<locals>.catch_errors\u001b[0;34m(frame, cache_size)\u001b[0m\n\u001b[1;32m    334\u001b[0m             \u001b[38;5;28;01mreturn\u001b[39;00m hijacked_callback(frame, cache_size, hooks)\n\u001b[1;32m    336\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m compile_lock:\n\u001b[0;32m--> 337\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcallback\u001b[49m\u001b[43m(\u001b[49m\u001b[43mframe\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcache_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhooks\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:104\u001b[0m, in \u001b[0;36mwrap_convert_context.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    102\u001b[0m torch\u001b[38;5;241m.\u001b[39mfx\u001b[38;5;241m.\u001b[39mgraph_module\u001b[38;5;241m.\u001b[39m_forward_from_src \u001b[38;5;241m=\u001b[39m fx_forward_from_src_skip_result\n\u001b[1;32m    103\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 104\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    105\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    106\u001b[0m     torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_set_grad_enabled(prior_grad_mode)\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:262\u001b[0m, in \u001b[0;36mconvert_frame_assert.<locals>._convert_frame_assert\u001b[0;34m(frame, cache_size, hooks)\u001b[0m\n\u001b[1;32m    259\u001b[0m \u001b[38;5;28;01mglobal\u001b[39;00m initial_grad_state\n\u001b[1;32m    260\u001b[0m initial_grad_state \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mis_grad_enabled()\n\u001b[0;32m--> 262\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_compile\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    263\u001b[0m \u001b[43m    \u001b[49m\u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    264\u001b[0m \u001b[43m    \u001b[49m\u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf_globals\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    265\u001b[0m \u001b[43m    \u001b[49m\u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf_locals\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    266\u001b[0m \u001b[43m    \u001b[49m\u001b[43mframe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf_builtins\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    267\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcompiler_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    268\u001b[0m \u001b[43m    \u001b[49m\u001b[43mone_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    269\u001b[0m \u001b[43m    \u001b[49m\u001b[43mexport\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    270\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhooks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    271\u001b[0m \u001b[43m    \u001b[49m\u001b[43mframe\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    272\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:163\u001b[0m, in \u001b[0;36mdynamo_timed.<locals>.dynamo_timed_inner.<locals>.time_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    161\u001b[0m     compilation_metrics[key] \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m    162\u001b[0m t0 \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 163\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    164\u001b[0m time_spent \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime() \u001b[38;5;241m-\u001b[39m t0\n\u001b[1;32m    165\u001b[0m \u001b[38;5;66;03m# print(f\"Dynamo timer: key={key}, latency={latency:.2f} sec\")\u001b[39;00m\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:324\u001b[0m, in \u001b[0;36m_compile\u001b[0;34m(code, globals, locals, builtins, compiler_fn, one_graph, export, hooks, frame)\u001b[0m\n\u001b[1;32m    322\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m attempt \u001b[38;5;129;01min\u001b[39;00m itertools\u001b[38;5;241m.\u001b[39mcount():\n\u001b[1;32m    323\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 324\u001b[0m         out_code \u001b[38;5;241m=\u001b[39m \u001b[43mtransform_code_object\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtransform\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    325\u001b[0m         orig_code_map[out_code] \u001b[38;5;241m=\u001b[39m code\n\u001b[1;32m    326\u001b[0m         \u001b[38;5;28;01mbreak\u001b[39;00m\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:445\u001b[0m, in \u001b[0;36mtransform_code_object\u001b[0;34m(code, transformations, safe)\u001b[0m\n\u001b[1;32m    442\u001b[0m instructions \u001b[38;5;241m=\u001b[39m cleaned_instructions(code, safe)\n\u001b[1;32m    443\u001b[0m propagate_line_nums(instructions)\n\u001b[0;32m--> 445\u001b[0m \u001b[43mtransformations\u001b[49m\u001b[43m(\u001b[49m\u001b[43minstructions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcode_options\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    446\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m clean_and_assemble_instructions(instructions, keys, code_options)[\u001b[38;5;241m1\u001b[39m]\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:311\u001b[0m, in \u001b[0;36m_compile.<locals>.transform\u001b[0;34m(instructions, code_options)\u001b[0m\n\u001b[1;32m    298\u001b[0m \u001b[38;5;28;01mnonlocal\u001b[39;00m output\n\u001b[1;32m    299\u001b[0m tracer \u001b[38;5;241m=\u001b[39m InstructionTranslator(\n\u001b[1;32m    300\u001b[0m     instructions,\n\u001b[1;32m    301\u001b[0m     code,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    309\u001b[0m     mutated_closure_cell_contents,\n\u001b[1;32m    310\u001b[0m )\n\u001b[0;32m--> 311\u001b[0m \u001b[43mtracer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    312\u001b[0m output \u001b[38;5;241m=\u001b[39m tracer\u001b[38;5;241m.\u001b[39moutput\n\u001b[1;32m    313\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m output \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1726\u001b[0m, in \u001b[0;36mInstructionTranslator.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1724\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m   1725\u001b[0m     _step_logger()(logging\u001b[38;5;241m.\u001b[39mINFO, \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorchdynamo start tracing \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mf_code\u001b[38;5;241m.\u001b[39mco_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 1726\u001b[0m     \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:576\u001b[0m, in \u001b[0;36mInstructionTranslatorBase.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    571\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    572\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput\u001b[38;5;241m.\u001b[39mpush_tx(\u001b[38;5;28mself\u001b[39m)\n\u001b[1;32m    573\u001b[0m     \u001b[38;5;28;01mwhile\u001b[39;00m (\n\u001b[1;32m    574\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minstruction_pointer \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    575\u001b[0m         \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput\u001b[38;5;241m.\u001b[39mshould_exit\n\u001b[0;32m--> 576\u001b[0m         \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    577\u001b[0m     ):\n\u001b[1;32m    578\u001b[0m         \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m    579\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m BackendCompilerFailed:\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:540\u001b[0m, in \u001b[0;36mInstructionTranslatorBase.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    538\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m, inst\u001b[38;5;241m.\u001b[39mopname):\n\u001b[1;32m    539\u001b[0m         unimplemented(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmissing: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00minst\u001b[38;5;241m.\u001b[39mopname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 540\u001b[0m     \u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minst\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopname\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43minst\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    542\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m inst\u001b[38;5;241m.\u001b[39mopname \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRETURN_VALUE\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    543\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m BackendCompilerFailed:\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:342\u001b[0m, in \u001b[0;36mbreak_graph_if_unsupported.<locals>.decorator.<locals>.wrapper\u001b[0;34m(self, inst)\u001b[0m\n\u001b[1;32m    340\u001b[0m reason \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    341\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 342\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minst\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    343\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m Unsupported \u001b[38;5;28;01mas\u001b[39;00m excp:\n\u001b[1;32m    344\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhas_backedge() \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshould_compile_partial_graph():\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:148\u001b[0m, in \u001b[0;36mstack_op.<locals>.impl\u001b[0;34m(self, inst)\u001b[0m\n\u001b[1;32m    146\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(fn)\n\u001b[1;32m    147\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mimpl\u001b[39m(\u001b[38;5;28mself\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInstructionTranslatorBase\u001b[39m\u001b[38;5;124m\"\u001b[39m, inst: Instruction):\n\u001b[0;32m--> 148\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpush(\u001b[43mfn_var\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall_function\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpopn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m)\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py:518\u001b[0m, in \u001b[0;36mBuiltinVariable.call_function\u001b[0;34m(self, tx, args, kwargs)\u001b[0m\n\u001b[1;32m    514\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m operator\u001b[38;5;241m.\u001b[39mtruediv \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[1;32m    515\u001b[0m             args[\u001b[38;5;241m0\u001b[39m], variables\u001b[38;5;241m.\u001b[39mUnspecializedPythonVariable\n\u001b[1;32m    516\u001b[0m         ):\n\u001b[1;32m    517\u001b[0m             args[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m=\u001b[39m args[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mconvert_to_constant(tx)\n\u001b[0;32m--> 518\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mwrap_fx_proxy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mproxy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    520\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m:\n\u001b[1;32m    521\u001b[0m     unimplemented(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpartial tensor op: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00margs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkwargs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:754\u001b[0m, in \u001b[0;36mwrap_fx_proxy\u001b[0;34m(tx, proxy, example_value, **options)\u001b[0m\n\u001b[1;32m    753\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrap_fx_proxy\u001b[39m(tx, proxy, example_value\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39moptions):\n\u001b[0;32m--> 754\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mwrap_fx_proxy_cls\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    755\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtarget_cls\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTensorVariable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    756\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    757\u001b[0m \u001b[43m        \u001b[49m\u001b[43mproxy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    758\u001b[0m \u001b[43m        \u001b[49m\u001b[43mexample_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexample_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    759\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    760\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:789\u001b[0m, in \u001b[0;36mwrap_fx_proxy_cls\u001b[0;34m(target_cls, tx, proxy, example_value, ignore_subclass, **options)\u001b[0m\n\u001b[1;32m    787\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m preserve_rng_state():\n\u001b[1;32m    788\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m example_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 789\u001b[0m         example_value \u001b[38;5;241m=\u001b[39m \u001b[43mget_fake_value\u001b[49m\u001b[43m(\u001b[49m\u001b[43mproxy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    791\u001b[0m     \u001b[38;5;66;03m# Handle recursive calls here\u001b[39;00m\n\u001b[1;32m    792\u001b[0m     \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(example_value, FakeTensor):\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/utils.py:1168\u001b[0m, in \u001b[0;36mget_fake_value\u001b[0;34m(node, tx)\u001b[0m\n\u001b[1;32m   1164\u001b[0m     unimplemented(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata dependent operator: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcause\u001b[38;5;241m.\u001b[39mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m   1165\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[1;32m   1166\u001b[0m     cause, torch\u001b[38;5;241m.\u001b[39m_subclasses\u001b[38;5;241m.\u001b[39mfake_tensor\u001b[38;5;241m.\u001b[39mDynamicOutputShapeException\n\u001b[1;32m   1167\u001b[0m ):\n\u001b[0;32m-> 1168\u001b[0m     \u001b[43munimplemented\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdynamic shape operator: \u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mcause\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunc\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1169\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[1;32m   1170\u001b[0m     cause, torch\u001b[38;5;241m.\u001b[39mfx\u001b[38;5;241m.\u001b[39mexperimental\u001b[38;5;241m.\u001b[39msymbolic_shapes\u001b[38;5;241m.\u001b[39mGuardOnDataDependentSymNode\n\u001b[1;32m   1171\u001b[0m ):\n\u001b[1;32m   1172\u001b[0m     unimplemented(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mguard on data-dependent symbolic int/float\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m~/condensed-sparsity/.venv/lib/python3.10/site-packages/torch/_dynamo/exc.py:71\u001b[0m, in \u001b[0;36munimplemented\u001b[0;34m(msg)\u001b[0m\n\u001b[1;32m     69\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21munimplemented\u001b[39m(msg: \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m     70\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m msg \u001b[38;5;241m!=\u001b[39m os\u001b[38;5;241m.\u001b[39menviron\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBREAK\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m---> 71\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m Unsupported(msg)\n",
      "\u001b[0;31mUnsupported\u001b[0m: dynamic shape operator: aten.index.Tensor\n\nfrom user code:\n   File \"/tmp/ipykernel_2951702/2667213753.py\", line 41, in forward\n    self.condensed_weight * input[..., self.input_mask],\n\nSet torch._dynamo.config.verbose=True for more information\n\n\nYou can suppress this exception and fall back to eager by setting:\n    torch._dynamo.config.suppress_errors = True\n"
     ]
    }
   ],
   "source": [
    "comp_mod = torch.compile(condensed_linear, mode=\"max-autotune\", fullgraph=True, backend=\"inductor\")\n",
    "comp_mod(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "98.7 µs ± 649 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "comp_mod(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "I0000 00:00:1696012059.441441 3998717 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n",
      "2023-09-29 12:27:39.452337: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:276] failed call to cuInit: CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE: forward compatibility was attempted on non supported HW\n",
      "2023-09-29 12:27:39.452809: E external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:312] kernel version 535.86.10 does not match DSO version 535.104.5 -- cannot find working devices in this configuration\n",
      "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "import jax\n",
    "from typing import Any, Callable, Sequence, Optional, Tuple, Union\n",
    "from jax import random, vmap, numpy as jnp\n",
    "import flax\n",
    "from flax import linen as nn\n",
    "import numpy as np\n",
    "from functools import partial\n",
    "\n",
    "# _dtype = jnp.bfloat16 # faster on gpu\n",
    "_dtype = jnp.float32 # faster on cpu @ batch size 1. slower at 64\n",
    "# t_fc = t_fc.to(torch.bfloat16) # try bf16, Time to beat (176micro for dense, 137 micro for fastest condensed)\n",
    "# conversion to jax/flax\n",
    "with torch.no_grad():\n",
    "    kernel = t_fc.weight.detach().cpu().numpy()\n",
    "    bias = t_fc.bias.detach().cpu().numpy()\n",
    "\n",
    "    # [outC, inC] -> [inC, outC]\n",
    "    kernel = jnp.transpose(kernel, (1, 0)).astype(_dtype)\n",
    "\n",
    "    key = random.key(0)\n",
    "    x = random.normal(key, (64, t_fc.in_features))\n",
    "\n",
    "    variables = {'params': {'kernel': kernel, 'bias': bias.astype(_dtype)}}\n",
    "    j_fc = nn.Dense(features=t_fc.out_features)\n",
    "    j_out = j_fc.apply(variables, x)\n",
    "\n",
    "    t_x = torch.from_numpy(np.array(x))\n",
    "    t_out = t_fc(t_x)\n",
    "    t_out = t_out.detach().cpu().numpy()\n",
    "\n",
    "    np.testing.assert_almost_equal(j_out, t_out, decimal=2)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_size = t_fc.in_features\n",
    "layer_width = t_fc.out_features\n",
    "batch_size = 16\n",
    "\n",
    "\n",
    "key = random.PRNGKey(42)\n",
    "key, subkey = random.split(key)\n",
    "x = random.normal(subkey, (batch_size, input_size), dtype=_dtype)\n",
    "x = jax.device_put(x)\n",
    "\n",
    "dense_layer = nn.Dense(features=layer_width, use_bias=True)\n",
    "dense_params = variables\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.87 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit dense_layer.apply(dense_params, x).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dense_fast = jax.jit(lambda x: dense_layer.apply(dense_params, x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[ 2.2639807e-12,  7.5584012e-01,  1.9262955e-12, ...,\n",
       "        -1.2273963e-12,  8.1965402e-02, -8.9613508e-12],\n",
       "       [ 2.2639807e-12, -2.0178687e-02,  1.9262955e-12, ...,\n",
       "        -1.2273963e-12,  1.0754541e+00, -8.9613508e-12],\n",
       "       [ 2.2639807e-12,  7.9683006e-01,  1.9262955e-12, ...,\n",
       "        -1.2273963e-12, -4.9104637e-01, -8.9613508e-12],\n",
       "       ...,\n",
       "       [ 2.2639807e-12,  3.0003309e-01,  1.9262955e-12, ...,\n",
       "        -1.2273963e-12, -1.6757858e-01, -8.9613508e-12],\n",
       "       [ 2.2639807e-12, -1.5575090e-02,  1.9262955e-12, ...,\n",
       "        -1.2273963e-12, -5.2756774e-01, -8.9613508e-12],\n",
       "       [ 2.2639807e-12, -2.2859134e-01,  1.9262955e-12, ...,\n",
       "        -1.2273963e-12,  1.3307133e-01, -8.9613508e-12]], dtype=float32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dense_fast(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "323 µs ± 7.51 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit dense_fast(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3072,)\n",
      "(768, 3072)\n",
      "(1145,)\n",
      "(1145, 206)\n",
      "(1145, 206)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'params': {'bias': None, 'indx_seqs': None, 'kernel': None}}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Flax condensed sparsity\n",
    "\n",
    "from numpy.typing import DTypeLike\n",
    "from jax.typing import ArrayLike\n",
    "from flax.core.scope import VariableDict\n",
    "from copy import deepcopy\n",
    "\n",
    "def condensed_param_converter(dense_params: VariableDict, dtype: Optional[DTypeLike]=None) -> VariableDict:\n",
    "    \"\"\"Convert dense tensor with sparse weights into condensed version\"\"\"\n",
    "    dense_params = deepcopy(dense_params)\n",
    "    kernel, bias = dense_params[\"params\"][\"kernel\"].T, dense_params[\"params\"][\"bias\"].T\n",
    "    # Without transpose here I found broadcasting issues in original condensed implementation\n",
    "    if dtype is None:\n",
    "        dtype = kernel.dtype\n",
    "    active_neuron_idx = _get_active_neuron_idx(kernel)\n",
    "    fine_grained_idx = _get_fine_grained_idx(kernel, active_neuron_idx)\n",
    "    struct_kernel = kernel[active_neuron_idx]\n",
    "    condensed_kernel = struct_kernel[fine_grained_idx].reshape(struct_kernel.shape[0], -1)\n",
    "\n",
    "    # TODO: Can speed-up the below, we used torch.nonzero(as_tuple=True)\n",
    "    # previously, need to translate the typical 2D tensor output from jax.nonzero into the same\n",
    "    # format. We don't really care about speed here for our purposes anyways\n",
    "    idxs = []\n",
    "    for neuron in fine_grained_idx:\n",
    "        idxs.append(jnp.argwhere(neuron!=0).flatten())\n",
    "    indx_seqs = jnp.stack(idxs)\n",
    "    return dict(\n",
    "        params=dict(\n",
    "            kernel=condensed_kernel,\n",
    "            bias=bias[active_neuron_idx],\n",
    "            indx_seqs=indx_seqs\n",
    "        )\n",
    "    )\n",
    "\n",
    "def _get_active_neuron_idx(kernel: ArrayLike) -> jax.Array:\n",
    "  # We find all-zero rows in first dimension of weight tensor\n",
    "  return kernel.sum(axis=list(range(1, kernel.ndim))) != 0\n",
    "\n",
    "\n",
    "def _get_fine_grained_idx(\n",
    "    kernel: ArrayLike, active_neuron_idx: ArrayLike\n",
    ") -> jax.Array:\n",
    "    return (kernel[active_neuron_idx] != 0).astype(\"bool\")\n",
    "\n",
    "class CondensedLinear(nn.Module):\n",
    "    features: int\n",
    "    fan_in: int\n",
    "    kernel_init: Callable = nn.initializers.lecun_normal()\n",
    "    bias_init: Callable = nn.initializers.zeros_init()\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, input: ArrayLike) -> jax.Array:\n",
    "        kernel = self.param(\"kernel\", self.kernel_init, (self.features, self.fan_in))\n",
    "        bias = self.param(\"bias\", self.bias_init, (self.features,))\n",
    "        indx_seqs = self.param(\"indx_seqs\", self.kernel_init, (self.features, self.fan_in))\n",
    "        return jnp.sum(kernel * input[:, indx_seqs], axis=2) + bias\n",
    "\n",
    "\n",
    "condensed_params = condensed_param_converter(variables)\n",
    "jax.tree_util.tree_map(lambda x: print(x.shape), variables)\n",
    "jax.tree_util.tree_map(lambda x: print(x.shape), condensed_params)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1145, 206)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "condensed_params[\"params\"][\"kernel\"].shape # features, fan_in for condensed linear ctor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[ 0.75584006, -0.17063873,  1.6859026 , ...,  0.18504576,\n",
       "        -0.00192568,  0.08196551],\n",
       "       [-0.02017869, -0.71082234,  0.16039723, ...,  0.41155237,\n",
       "        -0.4476214 ,  1.0754542 ],\n",
       "       [ 0.7968303 , -0.10603629,  1.2908362 , ..., -0.32634115,\n",
       "        -0.48441118, -0.49104658],\n",
       "       ...,\n",
       "       [ 0.30003327, -0.17096956,  0.10126591, ...,  0.6488617 ,\n",
       "        -0.1293261 , -0.16757864],\n",
       "       [-0.01557497, -0.39440617,  0.21288827, ...,  0.32918108,\n",
       "        -0.12092257, -0.527568  ],\n",
       "       [-0.22859119, -0.54928684,  0.2176865 , ...,  0.11042877,\n",
       "         0.5686784 ,  0.13307133]], dtype=float32)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cl = CondensedLinear(*condensed_params[\"params\"][\"kernel\"].shape)\n",
    "cl_fast = jax.jit(lambda x: cl.apply(condensed_params, x))\n",
    "cl_fast(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.8 ms ± 377 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "cl_fast(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# key, subkey = random.split(key)\n",
    "# input = random.uniform(subkey, (batch_size, input_size), dtype=jnp.float32)\n",
    "# input = jax.device_put(input)\n",
    "\n",
    "# # Create mmore realistic indx seqs by randomly shuffling and sampling\n",
    "# indx_seqs_stack = []\n",
    "# for i in range(layer_width):\n",
    "#   key, subkey = random.split(key)\n",
    "#   key, subkey2 = random.split(key)\n",
    "#   indx_seqs_stack.append(jax.random.shuffle(subkey, jax.random.choice(subkey2, jnp.arange(input_size), (sparsity,))))\n",
    "# indx_seqs = jnp.stack(indx_seqs_stack)\n",
    "# indx_seqs = jax.device_put(indx_seqs)\n",
    "\n",
    "# key, subkey = random.split(key)\n",
    "# weights = random.uniform(subkey, (layer_width, sparsity))\n",
    "\n",
    "weights, bias, indx_seqs = condensed_params['params']['kernel'], condensed_params['params']['bias'], condensed_params['params']['indx_seqs']\n",
    "weights = jax.device_put(weights)\n",
    "bias = jax.device_put(bias)\n",
    "indx_seqs = jax.device_put(indx_seqs)\n",
    "input = x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(16, 768)\n",
      "(1145, 206)\n",
      "(1145, 206)\n"
     ]
    }
   ],
   "source": [
    "for a in [input, weights, indx_seqs]:\n",
    "  print(a.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[ 0.75584006, -0.17063873,  1.6859026 , ...,  0.18504576,\n",
       "        -0.00192568,  0.08196551],\n",
       "       [-0.02017869, -0.71082234,  0.16039723, ...,  0.41155237,\n",
       "        -0.4476214 ,  1.0754542 ],\n",
       "       [ 0.7968303 , -0.10603629,  1.2908362 , ..., -0.32634115,\n",
       "        -0.48441118, -0.49104658],\n",
       "       ...,\n",
       "       [ 0.30003327, -0.17096956,  0.10126591, ...,  0.6488617 ,\n",
       "        -0.1293261 , -0.16757864],\n",
       "       [-0.01557497, -0.39440617,  0.21288827, ...,  0.32918108,\n",
       "        -0.12092257, -0.527568  ],\n",
       "       [-0.22859119, -0.54928684,  0.2176865 , ...,  0.11042877,\n",
       "         0.5686784 ,  0.13307133]], dtype=float32)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def forward_orig(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:\n",
    "    return jnp.sum(weights * input[:, indx_seqs], axis=2) + bias\n",
    "\n",
    "forward_orig_fast = jax.jit(forward_orig)\n",
    "forward_orig_fast(input, weights, indx_seqs, bias).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[ 0.75584006, -0.17063873,  1.6859026 , ...,  0.18504576,\n",
       "        -0.00192568,  0.08196551],\n",
       "       [-0.02017869, -0.71082234,  0.16039723, ...,  0.41155237,\n",
       "        -0.4476214 ,  1.0754542 ],\n",
       "       [ 0.7968303 , -0.10603629,  1.2908362 , ..., -0.32634115,\n",
       "        -0.48441118, -0.49104658],\n",
       "       ...,\n",
       "       [ 0.30003327, -0.17096956,  0.10126591, ...,  0.6488617 ,\n",
       "        -0.1293261 , -0.16757864],\n",
       "       [-0.01557497, -0.39440617,  0.21288827, ...,  0.32918108,\n",
       "        -0.12092257, -0.527568  ],\n",
       "       [-0.22859119, -0.54928684,  0.2176865 , ...,  0.11042877,\n",
       "         0.5686784 ,  0.13307133]], dtype=float32)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "forward_orig_fast(input, weights, indx_seqs, bias).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6.69 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit forward_orig(input, weights, indx_seqs, bias).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.73 ms ± 32.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit forward_orig_fast(input, weights, indx_seqs, bias).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "forward_orig_faster = jax.jit(partial(forward_orig, weights=weights, indx_seqs=indx_seqs, bias=bias))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[ 0.75584006, -0.17063873,  1.6859026 , ...,  0.18504576,\n",
       "        -0.00192568,  0.08196551],\n",
       "       [-0.02017869, -0.71082234,  0.16039723, ...,  0.41155237,\n",
       "        -0.4476214 ,  1.0754542 ],\n",
       "       [ 0.7968303 , -0.10603629,  1.2908362 , ..., -0.32634115,\n",
       "        -0.48441118, -0.49104658],\n",
       "       ...,\n",
       "       [ 0.30003327, -0.17096956,  0.10126591, ...,  0.6488617 ,\n",
       "        -0.1293261 , -0.16757864],\n",
       "       [-0.01557497, -0.39440617,  0.21288827, ...,  0.32918108,\n",
       "        -0.12092257, -0.527568  ],\n",
       "       [-0.22859119, -0.54928684,  0.2176865 , ...,  0.11042877,\n",
       "         0.5686784 ,  0.13307133]], dtype=float32)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "forward_orig_faster(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.63 ms ± 30.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit forward_orig_faster(input).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*italicized text*## Method #1: Use slicing/indexing and broadcasting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_output = forward_orig(input, weights, indx_seqs, bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Do forward pass for a single neuron from a single batch\n",
    "def forward_neuron_single(input: jnp.ndarray, weights: jnp.ndarray, indices: jnp.ndarray) -> jnp.ndarray:\n",
    "    return jnp.sum(input[indices] * weights)\n",
    "\n",
    "def forward_neuron_v(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:\n",
    "    return vmap(partial(forward_neuron_single, input), in_axes=0, out_axes=0)(weights, indx_seqs) + bias\n",
    "\n",
    "def forward_neuron(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:\n",
    "    return vmap(partial(forward_neuron_v, weights=weights, indx_seqs=indx_seqs, bias=bias))(input)\n",
    "\n",
    "# Do forward pass for all neurons over sparsity axis from a single batch\n",
    "def forward_sparsity_single(input: jnp.ndarray, weights: jnp.ndarray, indices: jnp.ndarray) -> jnp.ndarray:\n",
    "    return input[indices] * weights\n",
    "\n",
    "def forward_sparsity_v(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:\n",
    "    output_neurons = vmap(partial(forward_sparsity_single, input), in_axes=1, out_axes=1)(weights, indx_seqs)\n",
    "    return jnp.sum(output_neurons, axis=1) + bias\n",
    "\n",
    "def forward_sparsity(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:\n",
    "    return vmap(partial(forward_sparsity_v, weights=weights, indx_seqs=indx_seqs, bias=bias))(input)\n",
    "\n",
    "forward_neuron_fast = jax.jit(forward_neuron)\n",
    "forward_neuron_faster = jax.jit(partial(forward_neuron, weights=weights, indx_seqs=indx_seqs, bias=bias))\n",
    "forward_sparsity_fast = jax.jit(forward_sparsity)\n",
    "forward_sparsity_faster = jax.jit(partial(forward_sparsity, weights=weights, indx_seqs=indx_seqs, bias=bias))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Method #2: vmap over neuron/sparsity axes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "ename": "AssertionError",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m/home/mike/condensed-sparsity/notebooks/jax_benchmarks.ipynb Cell 26\u001b[0m line \u001b[0;36m4\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Bhector/home/mike/condensed-sparsity/notebooks/jax_benchmarks.ipynb#Y100sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a>\u001b[0m fast_sparsity_output \u001b[39m=\u001b[39m forward_sparsity_fast(\u001b[39minput\u001b[39m, weights, indx_seqs, bias)\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Bhector/home/mike/condensed-sparsity/notebooks/jax_benchmarks.ipynb#Y100sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>\u001b[0m fast_sparsity_output_faster \u001b[39m=\u001b[39m forward_sparsity_faster(\u001b[39minput\u001b[39m)\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2Bhector/home/mike/condensed-sparsity/notebooks/jax_benchmarks.ipynb#Y100sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>\u001b[0m \u001b[39massert\u001b[39;00m jnp\u001b[39m.\u001b[39mallclose(orig_output, fast_sparsity_output)\n\u001b[1;32m      <a href='vscode-notebook-cell://ssh-remote%2Bhector/home/mike/condensed-sparsity/notebooks/jax_benchmarks.ipynb#Y100sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a>\u001b[0m \u001b[39massert\u001b[39;00m jnp\u001b[39m.\u001b[39mallclose(orig_output, fast_sparsity_output_faster)\n",
      "\u001b[0;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# call once so JIT happens\n",
    "fast_sparsity_output = forward_sparsity_fast(input, weights, indx_seqs, bias)\n",
    "fast_sparsity_output_faster = forward_sparsity_faster(input)\n",
    "assert jnp.allclose(orig_output, fast_sparsity_output)\n",
    "assert jnp.allclose(orig_output, fast_sparsity_output_faster)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit forward_sparsity(input, weights, indx_seqs, bias).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit forward_sparsity_fast(input, weights, indx_seqs, bias).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit forward_sparsity_faster(input).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Method #3: vmap over sparsity/neuron axes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# call once so JIT happens\n",
    "fast_neuron_output = forward_neuron_fast(input, weights, indx_seqs, bias)\n",
    "faster_neuron_output = forward_neuron_faster(input)\n",
    "assert jnp.allclose(orig_output, fast_neuron_output)\n",
    "assert jnp.allclose(orig_output, faster_neuron_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit forward_neuron(input, weights, indx_seqs, bias).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit forward_neuron_fast(input, weights, indx_seqs, bias).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit forward_neuron_faster(input).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CondensedLinearVmapNeuron(nn.Module):\n",
    "    features: int\n",
    "    fan_in: int\n",
    "    kernel_init: Callable = nn.initializers.lecun_normal()\n",
    "    bias_init: Callable = nn.initializers.zeros_init()\n",
    "\n",
    "    def forward_neuron(self, input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:\n",
    "      return vmap(partial(forward_neuron_v, weights=weights, indx_seqs=indx_seqs, bias=bias))(input)\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, input: ArrayLike) -> jax.Array:\n",
    "        kernel = self.param(\"kernel\", self.kernel_init, (self.features, self.fan_in))\n",
    "        bias = self.param(\"bias\", self.bias_init, (self.features,))\n",
    "        indx_seqs = self.param(\"indx_seqs\", self.kernel_init, (self.features, self.fan_in))\n",
    "        return self.forward_neuron(input, weights, indx_seqs, bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cl = CondensedLinearVmapNeuron(*condensed_params[\"params\"][\"kernel\"].shape)\n",
    "cl_fast = jax.jit(lambda x: cl.apply(condensed_params, x))\n",
    "cl_fast(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit cl_fast(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit cl.apply(condensed_params, x).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Do forward pass for a single neuron from a single batch\n",
    "def forward_batch_neuron_single(input: jnp.ndarray, weights: jnp.ndarray, indices: jnp.ndarray) -> jnp.ndarray:\n",
    "    return jnp.sum(input[:, indices] * weights[None, :], axis=1)\n",
    "\n",
    "def forward_batch_neuron(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray) -> jnp.ndarray:\n",
    "    return vmap(partial(forward_batch_neuron_single, input), in_axes=0, out_axes=0)(weights, indx_seqs).T\n",
    "\n",
    "def forward_batch_sparsity(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray) -> jnp.ndarray:\n",
    "    return vmap(partial(forward_batch_neuron_single, input), in_axes=0, out_axes=1)(weights, indx_seqs)\n",
    "\n",
    "forward_batch_neuron_fast = jax.jit(forward_batch_neuron)\n",
    "forward_batch_sparsity_fast = jax.jit(forward_batch_sparsity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "forward_batch_neuron_output = forward_batch_neuron_fast(input, weights, indx_seqs)\n",
    "forward_batch_sparsity_fast_output = forward_batch_sparsity_fast(input, weights, indx_seqs)\n",
    "assert jnp.allclose(orig_output, forward_batch_neuron_output)  ## TODO: Add bias to above\n",
    "assert jnp.allclose(orig_output, forward_batch_sparsity_fast_output)  ## TODO: Add bias to above"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit forward_batch_neuron(input, weights, indx_seqs).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%timeit forward_batch_neuron_fast(input, weights, indx_seqs).block_until_ready()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
