{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fb1a8110",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "os.environ[\"OMP_NUM_THREADS\"] = \"1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "35287ee5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-26T11:38:17.209421Z",
     "iopub.status.busy": "2023-12-26T11:38:17.208962Z",
     "iopub.status.idle": "2023-12-26T11:38:20.906875Z",
     "shell.execute_reply": "2023-12-26T11:38:20.905911Z"
    },
    "papermill": {
     "duration": 3.727175,
     "end_time": "2023-12-26T11:38:20.909562",
     "exception": false,
     "start_time": "2023-12-26T11:38:17.182387",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from modelutils import *\n",
    "from datautils import *\n",
    "\n",
    "import datasets\n",
    "import matplotlib.pyplot as plt\n",
    "from bf16_fused_adam import BF16FusedAdamW\n",
    "import datasets\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ecf0db9a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-26T11:38:20.927014Z",
     "iopub.status.busy": "2023-12-26T11:38:20.926394Z",
     "iopub.status.idle": "2023-12-26T11:38:20.933689Z",
     "shell.execute_reply": "2023-12-26T11:38:20.932806Z"
    },
    "papermill": {
     "duration": 0.017022,
     "end_time": "2023-12-26T11:38:20.935379",
     "exception": false,
     "start_time": "2023-12-26T11:38:20.918357",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_opt(model):\n",
    "    import torch\n",
    "    def skip(*args, **kwargs):\n",
    "        pass\n",
    "    torch.nn.init.kaiming_uniform_ = skip\n",
    "    torch.nn.init.uniform_ = skip\n",
    "    torch.nn.init.normal_ = skip\n",
    "    from transformers import AutoModelForCausalLM\n",
    "    model = AutoModelForCausalLM.from_pretrained(model, torch_dtype='auto', cache_dir=\"/mnt/nvme/llm_weights\")\n",
    "    print(\"ms\", model.config.max_position_embeddings)\n",
    "    model.seqlen = 1024*8# model.config.max_position_embeddings\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d3d09b70",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-26T11:38:20.949258Z",
     "iopub.status.busy": "2023-12-26T11:38:20.948788Z",
     "iopub.status.idle": "2023-12-26T11:38:40.523396Z",
     "shell.execute_reply": "2023-12-26T11:38:40.522532Z"
    },
    "papermill": {
     "duration": 19.585753,
     "end_time": "2023-12-26T11:38:40.527104",
     "exception": false,
     "start_time": "2023-12-26T11:38:20.941351",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a1dd6c0423414b5f9a33db5912de25bc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ms 8192\n",
      "8030261248 6979584000\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "LlamaForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(128256, 4096)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x LlamaDecoderLayer(\n",
       "        (self_attn): LlamaAttention(\n",
       "          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "        )\n",
       "        (mlp): LlamaMLP(\n",
       "          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "      )\n",
       "    )\n",
       "    (norm): LlamaRMSNorm((4096,), eps=1e-05)\n",
       "    (rotary_emb): LlamaRotaryEmbedding()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=128256, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
    "TARGET_BITS = 2.0\n",
    "\n",
    "model = get_opt(model_name)\n",
    "model.eval()\n",
    "\n",
    "print(sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.model.layers.parameters()))\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bf2732ff",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-26T11:38:40.545044Z",
     "iopub.status.busy": "2023-12-26T11:38:40.544530Z",
     "iopub.status.idle": "2023-12-26T11:39:00.491876Z",
     "shell.execute_reply": "2023-12-26T11:39:00.491195Z"
    },
    "papermill": {
     "duration": 19.957607,
     "end_time": "2023-12-26T11:39:00.494561",
     "exception": false,
     "start_time": "2023-12-26T11:38:40.536954",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([256, 8192])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_samples = 256\n",
    "\n",
    "ds = datasets.load_from_disk(\"redpajama_tokenized_llama3/\")\n",
    "\n",
    "np.random.seed(47)\n",
    "inds = np.random.randint(0, len(ds), size=n_samples)\n",
    "\n",
    "dataloader = torch.LongTensor(ds[inds][\"input_ids\"])\n",
    "dataloader.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "17b0c3ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model.embed_tokens.weight\n",
      "model.layers.0.self_attn.q_proj.weight\n",
      "model.layers.0.self_attn.k_proj.weight\n",
      "model.layers.0.self_attn.v_proj.weight\n",
      "model.layers.0.self_attn.o_proj.weight\n",
      "model.layers.0.mlp.gate_proj.weight\n",
      "model.layers.0.mlp.up_proj.weight\n",
      "model.layers.0.mlp.down_proj.weight\n",
      "model.layers.0.input_layernorm.weight\n",
      "model.layers.0.post_attention_layernorm.weight\n",
      "model.layers.1.self_attn.q_proj.weight\n",
      "model.layers.1.self_attn.k_proj.weight\n",
      "model.layers.1.self_attn.v_proj.weight\n",
      "model.layers.1.self_attn.o_proj.weight\n",
      "model.layers.1.mlp.gate_proj.weight\n",
      "model.layers.1.mlp.up_proj.weight\n",
      "model.layers.1.mlp.down_proj.weight\n",
      "model.layers.1.input_layernorm.weight\n",
      "model.layers.1.post_attention_layernorm.weight\n",
      "model.layers.2.self_attn.q_proj.weight\n",
      "model.layers.2.self_attn.k_proj.weight\n",
      "model.layers.2.self_attn.v_proj.weight\n",
      "model.layers.2.self_attn.o_proj.weight\n",
      "model.layers.2.mlp.gate_proj.weight\n",
      "model.layers.2.mlp.up_proj.weight\n",
      "model.layers.2.mlp.down_proj.weight\n",
      "model.layers.2.input_layernorm.weight\n",
      "model.layers.2.post_attention_layernorm.weight\n",
      "model.layers.3.self_attn.q_proj.weight\n",
      "model.layers.3.self_attn.k_proj.weight\n",
      "model.layers.3.self_attn.v_proj.weight\n",
      "model.layers.3.self_attn.o_proj.weight\n",
      "model.layers.3.mlp.gate_proj.weight\n",
      "model.layers.3.mlp.up_proj.weight\n",
      "model.layers.3.mlp.down_proj.weight\n",
      "model.layers.3.input_layernorm.weight\n",
      "model.layers.3.post_attention_layernorm.weight\n",
      "model.layers.4.self_attn.q_proj.weight\n",
      "model.layers.4.self_attn.k_proj.weight\n",
      "model.layers.4.self_attn.v_proj.weight\n",
      "model.layers.4.self_attn.o_proj.weight\n",
      "model.layers.4.mlp.gate_proj.weight\n",
      "model.layers.4.mlp.up_proj.weight\n",
      "model.layers.4.mlp.down_proj.weight\n",
      "model.layers.4.input_layernorm.weight\n",
      "model.layers.4.post_attention_layernorm.weight\n",
      "model.layers.5.self_attn.q_proj.weight\n",
      "model.layers.5.self_attn.k_proj.weight\n",
      "model.layers.5.self_attn.v_proj.weight\n",
      "model.layers.5.self_attn.o_proj.weight\n",
      "model.layers.5.mlp.gate_proj.weight\n",
      "model.layers.5.mlp.up_proj.weight\n",
      "model.layers.5.mlp.down_proj.weight\n",
      "model.layers.5.input_layernorm.weight\n",
      "model.layers.5.post_attention_layernorm.weight\n",
      "model.layers.6.self_attn.q_proj.weight\n",
      "model.layers.6.self_attn.k_proj.weight\n",
      "model.layers.6.self_attn.v_proj.weight\n",
      "model.layers.6.self_attn.o_proj.weight\n",
      "model.layers.6.mlp.gate_proj.weight\n",
      "model.layers.6.mlp.up_proj.weight\n",
      "model.layers.6.mlp.down_proj.weight\n",
      "model.layers.6.input_layernorm.weight\n",
      "model.layers.6.post_attention_layernorm.weight\n",
      "model.layers.7.self_attn.q_proj.weight\n",
      "model.layers.7.self_attn.k_proj.weight\n",
      "model.layers.7.self_attn.v_proj.weight\n",
      "model.layers.7.self_attn.o_proj.weight\n",
      "model.layers.7.mlp.gate_proj.weight\n",
      "model.layers.7.mlp.up_proj.weight\n",
      "model.layers.7.mlp.down_proj.weight\n",
      "model.layers.7.input_layernorm.weight\n",
      "model.layers.7.post_attention_layernorm.weight\n",
      "model.layers.8.self_attn.q_proj.weight\n",
      "model.layers.8.self_attn.k_proj.weight\n",
      "model.layers.8.self_attn.v_proj.weight\n",
      "model.layers.8.self_attn.o_proj.weight\n",
      "model.layers.8.mlp.gate_proj.weight\n",
      "model.layers.8.mlp.up_proj.weight\n",
      "model.layers.8.mlp.down_proj.weight\n",
      "model.layers.8.input_layernorm.weight\n",
      "model.layers.8.post_attention_layernorm.weight\n",
      "model.layers.9.self_attn.q_proj.weight\n",
      "model.layers.9.self_attn.k_proj.weight\n",
      "model.layers.9.self_attn.v_proj.weight\n",
      "model.layers.9.self_attn.o_proj.weight\n",
      "model.layers.9.mlp.gate_proj.weight\n",
      "model.layers.9.mlp.up_proj.weight\n",
      "model.layers.9.mlp.down_proj.weight\n",
      "model.layers.9.input_layernorm.weight\n",
      "model.layers.9.post_attention_layernorm.weight\n",
      "model.layers.10.self_attn.q_proj.weight\n",
      "model.layers.10.self_attn.k_proj.weight\n",
      "model.layers.10.self_attn.v_proj.weight\n",
      "model.layers.10.self_attn.o_proj.weight\n",
      "model.layers.10.mlp.gate_proj.weight\n",
      "model.layers.10.mlp.up_proj.weight\n",
      "model.layers.10.mlp.down_proj.weight\n",
      "model.layers.10.input_layernorm.weight\n",
      "model.layers.10.post_attention_layernorm.weight\n",
      "model.layers.11.self_attn.q_proj.weight\n",
      "model.layers.11.self_attn.k_proj.weight\n",
      "model.layers.11.self_attn.v_proj.weight\n",
      "model.layers.11.self_attn.o_proj.weight\n",
      "model.layers.11.mlp.gate_proj.weight\n",
      "model.layers.11.mlp.up_proj.weight\n",
      "model.layers.11.mlp.down_proj.weight\n",
      "model.layers.11.input_layernorm.weight\n",
      "model.layers.11.post_attention_layernorm.weight\n",
      "model.layers.12.self_attn.q_proj.weight\n",
      "model.layers.12.self_attn.k_proj.weight\n",
      "model.layers.12.self_attn.v_proj.weight\n",
      "model.layers.12.self_attn.o_proj.weight\n",
      "model.layers.12.mlp.gate_proj.weight\n",
      "model.layers.12.mlp.up_proj.weight\n",
      "model.layers.12.mlp.down_proj.weight\n",
      "model.layers.12.input_layernorm.weight\n",
      "model.layers.12.post_attention_layernorm.weight\n",
      "model.layers.13.self_attn.q_proj.weight\n",
      "model.layers.13.self_attn.k_proj.weight\n",
      "model.layers.13.self_attn.v_proj.weight\n",
      "model.layers.13.self_attn.o_proj.weight\n",
      "model.layers.13.mlp.gate_proj.weight\n",
      "model.layers.13.mlp.up_proj.weight\n",
      "model.layers.13.mlp.down_proj.weight\n",
      "model.layers.13.input_layernorm.weight\n",
      "model.layers.13.post_attention_layernorm.weight\n",
      "model.layers.14.self_attn.q_proj.weight\n",
      "model.layers.14.self_attn.k_proj.weight\n",
      "model.layers.14.self_attn.v_proj.weight\n",
      "model.layers.14.self_attn.o_proj.weight\n",
      "model.layers.14.mlp.gate_proj.weight\n",
      "model.layers.14.mlp.up_proj.weight\n",
      "model.layers.14.mlp.down_proj.weight\n",
      "model.layers.14.input_layernorm.weight\n",
      "model.layers.14.post_attention_layernorm.weight\n",
      "model.layers.15.self_attn.q_proj.weight\n",
      "model.layers.15.self_attn.k_proj.weight\n",
      "model.layers.15.self_attn.v_proj.weight\n",
      "model.layers.15.self_attn.o_proj.weight\n",
      "model.layers.15.mlp.gate_proj.weight\n",
      "model.layers.15.mlp.up_proj.weight\n",
      "model.layers.15.mlp.down_proj.weight\n",
      "model.layers.15.input_layernorm.weight\n",
      "model.layers.15.post_attention_layernorm.weight\n",
      "model.layers.16.self_attn.q_proj.weight\n",
      "model.layers.16.self_attn.k_proj.weight\n",
      "model.layers.16.self_attn.v_proj.weight\n",
      "model.layers.16.self_attn.o_proj.weight\n",
      "model.layers.16.mlp.gate_proj.weight\n",
      "model.layers.16.mlp.up_proj.weight\n",
      "model.layers.16.mlp.down_proj.weight\n",
      "model.layers.16.input_layernorm.weight\n",
      "model.layers.16.post_attention_layernorm.weight\n",
      "model.layers.17.self_attn.q_proj.weight\n",
      "model.layers.17.self_attn.k_proj.weight\n",
      "model.layers.17.self_attn.v_proj.weight\n",
      "model.layers.17.self_attn.o_proj.weight\n",
      "model.layers.17.mlp.gate_proj.weight\n",
      "model.layers.17.mlp.up_proj.weight\n",
      "model.layers.17.mlp.down_proj.weight\n",
      "model.layers.17.input_layernorm.weight\n",
      "model.layers.17.post_attention_layernorm.weight\n",
      "model.layers.18.self_attn.q_proj.weight\n",
      "model.layers.18.self_attn.k_proj.weight\n",
      "model.layers.18.self_attn.v_proj.weight\n",
      "model.layers.18.self_attn.o_proj.weight\n",
      "model.layers.18.mlp.gate_proj.weight\n",
      "model.layers.18.mlp.up_proj.weight\n",
      "model.layers.18.mlp.down_proj.weight\n",
      "model.layers.18.input_layernorm.weight\n",
      "model.layers.18.post_attention_layernorm.weight\n",
      "model.layers.19.self_attn.q_proj.weight\n",
      "model.layers.19.self_attn.k_proj.weight\n",
      "model.layers.19.self_attn.v_proj.weight\n",
      "model.layers.19.self_attn.o_proj.weight\n",
      "model.layers.19.mlp.gate_proj.weight\n",
      "model.layers.19.mlp.up_proj.weight\n",
      "model.layers.19.mlp.down_proj.weight\n",
      "model.layers.19.input_layernorm.weight\n",
      "model.layers.19.post_attention_layernorm.weight\n",
      "model.layers.20.self_attn.q_proj.weight\n",
      "model.layers.20.self_attn.k_proj.weight\n",
      "model.layers.20.self_attn.v_proj.weight\n",
      "model.layers.20.self_attn.o_proj.weight\n",
      "model.layers.20.mlp.gate_proj.weight\n",
      "model.layers.20.mlp.up_proj.weight\n",
      "model.layers.20.mlp.down_proj.weight\n",
      "model.layers.20.input_layernorm.weight\n",
      "model.layers.20.post_attention_layernorm.weight\n",
      "model.layers.21.self_attn.q_proj.weight\n",
      "model.layers.21.self_attn.k_proj.weight\n",
      "model.layers.21.self_attn.v_proj.weight\n",
      "model.layers.21.self_attn.o_proj.weight\n",
      "model.layers.21.mlp.gate_proj.weight\n",
      "model.layers.21.mlp.up_proj.weight\n",
      "model.layers.21.mlp.down_proj.weight\n",
      "model.layers.21.input_layernorm.weight\n",
      "model.layers.21.post_attention_layernorm.weight\n",
      "model.layers.22.self_attn.q_proj.weight\n",
      "model.layers.22.self_attn.k_proj.weight\n",
      "model.layers.22.self_attn.v_proj.weight\n",
      "model.layers.22.self_attn.o_proj.weight\n",
      "model.layers.22.mlp.gate_proj.weight\n",
      "model.layers.22.mlp.up_proj.weight\n",
      "model.layers.22.mlp.down_proj.weight\n",
      "model.layers.22.input_layernorm.weight\n",
      "model.layers.22.post_attention_layernorm.weight\n",
      "model.layers.23.self_attn.q_proj.weight\n",
      "model.layers.23.self_attn.k_proj.weight\n",
      "model.layers.23.self_attn.v_proj.weight\n",
      "model.layers.23.self_attn.o_proj.weight\n",
      "model.layers.23.mlp.gate_proj.weight\n",
      "model.layers.23.mlp.up_proj.weight\n",
      "model.layers.23.mlp.down_proj.weight\n",
      "model.layers.23.input_layernorm.weight\n",
      "model.layers.23.post_attention_layernorm.weight\n",
      "model.layers.24.self_attn.q_proj.weight\n",
      "model.layers.24.self_attn.k_proj.weight\n",
      "model.layers.24.self_attn.v_proj.weight\n",
      "model.layers.24.self_attn.o_proj.weight\n",
      "model.layers.24.mlp.gate_proj.weight\n",
      "model.layers.24.mlp.up_proj.weight\n",
      "model.layers.24.mlp.down_proj.weight\n",
      "model.layers.24.input_layernorm.weight\n",
      "model.layers.24.post_attention_layernorm.weight\n",
      "model.layers.25.self_attn.q_proj.weight\n",
      "model.layers.25.self_attn.k_proj.weight\n",
      "model.layers.25.self_attn.v_proj.weight\n",
      "model.layers.25.self_attn.o_proj.weight\n",
      "model.layers.25.mlp.gate_proj.weight\n",
      "model.layers.25.mlp.up_proj.weight\n",
      "model.layers.25.mlp.down_proj.weight\n",
      "model.layers.25.input_layernorm.weight\n",
      "model.layers.25.post_attention_layernorm.weight\n",
      "model.layers.26.self_attn.q_proj.weight\n",
      "model.layers.26.self_attn.k_proj.weight\n",
      "model.layers.26.self_attn.v_proj.weight\n",
      "model.layers.26.self_attn.o_proj.weight\n",
      "model.layers.26.mlp.gate_proj.weight\n",
      "model.layers.26.mlp.up_proj.weight\n",
      "model.layers.26.mlp.down_proj.weight\n",
      "model.layers.26.input_layernorm.weight\n",
      "model.layers.26.post_attention_layernorm.weight\n",
      "model.layers.27.self_attn.q_proj.weight\n",
      "model.layers.27.self_attn.k_proj.weight\n",
      "model.layers.27.self_attn.v_proj.weight\n",
      "model.layers.27.self_attn.o_proj.weight\n",
      "model.layers.27.mlp.gate_proj.weight\n",
      "model.layers.27.mlp.up_proj.weight\n",
      "model.layers.27.mlp.down_proj.weight\n",
      "model.layers.27.input_layernorm.weight\n",
      "model.layers.27.post_attention_layernorm.weight\n",
      "model.layers.28.self_attn.q_proj.weight\n",
      "model.layers.28.self_attn.k_proj.weight\n",
      "model.layers.28.self_attn.v_proj.weight\n",
      "model.layers.28.self_attn.o_proj.weight\n",
      "model.layers.28.mlp.gate_proj.weight\n",
      "model.layers.28.mlp.up_proj.weight\n",
      "model.layers.28.mlp.down_proj.weight\n",
      "model.layers.28.input_layernorm.weight\n",
      "model.layers.28.post_attention_layernorm.weight\n",
      "model.layers.29.self_attn.q_proj.weight\n",
      "model.layers.29.self_attn.k_proj.weight\n",
      "model.layers.29.self_attn.v_proj.weight\n",
      "model.layers.29.self_attn.o_proj.weight\n",
      "model.layers.29.mlp.gate_proj.weight\n",
      "model.layers.29.mlp.up_proj.weight\n",
      "model.layers.29.mlp.down_proj.weight\n",
      "model.layers.29.input_layernorm.weight\n",
      "model.layers.29.post_attention_layernorm.weight\n",
      "model.layers.30.self_attn.q_proj.weight\n",
      "model.layers.30.self_attn.k_proj.weight\n",
      "model.layers.30.self_attn.v_proj.weight\n",
      "model.layers.30.self_attn.o_proj.weight\n",
      "model.layers.30.mlp.gate_proj.weight\n",
      "model.layers.30.mlp.up_proj.weight\n",
      "model.layers.30.mlp.down_proj.weight\n",
      "model.layers.30.input_layernorm.weight\n",
      "model.layers.30.post_attention_layernorm.weight\n",
      "model.layers.31.self_attn.q_proj.weight\n",
      "model.layers.31.self_attn.k_proj.weight\n",
      "model.layers.31.self_attn.v_proj.weight\n",
      "model.layers.31.self_attn.o_proj.weight\n",
      "model.layers.31.mlp.gate_proj.weight\n",
      "model.layers.31.mlp.up_proj.weight\n",
      "model.layers.31.mlp.down_proj.weight\n",
      "model.layers.31.input_layernorm.weight\n",
      "model.layers.31.post_attention_layernorm.weight\n",
      "model.norm.weight\n",
      "lm_head.weight\n",
      "model.layers.0.self_attn.q_proj\n",
      "model.layers.0.self_attn.k_proj\n",
      "model.layers.0.self_attn.v_proj\n",
      "model.layers.0.self_attn.o_proj\n",
      "model.layers.0.mlp.gate_proj\n",
      "model.layers.0.mlp.up_proj\n",
      "model.layers.0.mlp.down_proj\n",
      "model.layers.1.self_attn.q_proj\n",
      "model.layers.1.self_attn.k_proj\n",
      "model.layers.1.self_attn.v_proj\n",
      "model.layers.1.self_attn.o_proj\n",
      "model.layers.1.mlp.gate_proj\n",
      "model.layers.1.mlp.up_proj\n",
      "model.layers.1.mlp.down_proj\n",
      "model.layers.2.self_attn.q_proj\n",
      "model.layers.2.self_attn.k_proj\n",
      "model.layers.2.self_attn.v_proj\n",
      "model.layers.2.self_attn.o_proj\n",
      "model.layers.2.mlp.gate_proj\n",
      "model.layers.2.mlp.up_proj\n",
      "model.layers.2.mlp.down_proj\n",
      "model.layers.3.self_attn.q_proj\n",
      "model.layers.3.self_attn.k_proj\n",
      "model.layers.3.self_attn.v_proj\n",
      "model.layers.3.self_attn.o_proj\n",
      "model.layers.3.mlp.gate_proj\n",
      "model.layers.3.mlp.up_proj\n",
      "model.layers.3.mlp.down_proj\n",
      "model.layers.4.self_attn.q_proj\n",
      "model.layers.4.self_attn.k_proj\n",
      "model.layers.4.self_attn.v_proj\n",
      "model.layers.4.self_attn.o_proj\n",
      "model.layers.4.mlp.gate_proj\n",
      "model.layers.4.mlp.up_proj\n",
      "model.layers.4.mlp.down_proj\n",
      "model.layers.5.self_attn.q_proj\n",
      "model.layers.5.self_attn.k_proj\n",
      "model.layers.5.self_attn.v_proj\n",
      "model.layers.5.self_attn.o_proj\n",
      "model.layers.5.mlp.gate_proj\n",
      "model.layers.5.mlp.up_proj\n",
      "model.layers.5.mlp.down_proj\n",
      "model.layers.6.self_attn.q_proj\n",
      "model.layers.6.self_attn.k_proj\n",
      "model.layers.6.self_attn.v_proj\n",
      "model.layers.6.self_attn.o_proj\n",
      "model.layers.6.mlp.gate_proj\n",
      "model.layers.6.mlp.up_proj\n",
      "model.layers.6.mlp.down_proj\n",
      "model.layers.7.self_attn.q_proj\n",
      "model.layers.7.self_attn.k_proj\n",
      "model.layers.7.self_attn.v_proj\n",
      "model.layers.7.self_attn.o_proj\n",
      "model.layers.7.mlp.gate_proj\n",
      "model.layers.7.mlp.up_proj\n",
      "model.layers.7.mlp.down_proj\n",
      "model.layers.8.self_attn.q_proj\n",
      "model.layers.8.self_attn.k_proj\n",
      "model.layers.8.self_attn.v_proj\n",
      "model.layers.8.self_attn.o_proj\n",
      "model.layers.8.mlp.gate_proj\n",
      "model.layers.8.mlp.up_proj\n",
      "model.layers.8.mlp.down_proj\n",
      "model.layers.9.self_attn.q_proj\n",
      "model.layers.9.self_attn.k_proj\n",
      "model.layers.9.self_attn.v_proj\n",
      "model.layers.9.self_attn.o_proj\n",
      "model.layers.9.mlp.gate_proj\n",
      "model.layers.9.mlp.up_proj\n",
      "model.layers.9.mlp.down_proj\n",
      "model.layers.10.self_attn.q_proj\n",
      "model.layers.10.self_attn.k_proj\n",
      "model.layers.10.self_attn.v_proj\n",
      "model.layers.10.self_attn.o_proj\n",
      "model.layers.10.mlp.gate_proj\n",
      "model.layers.10.mlp.up_proj\n",
      "model.layers.10.mlp.down_proj\n",
      "model.layers.11.self_attn.q_proj\n",
      "model.layers.11.self_attn.k_proj\n",
      "model.layers.11.self_attn.v_proj\n",
      "model.layers.11.self_attn.o_proj\n",
      "model.layers.11.mlp.gate_proj\n",
      "model.layers.11.mlp.up_proj\n",
      "model.layers.11.mlp.down_proj\n",
      "model.layers.12.self_attn.q_proj\n",
      "model.layers.12.self_attn.k_proj\n",
      "model.layers.12.self_attn.v_proj\n",
      "model.layers.12.self_attn.o_proj\n",
      "model.layers.12.mlp.gate_proj\n",
      "model.layers.12.mlp.up_proj\n",
      "model.layers.12.mlp.down_proj\n",
      "model.layers.13.self_attn.q_proj\n",
      "model.layers.13.self_attn.k_proj\n",
      "model.layers.13.self_attn.v_proj\n",
      "model.layers.13.self_attn.o_proj\n",
      "model.layers.13.mlp.gate_proj\n",
      "model.layers.13.mlp.up_proj\n",
      "model.layers.13.mlp.down_proj\n",
      "model.layers.14.self_attn.q_proj\n",
      "model.layers.14.self_attn.k_proj\n",
      "model.layers.14.self_attn.v_proj\n",
      "model.layers.14.self_attn.o_proj\n",
      "model.layers.14.mlp.gate_proj\n",
      "model.layers.14.mlp.up_proj\n",
      "model.layers.14.mlp.down_proj\n",
      "model.layers.15.self_attn.q_proj\n",
      "model.layers.15.self_attn.k_proj\n",
      "model.layers.15.self_attn.v_proj\n",
      "model.layers.15.self_attn.o_proj\n",
      "model.layers.15.mlp.gate_proj\n",
      "model.layers.15.mlp.up_proj\n",
      "model.layers.15.mlp.down_proj\n",
      "model.layers.16.self_attn.q_proj\n",
      "model.layers.16.self_attn.k_proj\n",
      "model.layers.16.self_attn.v_proj\n",
      "model.layers.16.self_attn.o_proj\n",
      "model.layers.16.mlp.gate_proj\n",
      "model.layers.16.mlp.up_proj\n",
      "model.layers.16.mlp.down_proj\n",
      "model.layers.17.self_attn.q_proj\n",
      "model.layers.17.self_attn.k_proj\n",
      "model.layers.17.self_attn.v_proj\n",
      "model.layers.17.self_attn.o_proj\n",
      "model.layers.17.mlp.gate_proj\n",
      "model.layers.17.mlp.up_proj\n",
      "model.layers.17.mlp.down_proj\n",
      "model.layers.18.self_attn.q_proj\n",
      "model.layers.18.self_attn.k_proj\n",
      "model.layers.18.self_attn.v_proj\n",
      "model.layers.18.self_attn.o_proj\n",
      "model.layers.18.mlp.gate_proj\n",
      "model.layers.18.mlp.up_proj\n",
      "model.layers.18.mlp.down_proj\n",
      "model.layers.19.self_attn.q_proj\n",
      "model.layers.19.self_attn.k_proj\n",
      "model.layers.19.self_attn.v_proj\n",
      "model.layers.19.self_attn.o_proj\n",
      "model.layers.19.mlp.gate_proj\n",
      "model.layers.19.mlp.up_proj\n",
      "model.layers.19.mlp.down_proj\n",
      "model.layers.20.self_attn.q_proj\n",
      "model.layers.20.self_attn.k_proj\n",
      "model.layers.20.self_attn.v_proj\n",
      "model.layers.20.self_attn.o_proj\n",
      "model.layers.20.mlp.gate_proj\n",
      "model.layers.20.mlp.up_proj\n",
      "model.layers.20.mlp.down_proj\n",
      "model.layers.21.self_attn.q_proj\n",
      "model.layers.21.self_attn.k_proj\n",
      "model.layers.21.self_attn.v_proj\n",
      "model.layers.21.self_attn.o_proj\n",
      "model.layers.21.mlp.gate_proj\n",
      "model.layers.21.mlp.up_proj\n",
      "model.layers.21.mlp.down_proj\n",
      "model.layers.22.self_attn.q_proj\n",
      "model.layers.22.self_attn.k_proj\n",
      "model.layers.22.self_attn.v_proj\n",
      "model.layers.22.self_attn.o_proj\n",
      "model.layers.22.mlp.gate_proj\n",
      "model.layers.22.mlp.up_proj\n",
      "model.layers.22.mlp.down_proj\n",
      "model.layers.23.self_attn.q_proj\n",
      "model.layers.23.self_attn.k_proj\n",
      "model.layers.23.self_attn.v_proj\n",
      "model.layers.23.self_attn.o_proj\n",
      "model.layers.23.mlp.gate_proj\n",
      "model.layers.23.mlp.up_proj\n",
      "model.layers.23.mlp.down_proj\n",
      "model.layers.24.self_attn.q_proj\n",
      "model.layers.24.self_attn.k_proj\n",
      "model.layers.24.self_attn.v_proj\n",
      "model.layers.24.self_attn.o_proj\n",
      "model.layers.24.mlp.gate_proj\n",
      "model.layers.24.mlp.up_proj\n",
      "model.layers.24.mlp.down_proj\n",
      "model.layers.25.self_attn.q_proj\n",
      "model.layers.25.self_attn.k_proj\n",
      "model.layers.25.self_attn.v_proj\n",
      "model.layers.25.self_attn.o_proj\n",
      "model.layers.25.mlp.gate_proj\n",
      "model.layers.25.mlp.up_proj\n",
      "model.layers.25.mlp.down_proj\n",
      "model.layers.26.self_attn.q_proj\n",
      "model.layers.26.self_attn.k_proj\n",
      "model.layers.26.self_attn.v_proj\n",
      "model.layers.26.self_attn.o_proj\n",
      "model.layers.26.mlp.gate_proj\n",
      "model.layers.26.mlp.up_proj\n",
      "model.layers.26.mlp.down_proj\n",
      "model.layers.27.self_attn.q_proj\n",
      "model.layers.27.self_attn.k_proj\n",
      "model.layers.27.self_attn.v_proj\n",
      "model.layers.27.self_attn.o_proj\n",
      "model.layers.27.mlp.gate_proj\n",
      "model.layers.27.mlp.up_proj\n",
      "model.layers.27.mlp.down_proj\n",
      "model.layers.28.self_attn.q_proj\n",
      "model.layers.28.self_attn.k_proj\n",
      "model.layers.28.self_attn.v_proj\n",
      "model.layers.28.self_attn.o_proj\n",
      "model.layers.28.mlp.gate_proj\n",
      "model.layers.28.mlp.up_proj\n",
      "model.layers.28.mlp.down_proj\n",
      "model.layers.29.self_attn.q_proj\n",
      "model.layers.29.self_attn.k_proj\n",
      "model.layers.29.self_attn.v_proj\n",
      "model.layers.29.self_attn.o_proj\n",
      "model.layers.29.mlp.gate_proj\n",
      "model.layers.29.mlp.up_proj\n",
      "model.layers.29.mlp.down_proj\n",
      "model.layers.30.self_attn.q_proj\n",
      "model.layers.30.self_attn.k_proj\n",
      "model.layers.30.self_attn.v_proj\n",
      "model.layers.30.self_attn.o_proj\n",
      "model.layers.30.mlp.gate_proj\n",
      "model.layers.30.mlp.up_proj\n",
      "model.layers.30.mlp.down_proj\n",
      "model.layers.31.self_attn.q_proj\n",
      "model.layers.31.self_attn.k_proj\n",
      "model.layers.31.self_attn.v_proj\n",
      "model.layers.31.self_attn.o_proj\n",
      "model.layers.31.mlp.gate_proj\n",
      "model.layers.31.mlp.up_proj\n",
      "model.layers.31.mlp.down_proj\n",
      "torch.Size([8192])\n",
      "tensor(1.9536, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "0 tensor(1.9536, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1599, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "1 tensor(2.1599, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3329, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "2 tensor(2.3329, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.7351, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "3 tensor(2.7351, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2082, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "4 tensor(2.2082, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1849, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "5 tensor(2.1849, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.8060, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "6 tensor(0.8060, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.6443, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "7 tensor(0.6443, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.5981, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "8 tensor(2.5981, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1800, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "9 tensor(2.1800, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1666, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "10 tensor(2.1666, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.5693, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "11 tensor(0.5693, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.6246, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "12 tensor(2.6246, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2584, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "13 tensor(2.2584, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1381, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "14 tensor(2.1381, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0045, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "15 tensor(2.0045, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4470, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "16 tensor(2.4470, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.6136, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "17 tensor(2.6136, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3502, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "18 tensor(2.3502, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9820, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "19 tensor(1.9820, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0571, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "20 tensor(2.0571, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.5343, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "21 tensor(2.5343, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2933, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "22 tensor(2.2933, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.8074, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "23 tensor(1.8074, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2841, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "24 tensor(2.2841, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1971, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "25 tensor(2.1971, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2671, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "26 tensor(2.2671, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2430, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "27 tensor(2.2430, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.7640, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "28 tensor(1.7640, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4365, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "29 tensor(2.4365, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0500, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "30 tensor(2.0500, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.0783, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "31 tensor(1.0783, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.6860, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "32 tensor(2.6860, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.5594, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "33 tensor(2.5594, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4109, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "34 tensor(2.4109, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0763, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "35 tensor(2.0763, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.7021, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "36 tensor(2.7021, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.6653, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "37 tensor(1.6653, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9072, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "38 tensor(1.9072, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1671, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "39 tensor(2.1671, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.5892, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "40 tensor(2.5892, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1367, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "41 tensor(2.1367, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.3846, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "42 tensor(1.3846, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.3707, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "43 tensor(1.3707, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2824, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "44 tensor(2.2824, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4258, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "45 tensor(2.4258, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2772, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "46 tensor(2.2772, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2920, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "47 tensor(2.2920, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.7777, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "48 tensor(1.7777, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2618, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "49 tensor(2.2618, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1823, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "50 tensor(2.1823, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.8886, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "51 tensor(2.8886, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.7736, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "52 tensor(2.7736, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0298, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "53 tensor(2.0298, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3486, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "54 tensor(2.3486, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.7426, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "55 tensor(2.7426, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4559, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "56 tensor(2.4559, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3677, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "57 tensor(2.3677, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0648, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "58 tensor(2.0648, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0640, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "59 tensor(2.0640, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.7398, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "60 tensor(0.7398, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.9153, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "61 tensor(0.9153, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3077, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "62 tensor(2.3077, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1110, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "63 tensor(2.1110, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.8985, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "64 tensor(1.8985, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.6451, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "65 tensor(2.6451, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.7406, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "66 tensor(1.7406, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4653, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "67 tensor(2.4653, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.6738, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "68 tensor(2.6738, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.9872, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "69 tensor(0.9872, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9605, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "70 tensor(1.9605, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1523, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "71 tensor(2.1523, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3034, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "72 tensor(2.3034, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1862, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "73 tensor(2.1862, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.6759, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "74 tensor(2.6759, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9141, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "75 tensor(1.9141, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4506, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "76 tensor(2.4506, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2807, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "77 tensor(2.2807, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4387, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "78 tensor(2.4387, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4428, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "79 tensor(2.4428, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.8702, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "80 tensor(1.8702, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0660, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "81 tensor(2.0660, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0459, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "82 tensor(2.0459, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9109, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "83 tensor(1.9109, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4840, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "84 tensor(2.4840, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4150, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "85 tensor(2.4150, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3720, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "86 tensor(2.3720, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3502, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "87 tensor(2.3502, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.6636, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "88 tensor(1.6636, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3485, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "89 tensor(2.3485, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0592, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "90 tensor(2.0592, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2677, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "91 tensor(2.2677, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4561, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "92 tensor(2.4561, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0907, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "93 tensor(2.0907, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9452, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "94 tensor(1.9452, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0322, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "95 tensor(2.0322, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.8215, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "96 tensor(2.8215, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3790, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "97 tensor(2.3790, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3926, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "98 tensor(2.3926, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1767, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "99 tensor(2.1767, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0910, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "100 tensor(2.0910, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2886, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "101 tensor(2.2886, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0524, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "102 tensor(2.0524, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4366, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "103 tensor(2.4366, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2397, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "104 tensor(2.2397, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1395, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "105 tensor(2.1395, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0334, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "106 tensor(2.0334, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3601, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "107 tensor(2.3601, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2118, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "108 tensor(2.2118, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0997, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "109 tensor(2.0997, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0487, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "110 tensor(2.0487, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0353, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "111 tensor(2.0353, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0727, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "112 tensor(2.0727, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.5206, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "113 tensor(0.5206, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.5168, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "114 tensor(2.5168, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9185, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "115 tensor(1.9185, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3911, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "116 tensor(2.3911, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2990, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "117 tensor(2.2990, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.1211, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "118 tensor(1.1211, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.0971, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "119 tensor(1.0971, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2213, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "120 tensor(2.2213, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1008, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "121 tensor(2.1008, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1933, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "122 tensor(2.1933, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3205, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "123 tensor(2.3205, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1122, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "124 tensor(2.1122, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2109, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "125 tensor(2.2109, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.8008, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "126 tensor(1.8008, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.0277, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "127 tensor(1.0277, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.6984, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "128 tensor(0.6984, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4740, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "129 tensor(2.4740, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.5605, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "130 tensor(1.5605, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0070, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "131 tensor(2.0070, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4372, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "132 tensor(2.4372, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.3486, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "133 tensor(1.3486, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.9191, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "134 tensor(2.9191, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.5348, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "135 tensor(2.5348, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3598, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "136 tensor(2.3598, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3841, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "137 tensor(2.3841, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.6741, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "138 tensor(2.6741, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2952, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "139 tensor(2.2952, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9239, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "140 tensor(1.9239, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2977, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "141 tensor(2.2977, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1784, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "142 tensor(2.1784, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1824, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "143 tensor(2.1824, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2444, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "144 tensor(2.2444, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.4210, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "145 tensor(1.4210, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.6697, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "146 tensor(1.6697, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.6542, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "147 tensor(1.6542, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9687, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "148 tensor(1.9687, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.1562, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "149 tensor(1.1562, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4414, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "150 tensor(2.4414, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.5191, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "151 tensor(0.5191, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3083, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "152 tensor(2.3083, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.7463, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "153 tensor(1.7463, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.3556, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "154 tensor(1.3556, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4148, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "155 tensor(2.4148, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.5501, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "156 tensor(1.5501, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.5071, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "157 tensor(1.5071, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.7704, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "158 tensor(1.7704, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0582, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "159 tensor(2.0582, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2894, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "160 tensor(2.2894, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.6237, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "161 tensor(2.6237, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1235, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "162 tensor(2.1235, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.7740, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "163 tensor(2.7740, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.7697, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "164 tensor(1.7697, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4305, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "165 tensor(2.4305, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4179, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "166 tensor(2.4179, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0165, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "167 tensor(2.0165, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.7632, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "168 tensor(2.7632, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.8390, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "169 tensor(0.8390, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.2779, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "170 tensor(1.2779, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.4718, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "171 tensor(1.4718, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0439, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "172 tensor(2.0439, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2804, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "173 tensor(2.2804, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.5981, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "174 tensor(2.5981, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4443, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "175 tensor(2.4443, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2288, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "176 tensor(2.2288, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2725, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "177 tensor(2.2725, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.9876, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "178 tensor(0.9876, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3952, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "179 tensor(2.3952, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.8974, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "180 tensor(0.8974, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.5261, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "181 tensor(2.5261, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.8294, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "182 tensor(2.8294, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3520, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "183 tensor(2.3520, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1880, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "184 tensor(2.1880, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9733, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "185 tensor(1.9733, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9845, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "186 tensor(1.9845, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2709, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "187 tensor(2.2709, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.2250, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "188 tensor(1.2250, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2898, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "189 tensor(2.2898, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1158, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "190 tensor(2.1158, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0212, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "191 tensor(2.0212, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0204, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "192 tensor(2.0204, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.7576, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "193 tensor(2.7576, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.7935, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "194 tensor(1.7935, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9026, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "195 tensor(1.9026, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.6179, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "196 tensor(2.6179, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.1181, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "197 tensor(1.1181, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3344, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "198 tensor(2.3344, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1448, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "199 tensor(2.1448, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4218, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "200 tensor(2.4218, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4022, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "201 tensor(2.4022, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3494, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "202 tensor(2.3494, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.8851, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "203 tensor(1.8851, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2647, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "204 tensor(2.2647, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.6946, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "205 tensor(2.6946, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.8461, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "206 tensor(2.8461, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0867, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "207 tensor(2.0867, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.5610, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "208 tensor(2.5610, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3114, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "209 tensor(2.3114, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9154, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "210 tensor(1.9154, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0355, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "211 tensor(2.0355, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.9215, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "212 tensor(0.9215, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3132, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "213 tensor(2.3132, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.5270, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "214 tensor(1.5270, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.7217, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "215 tensor(2.7217, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0213, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "216 tensor(2.0213, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4833, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "217 tensor(2.4833, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3792, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "218 tensor(2.3792, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.7943, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "219 tensor(2.7943, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2509, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "220 tensor(2.2509, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.8421, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "221 tensor(2.8421, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.8779, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "222 tensor(2.8779, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.5588, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "223 tensor(2.5588, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1238, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "224 tensor(2.1238, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.7832, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "225 tensor(0.7832, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.5952, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "226 tensor(1.5952, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.7719, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "227 tensor(2.7719, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1810, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "228 tensor(2.1810, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9224, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "229 tensor(1.9224, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.8702, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "230 tensor(2.8702, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0987, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "231 tensor(2.0987, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1258, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "232 tensor(2.1258, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2540, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "233 tensor(2.2540, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.7031, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "234 tensor(1.7031, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.4016, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "235 tensor(2.4016, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1180, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "236 tensor(2.1180, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.3153, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "237 tensor(2.3153, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9767, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "238 tensor(1.9767, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.1904, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "239 tensor(2.1904, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.7589, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "240 tensor(2.7589, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.4716, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "241 tensor(1.4716, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.5360, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "242 tensor(0.5360, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0428, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "243 tensor(2.0428, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2762, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "244 tensor(2.2762, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2028, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "245 tensor(2.2028, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.6558, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "246 tensor(2.6558, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.9736, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "247 tensor(1.9736, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.5209, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "248 tensor(2.5209, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.2056, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "249 tensor(2.2056, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.8513, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "250 tensor(2.8513, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(0.8398, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "251 tensor(0.8398, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.0749, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "252 tensor(2.0749, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(2.6764, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "253 tensor(2.6764, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.8660, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "254 tensor(1.8660, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "torch.Size([8192])\n",
      "tensor(1.8761, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n",
      "255 tensor(1.8761, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>)\n"
     ]
    }
   ],
   "source": [
    "from cut_cross_entropy import linear_cross_entropy\n",
    "\n",
    "model.cuda()\n",
    "model.gradient_checkpointing_enable()\n",
    "model.config.use_cache = False\n",
    "model.train()\n",
    "\n",
    "def f_hook(m, i, o):\n",
    "    X = i[0].detach().float()\n",
    "    X = X.reshape(-1, X.shape[-1])\n",
    "    m.i_norm += X.square().mean(dim=0)\n",
    "    \n",
    "def b_hook(m, _, go):\n",
    "    X = go[0].detach().float()\n",
    "    X = X.reshape(-1, X.shape[-1])\n",
    "    m.o_norm += X.square().mean(dim=0) * 1e6\n",
    "\n",
    "for n, p in model.named_parameters():\n",
    "    print(n)\n",
    "    if \"embed_tokens\" not in n:\n",
    "        p.requires_grad = False\n",
    "\n",
    "handles = []\n",
    "\n",
    "for n, m in model.named_modules():\n",
    "    if type(m) == nn.Linear and \"lm_head\" not in n:\n",
    "        print(n)\n",
    "        m.i_norm = torch.zeros(m.weight.shape[1], device=m.weight.device)\n",
    "        m.o_norm = torch.zeros(m.weight.shape[0], device=m.weight.device)\n",
    "        handles.append(m.register_forward_hook(f_hook))\n",
    "        handles.append(m.register_full_backward_hook(b_hook))\n",
    "\n",
    "for idx, bx in enumerate(dataloader):\n",
    "    #print(bx)\n",
    "    print(bx.shape)\n",
    "    bx = bx.cuda().unsqueeze(0)\n",
    "    #lm_logits = model(bx.cuda()).logits\n",
    "    embs = model.model(bx.cuda())[0]\n",
    "    #shift_logits = lm_logits[:, :-1, :].contiguous()\n",
    "    #shift_labels = bx[:, 1:]\n",
    "    #loss_fct = nn.CrossEntropyLoss()\n",
    "    #loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n",
    "    loss = linear_cross_entropy(embs, model.lm_head.weight, bx, shift=1)\n",
    "    print(loss)\n",
    "    #print(loss2)\n",
    "    #qq = qqqq\n",
    "    print(idx, loss)\n",
    "    loss.backward()\n",
    "\n",
    "\n",
    "    \n",
    "for h in handles:\n",
    "    h.remove()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "813e9daa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 2.2889,  1.1245, -0.0119])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.randn(3, 5)\n",
    "a.sum(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d8596d9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def power_iteration(A, num_iters=5):\n",
    "    \"\"\"\n",
    "    Performs power iteration to compute the top singular vectors and value.\n",
    "    \n",
    "    Arguments:\n",
    "        A (torch.Tensor): The input matrix of shape (m, n).\n",
    "        num_iters (int): Number of iterations to perform.\n",
    "    \n",
    "    Returns:\n",
    "        u (torch.Tensor): Dominant left singular vector (m,).\n",
    "        sigma (torch.Tensor): Dominant singular value (scalar).\n",
    "        v (torch.Tensor): Dominant right singular vector (n,).\n",
    "    \"\"\"\n",
    "    # Start with a random vector on the appropriate device\n",
    "    n = A.shape[1]\n",
    "    v = torch.randn(n, device=A.device)\n",
    "    v = v / torch.norm(v)\n",
    "    \n",
    "    for _ in range(num_iters):\n",
    "        # Multiply A*v\n",
    "        u = torch.mv(A, v)\n",
    "        u_norm = torch.norm(u)\n",
    "        if u_norm == 0:\n",
    "            break\n",
    "        u = u / u_norm\n",
    "        \n",
    "        # Multiply A^T*u\n",
    "        v = torch.mv(A.t(), u)\n",
    "        v_norm = torch.norm(v)\n",
    "        if v_norm == 0:\n",
    "            break\n",
    "        v = v / v_norm\n",
    "    \n",
    "    # Estimate the dominant singular value as ||A*v||\n",
    "    sigma = torch.norm(torch.mv(A, v))\n",
    "    # The left singular vector corresponding to sigma:\n",
    "    u = torch.mv(A, v) / sigma\n",
    "    return u, sigma, v\n",
    "\n",
    "def svd_abs(W):\n",
    "    Sg = W.sign()\n",
    "    Sg[Sg == 0] = 1\n",
    "    u, s, v = power_iteration(W.abs(), num_iters=5)\n",
    "    apx = s * torch.ger(u, v)\n",
    "    \n",
    "    return apx * Sg\n",
    "\n",
    "def find_other2(A, W, nnz, Z, U, print_sc=None, debug=False, reg=0, rho_start=0.03, iters=3, prune_iters=1, flip=False, final=False):\n",
    "    XX = A.T.matmul(A)\n",
    "    XX += torch.diag(torch.ones_like(XX.diag())) * XX.diag().mean() * reg\n",
    "    \n",
    "    #norm2 = torch.ones_like(norm2)\n",
    "    Wnn = W# * norm2.unsqueeze(1)\n",
    "    rho = 1\n",
    "    XY = A.T.matmul(Wnn)\n",
    "    XXinv = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho)\n",
    "    XXinv2 = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho_start)\n",
    "    U = U\n",
    "    Z = Z\n",
    "    \n",
    "    B = XXinv2.matmul(XY + rho_start*(Z-U))\n",
    "    \n",
    "    r_scale = c_scale = mask = None\n",
    "\n",
    "    for itt in range(iters-1):\n",
    "        Z = svd_abs(B+U)\n",
    "        #print(\"   \", \"z\", itt, (A.matmul(Z) - W).square().sum().item(), W.square().sum().item())\n",
    "\n",
    "        U = U + (B - Z)    \n",
    "\n",
    "        B = XXinv.matmul(XY + rho*(Z-U))\n",
    "\n",
    "    Z = svd_abs(B+U)\n",
    "    #print(\"   \", \"z\", iters-1, (A.matmul(Z) - W).square().sum().item(), W.square().sum().item())\n",
    "    U = U + (B - Z)\n",
    "   \n",
    "    return (Z), U\n",
    "\n",
    "\n",
    "def factorizeT(W, XX, o_norm, asp=0.16, sp=0.16, iters=80):\n",
    "    nza = int(W.shape[0]**2 * asp)\n",
    "    nzb = int(W.numel() * sp - nza)\n",
    "    \n",
    "    norm = XX.sqrt().unsqueeze(1) + 1e-8\n",
    "    norm_o = o_norm.sqrt() + 1e-8\n",
    "       \n",
    "    Wn = W * norm * norm_o\n",
    "       \n",
    "    mid = int(TARGET_BITS*(W.shape[0]*W.shape[1]) / (W.shape[0] + W.shape[1]))\n",
    "    \n",
    "    Az = torch.randn((W.shape[0], mid), device=W.device)\n",
    "    Au = torch.zeros_like(Az)\n",
    "\n",
    "    Bz = torch.randn((mid, W.shape[1]), device=W.device)\n",
    "    Bu = torch.zeros_like(Bz)\n",
    "    \n",
    "    for itt in range(iters):\n",
    "        #if itt < 10:\n",
    "        #    rho_start = 0.0\n",
    "        #elif itt < 15:\n",
    "        #    rho_start = 0.00\n",
    "        #else:\n",
    "        #    rho_start = 0.1\n",
    "        rho_start = min(1.0, itt / (iters-3))**3\n",
    "        if True or itt > iters // 2:\n",
    "            nzaa = nza\n",
    "            nzbb = nzb\n",
    "        else:\n",
    "            alph = (itt / (iters // 2))**2\n",
    "            nzaa = int(nza / 2 * (1-alph) + nza * alph)\n",
    "            nzbb = int(nzb / 2 * (1-alph) + nzb * alph)\n",
    "\n",
    "            \n",
    "        mid = Bz.norm(dim=1) + 1e-12\n",
    "        final = itt == iters - 1\n",
    "        \n",
    "        Az, Au = (x.T for x in find_other2(Bz.T / mid, Wn.T, nzaa, Az.T, Au.T, reg=3e-2, debug=False, rho_start=rho_start, flip=True, final=final))\n",
    "        mid = Az.norm(dim=0) + 1e-12\n",
    "        Bz, Bu = find_other2(Az / mid, Wn, nzbb, Bz, Bu, reg=3e-2, debug=False, rho_start=rho_start, final=final)\n",
    "        #print(\"err\", itt, ((Az / mid).matmul(Bz) - Wn).square().sum().item(), (Wn).square().sum().item())\n",
    "        if itt == iters - 1:\n",
    "            print(\"err\", itt, ((Az / mid).matmul(Bz) - Wn).square().sum().item(), (Wn).square().sum().item())\n",
    "            \n",
    "    return ((Az / norm).matmul(Bz / norm_o)).T, (Bz / norm_o).T, (Az / norm).T, 1 / mid\n",
    "\n",
    "\n",
    "def factorizef(W, XX, o_norm, asp=0.16, sp=0.16, iters=80, l_prev=None):\n",
    "    s_time = time.time()\n",
    "    if W.shape[0] >= W.shape[1]:\n",
    "        return factorizeT(W.T, XX, o_norm, asp, sp=sp, iters=iters)\n",
    "    \n",
    "    #print(\"a\")\n",
    "    nza = int(W.shape[0]**2 * asp)\n",
    "    nzb = int(W.numel() * sp - nza)\n",
    "    norm = XX.sqrt() + 1e-8\n",
    "    norm_o = (o_norm.sqrt() + 1e-8).unsqueeze(1)\n",
    "\n",
    "    Wn = W * norm * norm_o\n",
    "    mid = int(TARGET_BITS*(W.shape[0]*W.shape[1]) / (W.shape[0] + W.shape[1]))\n",
    "    \n",
    "    Az = torch.randn((W.shape[0], mid), device=W.device)\n",
    "    Au = torch.zeros_like(Az)\n",
    "\n",
    "    Bz = torch.randn((mid, W.shape[1]), device=W.device)\n",
    "    Bu = torch.zeros_like(Bz)\n",
    "    \n",
    "    for itt in range(iters):\n",
    "        #if itt < 10:\n",
    "        #    rho_start = 0.0\n",
    "        #elif itt < 15:\n",
    "        #    rho_start = 0.00\n",
    "        #else:\n",
    "        #    rho_start = 0.1\n",
    "            \n",
    "        rho_start = min(1.0, itt / (iters-3))**3\n",
    "        if True or itt > iters // 2:\n",
    "            nzaa = nza\n",
    "            nzbb = nzb\n",
    "        else:\n",
    "            alph = (itt / (iters // 2))**2\n",
    "            nzaa = int(nza / 2 * (1-alph) + nza * alph)\n",
    "            nzbb = int(nzb / 2 * (1-alph) + nzb * alph)\n",
    "            \n",
    "        \n",
    "        final = itt == iters - 1\n",
    "        mid = Bz.norm(dim=1) + 1e-12\n",
    "        Az, Au = (x.T for x in find_other2(Bz.T / mid, Wn.T, nzaa, Az.T, Au.T, reg=3e-2, debug=False, rho_start=rho_start, final=final))        \n",
    "        mid = Az.norm(dim=0) + 1e-12\n",
    "        Bz, Bu = find_other2(Az / mid, Wn, nzbb, Bz, Bu, reg=3e-2, debug=False, rho_start=rho_start, flip=True, final=final)\n",
    "        #print(\"err\", itt, ((Az / mid).matmul(Bz) - Wn).square().sum().item(), (Wn).square().sum().item())\n",
    "        #print(itt, time.time() - s_time, end =\" \") \n",
    "        #print_scores(Az.matmul(Bz / norm))\n",
    "        if itt == iters - 1:\n",
    "            print(\"err\", itt, ((Az / mid).matmul(Bz) - Wn).square().sum().item(), (Wn).square().sum().item())\n",
    "            \n",
    "    return (Az / norm_o).matmul(Bz / norm), Az / norm_o, Bz / norm, 1 / mid\n",
    "\n",
    "def factorize(lx, l_prev=None):\n",
    "    W = lx.weight.detach().float()\n",
    "    \n",
    "    sm = min(W.shape)\n",
    "    lg = max(W.shape)\n",
    "    mid = sm\n",
    "    print(\"mid\", mid)\n",
    "    lim = 1.25\n",
    "    for density in np.linspace(0.1, 0.4, 301):\n",
    "        total_size = 0\n",
    "        total_pars = 0\n",
    "        total_pars += sm*lg\n",
    "        mask_size = sm*mid + mid*lg\n",
    "        total_ones2 = sm*lg * density\n",
    "        p=total_ones2 / mask_size\n",
    "        ent = -p*np.log2(p)-(1-p)*np.log2(1-p)\n",
    "        total_size += (total_ones2*2 + mask_size*ent)\n",
    "        #print(\" \", density, total_size / total_pars)\n",
    "        if total_size / total_pars < lim:\n",
    "            sp = density\n",
    "    \n",
    "    \n",
    "    if W.shape[0] == W.shape[1]:\n",
    "        asp = sp/2\n",
    "    else:\n",
    "        asp = sp\n",
    "    W2, Ab, Bb, mid = factorizef(W, lx.i_norm, lx.o_norm, asp=asp, sp=sp, l_prev=l_prev, iters=260)\n",
    "    Ac = Ab\n",
    "    \n",
    "    #Ac = Ab\n",
    "    #W3 = Ac.matmul(Bb)\n",
    "    \n",
    "    #Bc = get_at(lx.XX, W.T, Bb.T, Ac.T).T\n",
    "    Bc = Bb\n",
    "    \n",
    "    An = Ac.norm() + 1e-12\n",
    "    Bn = Bc.norm() + 1e-12\n",
    "    Ac *= (Bn/An).sqrt()\n",
    "    Bc *= (An/Bn).sqrt()\n",
    "    \n",
    "    W3 = (Ac * mid).matmul(Bc)\n",
    "    assert W3.shape == lx.weight.shape\n",
    "    print(\"sparsity check\", ((Ac != 0).sum() + (Bb != 0).sum()).item() / W3.numel())\n",
    "    return W3, Ac, Bc, mid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "18157288",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting ...\n",
      "Ready.\n",
      "0 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.4383566975593567 9841.158203125\n",
      "sparsity check 2.0\n",
      "err_spxsp 533.7374267578125\n",
      "0 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 13.441218376159668 1171.7940673828125\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 33.56348419189453\n",
      "0 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.076663970947266 436.83599853515625\n",
      "sparsity check 2.0\n",
      "err_spxsp 134.54632568359375\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.0025646664507803507\n",
      "0 0.002087427475998993\n",
      "1 0.0014450392134222056\n",
      "2 0.001128885020079906\n",
      "3 0.0012699482997504674\n",
      "4 0.001326797969340987\n",
      "5 0.0009757646530488273\n",
      "6 0.0008886884204457601\n",
      "7 0.0008791659238340799\n",
      "err after  0.0008783414446043025\n",
      "0 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.3716686964035034 378.6822204589844\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 253.10183715820312\n",
      "0 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 27.28969383239746 394.8713073730469\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 693.1514892578125\n",
      "0 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 18.347957611083984 493.4566650390625\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 765.1322021484375\n",
      "0 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 52.37645721435547 820.6001586914062\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 686.4556884765625\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.01808585242542904\n",
      "0 0.016851811386004556\n",
      "1 0.014367779323947616\n",
      "2 0.026264477517543128\n",
      "3 0.024207827842474217\n",
      "4 0.011612243975832826\n",
      "5 0.009818340449783136\n",
      "6 0.008422214383244864\n",
      "7 0.008298395910969703\n",
      "err after  0.00826640568448056\n",
      "1 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.2651078701019287 109.70814514160156\n",
      "sparsity check 2.0\n",
      "err_spxsp 844.5587158203125\n",
      "1 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 157.90121459960938 6257.216796875\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 35.313636779785156\n",
      "1 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 13.667130470275879 258.56719970703125\n",
      "sparsity check 2.0\n",
      "err_spxsp 143.6864013671875\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.012349620868917555\n",
      "0 0.011516244863742031\n",
      "1 0.010397890690001077\n",
      "2 0.00963682854671788\n",
      "3 0.009177189986075973\n",
      "4 0.008965833689217106\n",
      "5 0.008749285840167431\n",
      "6 0.008619137512141606\n",
      "7 0.008582541950090672\n",
      "err after  0.008575950658268994\n",
      "1 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 1.0014801025390625 417.93408203125\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 286.8722229003906\n",
      "1 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 30.493595123291016 385.5950927734375\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 686.8384399414062\n",
      "1 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 26.490291595458984 494.0457763671875\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 782.54443359375\n",
      "1 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 40.47160720825195 79962.875\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 761.7555541992188\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.018124480098776985\n",
      "0 0.015328943209169665\n",
      "1 0.013691916730749654\n",
      "2 0.012734979569358984\n",
      "3 0.012026866679661907\n",
      "4 0.011369384854333475\n",
      "5 0.010784542220790172\n",
      "6 0.01059744668964413\n",
      "7 0.010548497553827474\n",
      "err after  0.010543801716266898\n",
      "2 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 1.8635152578353882 197.3865509033203\n",
      "sparsity check 2.0\n",
      "err_spxsp 657.75341796875\n",
      "2 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 867.773681640625 25664.57421875\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 19.05646514892578\n",
      "2 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.839994430541992 195.8936767578125\n",
      "sparsity check 2.0\n",
      "err_spxsp 107.89219665527344\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.012611227302841144\n",
      "0 0.012164022577053402\n",
      "1 0.011602019974816358\n",
      "2 0.011164158113388112\n",
      "3 0.01085955741655198\n",
      "4 0.01073462059503072\n",
      "5 0.010516604894291959\n",
      "6 0.01040321897380636\n",
      "7 0.01037955285028147\n",
      "err after  0.010376626734796446\n",
      "2 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 4.287783145904541 1178.958984375\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 375.341064453125\n",
      "2 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 29.521862030029297 347.12432861328125\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 675.037841796875\n",
      "2 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 26.727970123291016 319.8951110839844\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 836.9511108398438\n",
      "2 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 37.59506607055664 460.8402404785156\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 679.9669189453125\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.01799469510660856\n",
      "0 0.01601151305658277\n",
      "1 0.014370955319463974\n",
      "2 0.01342105827643536\n",
      "3 0.012904902891023085\n",
      "4 0.012686196700087748\n",
      "5 0.012245994552358752\n",
      "6 0.012028518431179691\n",
      "7 0.011975696863373742\n",
      "err after  0.01197057701574522\n",
      "3 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 2.694887161254883 132.48208618164062\n",
      "sparsity check 2.0\n",
      "err_spxsp 627.0409545898438\n",
      "3 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 531.7685546875 7588.1376953125\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 27.946956634521484\n",
      "3 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 22.554100036621094 371.1871337890625\n",
      "sparsity check 2.0\n",
      "err_spxsp 159.08544921875\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.01576290280354442\n",
      "0 0.01500272080374998\n",
      "1 0.014090203414525604\n",
      "2 0.01349149868474342\n",
      "3 0.013109055034874473\n",
      "4 0.01288742407632526\n",
      "5 0.012752986145642353\n",
      "6 0.012658308884056169\n",
      "7 0.012629997629119316\n",
      "err after  0.012626429517695215\n",
      "3 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 3.360809564590454 249.56704711914062\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 320.2950744628906\n",
      "3 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 29.382097244262695 334.70672607421875\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 662.8419189453125\n",
      "3 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 22.60019874572754 283.2327880859375\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 939.9226684570312\n",
      "3 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 32.49919509887695 380.55523681640625\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 637.4522705078125\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.01722065897047287\n",
      "0 0.014889820631651673\n",
      "1 0.013514260601368733\n",
      "2 0.012838315276894718\n",
      "3 0.012583590767462738\n",
      "4 0.012298809564526891\n",
      "5 0.011979710943705868\n",
      "6 0.011801816523075104\n",
      "7 0.011758336322600371\n",
      "err after  0.011754107357774046\n",
      "4 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 1.9980478286743164 125.30489349365234\n",
      "sparsity check 2.0\n",
      "err_spxsp 615.87158203125\n",
      "4 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 183.30621337890625 2748.644287109375\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 31.928035736083984\n",
      "4 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 19.540258407592773 308.2362976074219\n",
      "sparsity check 2.0\n",
      "err_spxsp 166.2259979248047\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.01469156074381317\n",
      "0 0.014096178601903375\n",
      "1 0.013439059435768286\n",
      "2 0.012949244988703867\n",
      "3 0.012655032260227017\n",
      "4 0.012484039150876924\n",
      "5 0.012354875934761367\n",
      "6 0.012259080707735848\n",
      "7 0.01223361700795067\n",
      "err after  0.012230732681928203\n",
      "4 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 2.6643009185791016 227.1650390625\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 309.04095458984375\n",
      "4 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 29.023677825927734 337.42767333984375\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 643.997314453125\n",
      "4 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 18.47871208190918 251.12530517578125\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1077.70654296875\n",
      "4 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 32.229679107666016 350.5721435546875\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 599.53564453125\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.015534370100795059\n",
      "0 0.013741825849137967\n",
      "1 0.012541511536255712\n",
      "2 0.011880437312356662\n",
      "3 0.011577351808227831\n",
      "4 0.011319408418785315\n",
      "5 0.011101019794296008\n",
      "6 0.010968252900056541\n",
      "7 0.010933547306194669\n",
      "err after  0.01093008476527757\n",
      "5 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 1.8321239948272705 109.50778198242188\n",
      "sparsity check 2.0\n",
      "err_spxsp 566.4998779296875\n",
      "5 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 75.28874206542969 1282.9521484375\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 21.328161239624023\n",
      "5 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 12.424322128295898 217.54364013671875\n",
      "sparsity check 2.0\n",
      "err_spxsp 133.34030151367188\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.012629105072846869\n",
      "0 0.012324218900175765\n",
      "1 0.011929628934012726\n",
      "2 0.011604658167925663\n",
      "3 0.011390682917408412\n",
      "4 0.011280388516752282\n",
      "5 0.011189399348950246\n",
      "6 0.011093433529822505\n",
      "7 0.011072977466028533\n",
      "err after  0.01107065176256583\n",
      "5 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 2.27120304107666 148.06173706054688\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 278.033447265625\n",
      "5 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 24.03139877319336 290.8399658203125\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 641.9815673828125\n",
      "5 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 14.628948211669922 210.00018310546875\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1056.09765625\n",
      "5 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 25.637454986572266 293.5966796875\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 597.7841796875\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.013804632195387967\n",
      "0 0.012412732368829893\n",
      "1 0.011483517504530028\n",
      "2 0.010952260621706955\n",
      "3 0.010684954057069262\n",
      "4 0.010498882256797515\n",
      "5 0.010306521860911744\n",
      "6 0.010199242093221983\n",
      "7 0.010169903260248248\n",
      "err after  0.010166901940465323\n",
      "6 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 1.638476848602295 82.12611389160156\n",
      "sparsity check 2.0\n",
      "err_spxsp 633.423828125\n",
      "6 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 48.807456970214844 625.0521240234375\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 25.419565200805664\n",
      "6 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 13.192723274230957 189.84561157226562\n",
      "sparsity check 2.0\n",
      "err_spxsp 142.28805541992188\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.011803853092715144\n",
      "0 0.011534323486557696\n",
      "1 0.011171033671416808\n",
      "2 0.010873507006181171\n",
      "3 0.010700680051741074\n",
      "4 0.010623515046972898\n",
      "5 0.01050529542590084\n",
      "6 0.010433409441247932\n",
      "7 0.010415849050332326\n",
      "err after  0.010413881675049197\n",
      "6 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 2.1251354217529297 123.88055419921875\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 318.7065734863281\n",
      "6 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 18.918113708496094 235.29473876953125\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 641.658447265625\n",
      "6 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 11.126502990722656 165.78916931152344\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1065.3203125\n",
      "6 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 20.380733489990234 230.57264709472656\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 591.8184204101562\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.013055689913016977\n",
      "0 0.011953887129493523\n",
      "1 0.011163153656525537\n",
      "2 0.010722628623625496\n",
      "3 0.010470512166648405\n",
      "4 0.010410948903881945\n",
      "5 0.010194644884904847\n",
      "6 0.010070597772937617\n",
      "7 0.01004104303137865\n",
      "err after  0.010038130032626214\n",
      "7 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 1.7433581352233887 81.38662719726562\n",
      "sparsity check 2.0\n",
      "err_spxsp 521.4410400390625\n",
      "7 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 33.345672607421875 474.5511474609375\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 22.049713134765625\n",
      "7 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 12.46328353881836 189.58480834960938\n",
      "sparsity check 2.0\n",
      "err_spxsp 141.44741821289062\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.0115635645815928\n",
      "0 0.01130438584004878\n",
      "1 0.010963003162032692\n",
      "2 0.010687015008443268\n",
      "3 0.01051423148601316\n",
      "4 0.010417049776151543\n",
      "5 0.010347602728870697\n",
      "6 0.010286515394909657\n",
      "7 0.01026944216391712\n",
      "err after  0.010267441002724809\n",
      "7 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 1.9482711553573608 112.07147216796875\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 296.8409118652344\n",
      "7 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 14.611572265625 187.4718780517578\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 658.7019653320312\n",
      "7 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.765278816223145 139.20465087890625\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1032.37451171875\n",
      "7 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 15.816585540771484 181.6197509765625\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 591.665283203125\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.012246523288922617\n",
      "0 0.011470594494312536\n",
      "1 0.010781033070088597\n",
      "2 0.010359945103118662\n",
      "3 0.010152879349334398\n",
      "4 0.009993135128752328\n",
      "5 0.009848331043031067\n",
      "6 0.009777571925951634\n",
      "7 0.009749875862326007\n",
      "err after  0.009746926860316307\n",
      "8 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 1.4691622257232666 73.22645568847656\n",
      "sparsity check 2.0\n",
      "err_spxsp 504.1269226074219\n",
      "8 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 21.081140518188477 325.3179931640625\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 24.77642059326172\n",
      "8 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 10.015251159667969 169.4737548828125\n",
      "sparsity check 2.0\n",
      "err_spxsp 141.87322998046875\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.011007799777871696\n",
      "0 0.010830589835677529\n",
      "1 0.01056427211733535\n",
      "2 0.010336330866266508\n",
      "3 0.010203063542576274\n",
      "4 0.01013124167729984\n",
      "5 0.01004375730008178\n",
      "6 0.009993595676860423\n",
      "7 0.009979278560422244\n",
      "err after  0.00997753983392613\n",
      "8 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 1.471785306930542 110.6685791015625\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 280.321044921875\n",
      "8 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 11.345083236694336 139.44224548339844\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 639.3282470703125\n",
      "8 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 6.538627624511719 98.22527313232422\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1002.5269775390625\n",
      "8 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 13.27122688293457 152.05625915527344\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 584.1429443359375\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.012203720663819695\n",
      "0 0.011434844494942809\n",
      "1 0.010809251445607515\n",
      "2 0.010445255320519209\n",
      "3 0.010226015227090102\n",
      "4 0.010108615992066916\n",
      "5 0.010032097135990625\n",
      "6 0.009904568663841928\n",
      "7 0.009878643355477834\n",
      "err after  0.009876066444121534\n",
      "9 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 1.2607781887054443 79.44091796875\n",
      "sparsity check 2.0\n",
      "err_spxsp 493.0610046386719\n",
      "9 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 21.515365600585938 489.63568115234375\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 36.05767059326172\n",
      "9 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.936771392822266 155.14752197265625\n",
      "sparsity check 2.0\n",
      "err_spxsp 162.10711669921875\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.011264992568612797\n",
      "0 0.01106553441422875\n",
      "1 0.010768551645014668\n",
      "2 0.01052332093968289\n",
      "3 0.010370584230258828\n",
      "4 0.010286258348060073\n",
      "5 0.010223003999271896\n",
      "6 0.010172033164053573\n",
      "7 0.010156817706956645\n",
      "err after  0.010155030091482331\n",
      "9 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 1.2103445529937744 154.6431884765625\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 306.51605224609375\n",
      "9 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 10.346481323242188 125.51142883300781\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 640.962158203125\n",
      "9 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 5.763970375061035 85.802001953125\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1013.8162841796875\n",
      "9 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 12.561807632446289 146.12332153320312\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 588.6904907226562\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.012405393361405004\n",
      "0 0.011748683187761344\n",
      "1 0.011171011916303542\n",
      "2 0.010817369493452134\n",
      "3 0.010600043446174823\n",
      "4 0.010457479125761893\n",
      "5 0.010360207332269056\n",
      "6 0.010280726786731975\n",
      "7 0.010253670192469144\n",
      "err after  0.010251106294163037\n",
      "10 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 1.0494799613952637 72.96328735351562\n",
      "sparsity check 2.0\n",
      "err_spxsp 529.4990844726562\n",
      "10 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 18.556365966796875 322.4383544921875\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 23.80730628967285\n",
      "10 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.046539306640625 134.53640747070312\n",
      "sparsity check 2.0\n",
      "err_spxsp 148.47036743164062\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.011376040056347847\n",
      "0 0.011211005672521424\n",
      "1 0.010979117014358053\n",
      "2 0.010780813587189186\n",
      "3 0.010668639788491419\n",
      "4 0.010596818636258831\n",
      "5 0.010511995293200016\n",
      "6 0.010469914846908068\n",
      "7 0.01045666609206819\n",
      "err after  0.010455168474436505\n",
      "10 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 1.1143028736114502 93.9135513305664\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 273.847900390625\n",
      "10 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.404840469360352 114.09749603271484\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 653.7384643554688\n",
      "10 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 5.32466983795166 78.11729431152344\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 955.2476806640625\n",
      "10 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 11.330167770385742 133.0609130859375\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 594.197998046875\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.011917142532183789\n",
      "0 0.011326657124300255\n",
      "1 0.010798770912515465\n",
      "2 0.010487311006727396\n",
      "3 0.010309080949809868\n",
      "4 0.010203475885646185\n",
      "5 0.0101326347066788\n",
      "6 0.010034569164417917\n",
      "7 0.01001115996405133\n",
      "err after  0.010008921057305997\n",
      "11 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 1.4038289785385132 83.28546905517578\n",
      "sparsity check 2.0\n",
      "err_spxsp 480.3752136230469\n",
      "11 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 11.998983383178711 233.2915496826172\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 25.566545486450195\n",
      "11 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 7.842517852783203 150.13401794433594\n",
      "sparsity check 2.0\n",
      "err_spxsp 151.88119506835938\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.011003885349055054\n",
      "0 0.010857698005565908\n",
      "1 0.010647540118952747\n",
      "2 0.010464763894560747\n",
      "3 0.010347178489610087\n",
      "4 0.010282116527378093\n",
      "5 0.010225081132375635\n",
      "6 0.010182731071836315\n",
      "7 0.010169634559133556\n",
      "err after  0.01016809275461128\n",
      "11 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 1.3137884140014648 106.02893829345703\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 276.81109619140625\n",
      "11 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.58592414855957 109.29302978515625\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 675.0576782226562\n",
      "11 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 4.799404144287109 75.43949890136719\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 941.9658813476562\n",
      "11 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 10.055471420288086 126.36055755615234\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 607.7005615234375\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.012768793465511408\n",
      "0 0.011912685662537115\n",
      "1 0.011326486182952067\n",
      "2 0.011058868072723271\n",
      "3 0.010887753127462929\n",
      "4 0.010701055289246142\n",
      "5 0.010603110149531858\n",
      "6 0.01054356749591534\n",
      "7 0.010520929190533934\n",
      "err after  0.010518503422645153\n",
      "12 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 1.0536408424377441 41.11920928955078\n",
      "sparsity check 2.0\n",
      "err_spxsp 459.1900634765625\n",
      "12 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 15.26688003540039 246.15289306640625\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 36.69236373901367\n",
      "12 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.835906982421875 130.31703186035156\n",
      "sparsity check 2.0\n",
      "err_spxsp 166.2218780517578\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.011680256284307688\n",
      "0 0.011511545511893928\n",
      "1 0.011281517217867076\n",
      "2 0.011080106040026294\n",
      "3 0.010972268584737321\n",
      "4 0.010900349272560561\n",
      "5 0.010817146561748814\n",
      "6 0.010772371952043613\n",
      "7 0.010758812863059575\n",
      "err after  0.01075728200521553\n",
      "12 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 1.0973320007324219 91.41328430175781\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 267.69512939453125\n",
      "12 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.453657150268555 110.77204132080078\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 717.133056640625\n",
      "12 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 4.568536758422852 73.53372192382812\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 945.400146484375\n",
      "12 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 10.96445083618164 132.66867065429688\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 660.3679809570312\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.013202037000155542\n",
      "0 0.012619454450032208\n",
      "1 0.012027703629428288\n",
      "2 0.011649117888737237\n",
      "3 0.011526216829224722\n",
      "4 0.011380293450201862\n",
      "5 0.011226163202081807\n",
      "6 0.011152973485877737\n",
      "7 0.011129598631669069\n",
      "err after  0.01112697615462821\n",
      "13 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 1.1543505191802979 59.82682800292969\n",
      "sparsity check 2.0\n",
      "err_spxsp 450.80181884765625\n",
      "13 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 11.747334480285645 185.76260375976562\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 31.266223907470703\n",
      "13 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.877817153930664 165.21554565429688\n",
      "sparsity check 2.0\n",
      "err_spxsp 159.42941284179688\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.012357333354884759\n",
      "0 0.012181090613012202\n",
      "1 0.01192260424795677\n",
      "2 0.011704497435857775\n",
      "3 0.011591922913794406\n",
      "4 0.011503139005071716\n",
      "5 0.011409330421884079\n",
      "6 0.011371677865099628\n",
      "7 0.011358196912624408\n",
      "err after  0.011356572278600652\n",
      "13 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.8556956052780151 83.57119750976562\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 261.1387939453125\n",
      "13 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.683857917785645 112.44146728515625\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 709.5204467773438\n",
      "13 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 4.571908950805664 73.38868713378906\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 933.2982177734375\n",
      "13 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 12.386608123779297 145.78729248046875\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 656.08837890625\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.014688299939734861\n",
      "0 0.013980448246002197\n",
      "1 0.013348888136533787\n",
      "2 0.012933531765156658\n",
      "3 0.01274378636298934\n",
      "4 0.012557752757857088\n",
      "5 0.01244178778870264\n",
      "6 0.012363544326944975\n",
      "7 0.012335378232819494\n",
      "err after  0.012332451344263973\n",
      "14 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 1.1062229871749878 60.596107482910156\n",
      "sparsity check 2.0\n",
      "err_spxsp 456.7262878417969\n",
      "14 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 10.858565330505371 161.92788696289062\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 29.403362274169922\n",
      "14 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.675824165344238 150.44400024414062\n",
      "sparsity check 2.0\n",
      "err_spxsp 157.71975708007812\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.01361592358807684\n",
      "0 0.013452887789753731\n",
      "1 0.01319628882993129\n",
      "2 0.012963408866198733\n",
      "3 0.012820304866181687\n",
      "4 0.01273233005849761\n",
      "5 0.012649694577703485\n",
      "6 0.012604121948243119\n",
      "7 0.012588933248480316\n",
      "err after  0.012587095494382083\n",
      "14 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.9846162796020508 113.75788879394531\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 302.10992431640625\n",
      "14 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.985013961791992 124.92034912109375\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 723.0130615234375\n",
      "14 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 5.178436279296875 78.21601867675781\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 995.3681640625\n",
      "14 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 14.565177917480469 165.3238525390625\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 666.6700439453125\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.016558893697947497\n",
      "0 0.015873397715040483\n",
      "1 0.015162286970735295\n",
      "2 0.014674632784590358\n",
      "3 0.014425722856685752\n",
      "4 0.01423988269016263\n",
      "5 0.01410364512776141\n",
      "6 0.014000522263813764\n",
      "7 0.013969862709927838\n",
      "err after  0.013966282102046534\n",
      "15 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.8965739011764526 37.78181457519531\n",
      "sparsity check 2.0\n",
      "err_spxsp 538.7940673828125\n",
      "15 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 25.800323486328125 590.1173706054688\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 46.57728958129883\n",
      "15 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.305561065673828 135.094482421875\n",
      "sparsity check 2.0\n",
      "err_spxsp 183.79490661621094\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.015692073397076456\n",
      "0 0.015460615057236282\n",
      "1 0.015092555848241318\n",
      "2 0.014789033913984895\n",
      "3 0.01460336883974378\n",
      "4 0.014496216750558233\n",
      "5 0.014410255174880149\n",
      "6 0.014354776871186914\n",
      "7 0.014336418767925352\n",
      "err after  0.014334153485833667\n",
      "15 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 1.086221694946289 91.72576904296875\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 281.8848876953125\n",
      "15 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 11.518994331359863 140.71780395507812\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 699.4571533203125\n",
      "15 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 5.888193130493164 87.75413513183594\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1035.799560546875\n",
      "15 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 19.18511962890625 210.92532348632812\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 662.188232421875\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.019563648322218796\n",
      "0 0.01893433675650158\n",
      "1 0.018150283765862696\n",
      "2 0.017546110961120576\n",
      "3 0.017175782933918526\n",
      "4 0.01689730708312709\n",
      "5 0.016737709196604555\n",
      "6 0.016632030910841422\n",
      "7 0.016598112077190308\n",
      "err after  0.016594188226008555\n",
      "16 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.8187479376792908 45.16477966308594\n",
      "sparsity check 2.0\n",
      "err_spxsp 575.4278564453125\n",
      "16 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 26.52083969116211 431.65234375\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 33.92329406738281\n",
      "16 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.120551109313965 126.54366302490234\n",
      "sparsity check 2.0\n",
      "err_spxsp 171.87008666992188\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.01811602019733982\n",
      "0 0.017944338756933575\n",
      "1 0.017635429052461404\n",
      "2 0.017323188159934944\n",
      "3 0.017106508035794832\n",
      "4 0.01696334268854116\n",
      "5 0.01687078487520921\n",
      "6 0.01681827605716535\n",
      "7 0.016797279291495215\n",
      "err after  0.016794619099528063\n",
      "16 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.9432898759841919 107.8004150390625\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 311.1025390625\n",
      "16 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 12.065853118896484 145.7529754638672\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 673.8551025390625\n",
      "16 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 7.0396270751953125 97.90751647949219\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1083.3653564453125\n",
      "16 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 20.104537963867188 226.42813110351562\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 646.5306396484375\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.023010257275018375\n",
      "0 0.022318434224871453\n",
      "1 0.02148074177239323\n",
      "2 0.020721071407024283\n",
      "3 0.020260979370505083\n",
      "4 0.02001146144903032\n",
      "5 0.01984687532603857\n",
      "6 0.019702802845131373\n",
      "7 0.019659896850498626\n",
      "err after  0.019655126521683997\n",
      "17 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.7203125953674316 35.38856506347656\n",
      "sparsity check 2.0\n",
      "err_spxsp 572.8230590820312\n",
      "17 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 35.53400421142578 604.23974609375\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 47.75688552856445\n",
      "17 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 7.136258125305176 99.0333251953125\n",
      "sparsity check 2.0\n",
      "err_spxsp 199.6610870361328\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.021124257320479956\n",
      "0 0.020937877336109523\n",
      "1 0.02063682059815619\n",
      "2 0.020324936558608897\n",
      "3 0.020097308592085028\n",
      "4 0.019945408574130852\n",
      "5 0.019839862492517568\n",
      "6 0.019778730344114592\n",
      "7 0.019754235461732605\n",
      "err after  0.019751050629565725\n",
      "17 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.8534136414527893 103.6420669555664\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 303.3584899902344\n",
      "17 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 13.04890251159668 151.93582153320312\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 666.41162109375\n",
      "17 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.142053604125977 105.82382202148438\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1104.326904296875\n",
      "17 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 24.00214195251465 265.18402099609375\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 649.3059692382812\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.025187228791764937\n",
      "0 0.024682217779627535\n",
      "1 0.023896208374935668\n",
      "2 0.02311683003063081\n",
      "3 0.02263599314755993\n",
      "4 0.022357920170179568\n",
      "5 0.022146694944240153\n",
      "6 0.0220265208336059\n",
      "7 0.021985532825055998\n",
      "err after  0.021981108926411252\n",
      "18 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.842880129814148 35.704105377197266\n",
      "sparsity check 2.0\n",
      "err_spxsp 600.9321899414062\n",
      "18 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 38.32423400878906 654.5672607421875\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 32.436439514160156\n",
      "18 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 4.606368064880371 63.8070068359375\n",
      "sparsity check 2.0\n",
      "err_spxsp 190.03732299804688\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.022925547025806736\n",
      "0 0.02282107694190927\n",
      "1 0.022617022121266928\n",
      "2 0.022355052664352115\n",
      "3 0.022137507214210927\n",
      "4 0.02198424420930678\n",
      "5 0.02187872477225028\n",
      "6 0.021817698747327086\n",
      "7 0.021792920510051772\n",
      "err after  0.02178987718798453\n",
      "18 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.657600998878479 55.1666259765625\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 294.7596435546875\n",
      "18 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 12.549680709838867 148.70860290527344\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 652.4073486328125\n",
      "18 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.978606224060059 113.97874450683594\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1112.32666015625\n",
      "18 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 21.350658416748047 237.48094177246094\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 639.740478515625\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.027278992820356507\n",
      "0 0.0267374141185428\n",
      "1 0.02597478637471795\n",
      "2 0.02521421437995741\n",
      "3 0.024757086248428095\n",
      "4 0.02448817272670567\n",
      "5 0.02424421245814301\n",
      "6 0.024132811289746314\n",
      "7 0.024094297899864614\n",
      "err after  0.02408998018654529\n",
      "19 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.6422399282455444 25.521358489990234\n",
      "sparsity check 2.0\n",
      "err_spxsp 614.726806640625\n",
      "19 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 48.12587356567383 934.3994140625\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 38.693572998046875\n",
      "19 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 3.704493522644043 50.46122360229492\n",
      "sparsity check 2.0\n",
      "err_spxsp 198.8514404296875\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.024947259873442817\n",
      "0 0.024844977881002706\n",
      "1 0.024654018408909906\n",
      "2 0.024397922308708075\n",
      "3 0.02417741721001221\n",
      "4 0.02401942468713969\n",
      "5 0.02391280105803162\n",
      "6 0.02385204115853412\n",
      "7 0.02382731159013929\n",
      "err after  0.023824242052796762\n",
      "19 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.6489914059638977 64.94819641113281\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 318.56787109375\n",
      "19 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 12.171798706054688 135.9532470703125\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 643.900634765625\n",
      "19 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.321552276611328 112.07603454589844\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1125.5382080078125\n",
      "19 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 19.837736129760742 223.196533203125\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 641.4732055664062\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.028434515690605622\n",
      "0 0.027953509859798942\n",
      "1 0.02727591614529956\n",
      "2 0.026561332422716077\n",
      "3 0.026088686368893832\n",
      "4 0.025795352994464338\n",
      "5 0.025606446972233243\n",
      "6 0.025499436786049046\n",
      "7 0.025461879500653595\n",
      "err after  0.02545759661734337\n",
      "20 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.6799367666244507 32.2389030456543\n",
      "sparsity check 2.0\n",
      "err_spxsp 640.6380615234375\n",
      "20 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 46.90347671508789 883.014892578125\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 42.777320861816406\n",
      "20 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 3.999985456466675 49.88782501220703\n",
      "sparsity check 2.0\n",
      "err_spxsp 193.60299682617188\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.02634023937571328\n",
      "0 0.026246796616760548\n",
      "1 0.026062625423946884\n",
      "2 0.025810949126025662\n",
      "3 0.02559630880568875\n",
      "4 0.025443288403039332\n",
      "5 0.02533991917880485\n",
      "6 0.025279913177655544\n",
      "7 0.025255450447730254\n",
      "err after  0.025252462080970872\n",
      "20 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.5778592228889465 76.3643798828125\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 304.57275390625\n",
      "20 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 12.085344314575195 136.78370666503906\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 647.2633056640625\n",
      "20 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.77895736694336 169.0137176513672\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1144.9801025390625\n",
      "20 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 18.825273513793945 213.01547241210938\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 647.3140869140625\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.030739346315385774\n",
      "0 0.030226949536881875\n",
      "1 0.029547663281846326\n",
      "2 0.028847738220065366\n",
      "3 0.0283953756588744\n",
      "4 0.028126326673373114\n",
      "5 0.027914878490264527\n",
      "6 0.027797442679002415\n",
      "7 0.027756161136494484\n",
      "err after  0.027751318775699474\n",
      "21 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.5968358516693115 28.373140335083008\n",
      "sparsity check 2.0\n",
      "err_spxsp 604.4771728515625\n",
      "21 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 46.0216178894043 756.3050537109375\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 46.10675048828125\n",
      "21 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 5.027043342590332 81.43279266357422\n",
      "sparsity check 2.0\n",
      "err_spxsp 218.36328125\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.028836567456892226\n",
      "0 0.028708675286907237\n",
      "1 0.028481333487434313\n",
      "2 0.0282004988839617\n",
      "3 0.027972669013252016\n",
      "4 0.027810346204205416\n",
      "5 0.027700296632247046\n",
      "6 0.02763894999952754\n",
      "7 0.027613255369942635\n",
      "err after  0.027609929762547836\n",
      "21 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.6338109970092773 96.96500396728516\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 363.63372802734375\n",
      "21 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 12.000635147094727 136.90646362304688\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 653.4180908203125\n",
      "21 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.76732063293457 116.12586975097656\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1150.7127685546875\n",
      "21 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 19.204326629638672 220.45437622070312\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 663.041748046875\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.03179352600272978\n",
      "0 0.031435275785042904\n",
      "1 0.03083380414318526\n",
      "2 0.03015702161792433\n",
      "3 0.029702571904635988\n",
      "4 0.02940676650905516\n",
      "5 0.029213936286396347\n",
      "6 0.029112364936736412\n",
      "7 0.029074711732391734\n",
      "err after  0.029070319600577932\n",
      "22 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.5560173988342285 27.844636917114258\n",
      "sparsity check 2.0\n",
      "err_spxsp 578.7540893554688\n",
      "22 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 37.391136169433594 590.348388671875\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 49.14501953125\n",
      "22 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 4.012679576873779 48.151573181152344\n",
      "sparsity check 2.0\n",
      "err_spxsp 211.20281982421875\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.02983620318264002\n",
      "0 0.02976506989216432\n",
      "1 0.029605156829347834\n",
      "2 0.02936495627363911\n",
      "3 0.029152152710594237\n",
      "4 0.028996797831496224\n",
      "5 0.0288903184919036\n",
      "6 0.028830510505940765\n",
      "7 0.028805320893297903\n",
      "err after  0.028802209635614417\n",
      "22 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.43601250648498535 64.17420959472656\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 296.73614501953125\n",
      "22 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 10.959365844726562 121.59285736083984\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 649.2923583984375\n",
      "22 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.265790939331055 110.70352172851562\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1141.9068603515625\n",
      "22 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 17.047664642333984 191.748779296875\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 663.1114501953125\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.03137130292452639\n",
      "0 0.031022135830426123\n",
      "1 0.030497298044792842\n",
      "2 0.02989920888649067\n",
      "3 0.029485927916539367\n",
      "4 0.029217916169727687\n",
      "5 0.029043696420558263\n",
      "6 0.02894812096201349\n",
      "7 0.028913593414472416\n",
      "err after  0.028909776905493345\n",
      "23 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.5684727430343628 23.40460968017578\n",
      "sparsity check 2.0\n",
      "err_spxsp 558.482666015625\n",
      "23 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 52.083221435546875 815.3155517578125\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 53.47233963012695\n",
      "23 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 3.328648328781128 36.30033493041992\n",
      "sparsity check 2.0\n",
      "err_spxsp 218.69863891601562\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.02958168458280852\n",
      "0 0.02950789780152263\n",
      "1 0.029356077226111665\n",
      "2 0.029135617172869388\n",
      "3 0.02893820337340003\n",
      "4 0.028793646226404235\n",
      "5 0.028695263303234242\n",
      "6 0.028638601885177195\n",
      "7 0.028615064242330845\n",
      "err after  0.028612352682102937\n",
      "23 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.5518499612808228 83.38567352294922\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 312.19744873046875\n",
      "23 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 10.279671669006348 166.23971557617188\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 661.18505859375\n",
      "23 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.015881538391113 121.94808959960938\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1155.67236328125\n",
      "23 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 14.379469871520996 165.74075317382812\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 664.399658203125\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.02976358777959831\n",
      "0 0.029500364646082744\n",
      "1 0.029052498917735647\n",
      "2 0.02853003288328182\n",
      "3 0.028170911537017673\n",
      "4 0.027938226419792045\n",
      "5 0.027771641201979946\n",
      "6 0.027681084065989126\n",
      "7 0.02764921580819646\n",
      "err after  0.02764572932937881\n",
      "24 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.48186182975769043 24.291229248046875\n",
      "sparsity check 2.0\n",
      "err_spxsp 541.4312744140625\n",
      "24 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 54.446529388427734 919.687255859375\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 60.89148712158203\n",
      "24 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 3.225983142852783 37.3940315246582\n",
      "sparsity check 2.0\n",
      "err_spxsp 245.49705505371094\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.02821943472372368\n",
      "0 0.02816016050928738\n",
      "1 0.028036626528773922\n",
      "2 0.027841218026878778\n",
      "3 0.027662202170176897\n",
      "4 0.027534371525689494\n",
      "5 0.02744965927558951\n",
      "6 0.02739043111068895\n",
      "7 0.027367761569621507\n",
      "err after  0.027365361172996927\n",
      "24 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.42945343255996704 51.73432922363281\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 250.64002990722656\n",
      "24 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.484593391418457 109.25418090820312\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 661.8756103515625\n",
      "24 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.937222480773926 280.4832458496094\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1202.8740234375\n",
      "24 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 11.946869850158691 143.02789306640625\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 667.042236328125\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.027467288353363983\n",
      "0 0.027238150607445277\n",
      "1 0.026874012895859778\n",
      "2 0.026449471908563282\n",
      "3 0.026146151380089577\n",
      "4 0.025935310863133054\n",
      "5 0.025793877670366783\n",
      "6 0.025714406205224805\n",
      "7 0.025685046362923458\n",
      "err after  0.025681825223728083\n",
      "25 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.5773006677627563 30.23148536682129\n",
      "sparsity check 2.0\n",
      "err_spxsp 545.1275634765625\n",
      "25 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 56.0528678894043 1031.706298828125\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 64.39720153808594\n",
      "25 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 2.7167601585388184 33.37953567504883\n",
      "sparsity check 2.0\n",
      "err_spxsp 251.8390350341797\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.026250883063767105\n",
      "0 0.02617141195514705\n",
      "1 0.026044053251098376\n",
      "2 0.02586749251349829\n",
      "3 0.025709469977300614\n",
      "4 0.025604166949051432\n",
      "5 0.02553075510513736\n",
      "6 0.025470066160778515\n",
      "7 0.02545110548089724\n",
      "err after  0.02544902310182806\n",
      "25 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.7366563081741333 610.078857421875\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 309.1354675292969\n",
      "25 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.00748348236084 113.49372863769531\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 679.1292724609375\n",
      "25 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.33479118347168 105.682373046875\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1196.6412353515625\n",
      "25 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 9.826909065246582 117.90719604492188\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 673.9280395507812\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.025259216752601787\n",
      "0 0.025025481372722425\n",
      "1 0.024696457774552982\n",
      "2 0.024316695205925498\n",
      "3 0.024044504589255666\n",
      "4 0.02386586584179895\n",
      "5 0.023745150050672237\n",
      "6 0.023675436663324945\n",
      "7 0.023650116705539403\n",
      "err after  0.023647085654374678\n",
      "26 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.4288507103919983 28.94350242614746\n",
      "sparsity check 2.0\n",
      "err_spxsp 573.1712646484375\n",
      "26 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 36.36290740966797 546.2421875\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 80.79936981201172\n",
      "26 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 2.906109571456909 36.42784118652344\n",
      "sparsity check 2.0\n",
      "err_spxsp 259.379638671875\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.024288725173391867\n",
      "0 0.02419719374665874\n",
      "1 0.02406501007862971\n",
      "2 0.02389685258458485\n",
      "3 0.023751577900839038\n",
      "4 0.0236474831981468\n",
      "5 0.02357921098882798\n",
      "6 0.02353317561573931\n",
      "7 0.023514912336395355\n",
      "err after  0.02351287754936493\n",
      "26 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.4946408271789551 544.052978515625\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 342.1585388183594\n",
      "26 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.172761917114258 95.6904296875\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 694.31494140625\n",
      "26 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 7.561790943145752 89.83293151855469\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1208.903076171875\n",
      "26 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.579063415527344 103.9195556640625\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 687.3512573242188\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.023102058992662933\n",
      "0 0.02284279267405509\n",
      "1 0.022529675603436772\n",
      "2 0.02219224713917356\n",
      "3 0.02196825707142125\n",
      "4 0.021804654286825098\n",
      "5 0.021690835637855344\n",
      "6 0.021625174435030203\n",
      "7 0.021601829430437647\n",
      "err after  0.021598853138129925\n",
      "27 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.601203203201294 33.05913162231445\n",
      "sparsity check 2.0\n",
      "err_spxsp 554.4873046875\n",
      "27 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 37.732093811035156 714.23291015625\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 86.81948852539062\n",
      "27 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 2.689044952392578 37.32691192626953\n",
      "sparsity check 2.0\n",
      "err_spxsp 280.5586242675781\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.022105975771410158\n",
      "0 0.022033114553778432\n",
      "1 0.02193150186940329\n",
      "2 0.02178963339247275\n",
      "3 0.02166272066097008\n",
      "4 0.021567924275586847\n",
      "5 0.021502602194232168\n",
      "6 0.0214645759187988\n",
      "7 0.0214479223905073\n",
      "err after  0.021445799531647936\n",
      "27 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.7049944996833801 168.62220764160156\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 368.85552978515625\n",
      "27 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 7.828252792358398 110.34883117675781\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 723.2877807617188\n",
      "27 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 7.378094673156738 145.7476043701172\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1256.736328125\n",
      "27 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 7.921375274658203 105.46082305908203\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 712.4649658203125\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.020759374951012433\n",
      "0 0.02049518350395374\n",
      "1 0.0202079479022359\n",
      "2 0.019907996858819388\n",
      "3 0.019700806151377037\n",
      "4 0.01955601592635503\n",
      "5 0.01945043894374976\n",
      "6 0.01939562880579615\n",
      "7 0.019376512318558525\n",
      "err after  0.019373434552107938\n",
      "28 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.5555449724197388 21.63748550415039\n",
      "sparsity check 2.0\n",
      "err_spxsp 545.1629638671875\n",
      "28 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 35.158653259277344 604.4564208984375\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 103.95515441894531\n",
      "28 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 2.6622750759124756 35.09123611450195\n",
      "sparsity check 2.0\n",
      "err_spxsp 281.4079895019531\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.01994455818203278\n",
      "0 0.019852431116305524\n",
      "1 0.019744544315472012\n",
      "2 0.01960637186493841\n",
      "3 0.019488072648528032\n",
      "4 0.019400310120545328\n",
      "5 0.01933854895105469\n",
      "6 0.01930383240687661\n",
      "7 0.019289168529212475\n",
      "err after  0.019286895007098792\n",
      "28 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.7451112866401672 281.6340637207031\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 278.9210205078125\n",
      "28 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 7.547174453735352 108.29251098632812\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 762.899169921875\n",
      "28 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 6.830465316772461 103.35108947753906\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1219.1165771484375\n",
      "28 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 7.925426483154297 102.682373046875\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 752.8750610351562\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.01941807012190111\n",
      "0 0.019103583428659476\n",
      "1 0.01879856798404944\n",
      "2 0.018511339665565174\n",
      "3 0.0183184019770124\n",
      "4 0.018189924478065223\n",
      "5 0.018093939608661458\n",
      "6 0.018038386300759157\n",
      "7 0.018019796203589067\n",
      "err after  0.018016699239524314\n",
      "29 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.5651804804801941 21.62911605834961\n",
      "sparsity check 2.0\n",
      "err_spxsp 559.6349487304688\n",
      "29 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 43.95543670654297 1466.1036376953125\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 128.37374877929688\n",
      "29 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 2.033418893814087 34.09743118286133\n",
      "sparsity check 2.0\n",
      "err_spxsp 338.1463317871094\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.01864520704111783\n",
      "0 0.018482438084902242\n",
      "1 0.01834874172345735\n",
      "2 0.018219680914626224\n",
      "3 0.01811103801810532\n",
      "4 0.01802924846924725\n",
      "5 0.01797060553144547\n",
      "6 0.017938548440724844\n",
      "7 0.01792475457841647\n",
      "err after  0.017922415703651495\n",
      "29 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 1.006880760192871 1204.791015625\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 489.66864013671875\n",
      "29 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 7.552248001098633 278.0796203613281\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 863.7265014648438\n",
      "29 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 6.16586971282959 121.91730499267578\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1212.032470703125\n",
      "29 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.14897632598877 113.78445434570312\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 818.3211669921875\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.021423443751700688\n",
      "0 0.020862451900029555\n",
      "1 0.020417111281858524\n",
      "2 0.020048693178978283\n",
      "3 0.01982537435469567\n",
      "4 0.019682768790517002\n",
      "5 0.01956660057112458\n",
      "6 0.019489637823426165\n",
      "7 0.019465249508357374\n",
      "err after  0.019461098239844432\n",
      "30 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.29063719511032104 20.7009334564209\n",
      "sparsity check 2.0\n",
      "err_spxsp 415.78302001953125\n",
      "30 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 30.817367553710938 795.9965209960938\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 191.913818359375\n",
      "30 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 3.0982017517089844 41.68157958984375\n",
      "sparsity check 2.0\n",
      "err_spxsp 312.5718078613281\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.020538237840810325\n",
      "0 0.02027097796963062\n",
      "1 0.02008192921493901\n",
      "2 0.019908016543922713\n",
      "3 0.01976847442347207\n",
      "4 0.01966718406038126\n",
      "5 0.019595351001044037\n",
      "6 0.01955320822162321\n",
      "7 0.019535365427145734\n",
      "err after  0.019531925856426824\n",
      "30 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 0.8062591552734375 528.6221313476562\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 273.29901123046875\n",
      "30 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 8.725545883178711 231.55091857910156\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 912.0755004882812\n",
      "30 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 6.816112518310547 113.24508666992188\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1273.58203125\n",
      "30 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 10.829842567443848 179.37896728515625\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 881.7763671875\n",
      "dict_keys(['self_attn.q_proj.weight', 'self_attn.k_proj.weight', 'self_attn.v_proj.weight', 'self_attn.o_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.06405745532538276\n",
      "0 0.06085182288370561\n",
      "1 0.058100425550946966\n",
      "2 0.05684900423511863\n",
      "3 0.056295137022971176\n",
      "4 0.05588839766278397\n",
      "5 0.05534700020507444\n",
      "6 0.054917694280447904\n",
      "7 0.05478246131679043\n",
      "err after  0.0547542526255711\n",
      "31 self_attn.q_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 0.43306291103363037 12.985605239868164\n",
      "sparsity check 2.0\n",
      "err_spxsp 564.2828369140625\n",
      "31 self_attn.v_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 14.136063575744629 560.2271728515625\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 140.1163330078125\n",
      "31 self_attn.o_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 4.143568992614746 108.0620346069336\n",
      "sparsity check 2.0\n",
      "err_spxsp 374.3160095214844\n",
      "dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])\n",
      "err before 0.05761323377373628\n",
      "0 0.056677902830415405\n",
      "1 0.0561331365533988\n",
      "2 0.055663020655629225\n",
      "3 0.055358834950311575\n",
      "4 0.055068255998776294\n",
      "5 0.0547105807781918\n",
      "6 0.054522344275028445\n",
      "7 0.05444257224735338\n",
      "err after  0.05442805673374096\n",
      "31 self_attn.k_proj\n",
      "Pruning ...\n",
      "mid 1024\n",
      "err 259 1.1408226490020752 229.0838165283203\n",
      "sparsity check 1.99951171875\n",
      "err_spxsp 332.3433837890625\n",
      "31 mlp.up_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 15.306169509887695 515.9136962890625\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1307.301025390625\n",
      "31 mlp.gate_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 10.632704734802246 207.28602600097656\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1680.9462890625\n",
      "31 mlp.down_proj\n",
      "Pruning ...\n",
      "mid 4096\n",
      "err 259 38.15721893310547 2327.344970703125\n",
      "sparsity check 1.9998256138392858\n",
      "err_spxsp 1113.85693359375\n",
      "total time 28280.0925488472\n"
     ]
    }
   ],
   "source": [
    "import cupy\n",
    "\n",
    "def my_pack(x):\n",
    "    x = (x == 1).to(torch.uint8)\n",
    "    out = torch.zeros((x.shape[0]//8), device=x.device, dtype=torch.uint8)\n",
    "    for i in range(8):\n",
    "        out += x[i::8] << (7 - i)\n",
    "    return out\n",
    "\n",
    "@torch.compile\n",
    "def my_unpack(x):\n",
    "    out = torch.zeros((x.shape[0], 8), device=x.device, dtype=torch.int8)\n",
    "    for i in range(8):\n",
    "        out[:,i] = (x >> (7 - i)) & 1\n",
    "    return out.flatten() * 2 - 1\n",
    "\n",
    "def power_iteration(A, num_iters=5):\n",
    "    \"\"\"\n",
    "    Performs power iteration to compute the top singular vectors and value.\n",
    "    \n",
    "    Arguments:\n",
    "        A (torch.Tensor): The input matrix of shape (m, n).\n",
    "        num_iters (int): Number of iterations to perform.\n",
    "    \n",
    "    Returns:\n",
    "        u (torch.Tensor): Dominant left singular vector (m,).\n",
    "        sigma (torch.Tensor): Dominant singular value (scalar).\n",
    "        v (torch.Tensor): Dominant right singular vector (n,).\n",
    "    \"\"\"\n",
    "    # Start with a random vector on the appropriate device\n",
    "    n = A.shape[1]\n",
    "    v = torch.randn(n, device=A.device)\n",
    "    v = v / torch.norm(v)\n",
    "    \n",
    "    for _ in range(num_iters):\n",
    "        # Multiply A*v\n",
    "        u = torch.mv(A, v)\n",
    "        u_norm = torch.norm(u)\n",
    "        if u_norm == 0:\n",
    "            break\n",
    "        u = u / u_norm\n",
    "        \n",
    "        # Multiply A^T*u\n",
    "        v = torch.mv(A.t(), u)\n",
    "        v_norm = torch.norm(v)\n",
    "        if v_norm == 0:\n",
    "            break\n",
    "        v = v / v_norm\n",
    "    \n",
    "    # Estimate the dominant singular value as ||A*v||\n",
    "    sigma = torch.norm(torch.mv(A, v))\n",
    "    # The left singular vector corresponding to sigma:\n",
    "    u = torch.mv(A, v) / sigma\n",
    "    return u, sigma, v\n",
    "\n",
    "def svd_abs2(W):\n",
    "    Sg = W.sign()\n",
    "    Sg[Sg == 0] = 1\n",
    "    u, s, v = power_iteration(W.abs(), num_iters=5)\n",
    "    apx = s * torch.ger(u, v)\n",
    "    \n",
    "    return u * s, Sg, v\n",
    "\n",
    "class BitLinear(nn.Module):\n",
    "    def __init__(self, b):\n",
    "        super().__init__()\n",
    "        \n",
    "        #u, b, v = svd_abs(w.float())\n",
    "        b_packed = my_pack(b.flatten())\n",
    "        self.shape = b.shape\n",
    "        #print(b)\n",
    "        #print(my_unpack(b_packed).reshape(b.shape))\n",
    "        \n",
    "        self.register_buffer(\"bp\", b_packed)\n",
    "\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return x.matmul(my_unpack(self.bp).reshape(self.shape).T.to(x.dtype))\n",
    "        \n",
    "\n",
    "        \n",
    "class Mul(nn.Module):\n",
    "    def __init__(self, w):\n",
    "        super().__init__()\n",
    "        #print(\"w\", w.amin().item(), w.median().item(), w.amax().item())\n",
    "        \n",
    "        self.register_buffer(\"w\", w)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return x * self.w.to(x.dtype)\n",
    "        \n",
    "\n",
    "def replace(lx):\n",
    "    dev = \"cuda\"\n",
    "    m1 = lx.weight.B\n",
    "    m2 = lx.weight.A\n",
    "    \n",
    "    u1, b1, v1 = svd_abs2(m1.float())\n",
    "    u2, b2, v2 = svd_abs2(m2.float())\n",
    "    \n",
    "    lx2 = nn.Sequential(\n",
    "        Mul(v1),\n",
    "        BitLinear(b1),\n",
    "        Mul(u1*lx.weight.mid*v2),\n",
    "        BitLinear(b2),\n",
    "        Mul(u2)\n",
    "    )\n",
    "    return lx2\n",
    "\n",
    "@torch.no_grad()\n",
    "def opt_sequential(model, dataloader, dev):\n",
    "    print('Starting ...')\n",
    "    \n",
    "    model.cpu()\n",
    "    model.gradient_checkpointing_disable()\n",
    "    model.eval()\n",
    "    use_cache = model.config.use_cache\n",
    "    model.config.use_cache = False\n",
    "    layers = model.model.layers\n",
    "    \n",
    "\n",
    "    model.model.embed_tokens = model.model.embed_tokens.to(dev) \n",
    "    model.model.rotary_emb = model.model.rotary_emb.to(dev)\n",
    "    layers[0] = layers[0].to(dev)\n",
    "\n",
    "    dtype = next(iter(model.parameters())).dtype\n",
    "    inps = torch.zeros(\n",
    "        (n_samples, model.seqlen, model.config.hidden_size), dtype=dtype, device=\"cpu\"\n",
    "    )\n",
    "    cache = {'i': 0, 'attention_mask': None}\n",
    "\n",
    "    class Catcher(nn.Module):\n",
    "        def __init__(self, module):\n",
    "            super().__init__()\n",
    "            self.module = module\n",
    "        def forward(self, inp, **kwargs):\n",
    "            inps[cache['i']] = inp\n",
    "            cache['i'] += 1\n",
    "            cache['attention_mask'] = kwargs['attention_mask']\n",
    "            cache['position_embeddings'] = kwargs['position_embeddings']\n",
    "            raise ValueError\n",
    "    layers[0] = Catcher(layers[0])\n",
    "    for batch in dataloader:\n",
    "        try:\n",
    "            model(batch.unsqueeze(0).to(dev))\n",
    "        except ValueError:\n",
    "            pass\n",
    "    layers[0] = layers[0].module\n",
    "\n",
    "    layers[0] = layers[0].cpu()\n",
    "    model.model.embed_tokens = model.model.embed_tokens.cpu()\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    comp_inps = inps.clone()\n",
    "    attention_mask = cache['attention_mask']\n",
    "    position_embeddings = cache['position_embeddings']\n",
    "\n",
    "    print('Ready.')\n",
    "\n",
    "    layers = model.model.layers\n",
    "\n",
    "    for i in range(len(layers)):\n",
    "        layer = layers[i].to(dev)\n",
    "\n",
    "        subset = find_layers(layer)\n",
    "        for j in range(n_samples):\n",
    "            inps[j] = layer(inps[j].unsqueeze(0).cuda(), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]\n",
    "\n",
    "        imp = layer.mlp.down_proj.o_norm\n",
    "        \n",
    "        for name in [\n",
    "            \"self_attn.q_proj\",\n",
    "            \"self_attn.v_proj\",\n",
    "            \"self_attn.o_proj\",\n",
    "            \"self_attn.k_proj\",\n",
    "            \"mlp.up_proj\",\n",
    "            \"mlp.gate_proj\",\n",
    "            \"mlp.down_proj\",\n",
    "        ]:\n",
    "            #if \"gate_proj\" not in name:\n",
    "            #    continue\n",
    "            to_opt = {n: p for n, p in layer.named_parameters() if \"weight\" in n and \"layernorm\" not in n}\n",
    "            if len(to_opt) > 0 and ((\"q_proj\" in name and i >= 1) or \"k_proj\" in name):\n",
    "                \n",
    "                for n, p in to_opt.items():\n",
    "                    p.requires_grad = True\n",
    "                print(to_opt.keys())\n",
    "\n",
    "                #opt = torch.optim.Adam(to_opt.values(), lr=1e-5)\n",
    "                #opt = Lamb(to_opt.values(), lr=1e-3, weight_decay=1e-4)\n",
    "                #sch = torch.optim.lr_scheduler.LinearLR(opt, start_factor=1e-8, total_iters=16)\n",
    "                lr = 3e-5\n",
    "                opt = BF16FusedAdamW(to_opt.values(), lr, weight_decay=1e-4)\n",
    "                sch = torch.optim.lr_scheduler.OneCycleLR(opt, lr, total_steps=n_samples*8 // 8, cycle_momentum=False)\n",
    "                err_before = 0\n",
    "                for j in range(n_samples):\n",
    "                    cur_out = layer(comp_inps[j].unsqueeze(0).cuda(), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]\n",
    "                    err_before += ((cur_out.float() - inps[j].cuda().float()).square() * imp).mean().item()\n",
    "\n",
    "                print(\"err before\", err_before)\n",
    "\n",
    "                with torch.enable_grad():\n",
    "                    for ep in range(8):\n",
    "                        err_total = 0\n",
    "                        for j in range(n_samples):\n",
    "                            cur_out = layer(comp_inps[j].unsqueeze(0).cuda(), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]\n",
    "                            err = ((cur_out.float() - inps[j].cuda().float()).square() * imp).sum()\n",
    "                            err.backward()\n",
    "                            if j % 8 == 7:\n",
    "                                opt.step()\n",
    "                                sch.step()\n",
    "                                layer.zero_grad(set_to_none=True)\n",
    "                            err_total += err.item() / inps.shape[1] / inps.shape[2]\n",
    "                        print(ep, err_total)\n",
    "\n",
    "                err_after = 0\n",
    "                for j in range(n_samples):\n",
    "                    cur_out = layer(comp_inps[j].unsqueeze(0).cuda(), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]\n",
    "                    err_after += ((cur_out.float() - inps[j].cuda().float()).square() * imp).mean().item()\n",
    "                print(\"err after \", err_after)\n",
    "                        \n",
    "            #gpts[name].free()\n",
    "            \n",
    "            print(i, name)\n",
    "            print('Pruning ...')\n",
    "            lx = subset[name]\n",
    "            \n",
    "            s1 = time.time()\n",
    "            #W2 = go_admm(lx)\n",
    "            W2, Ac, Bb, mid, = factorize(lx)\n",
    "            W2 = W2.T\n",
    "            err_spxsp = (W2.T - lx.weight).square().sum().item()\n",
    "            print(\"err_spxsp\", err_spxsp)\n",
    "            \n",
    "            \n",
    "            lx.weight.data = W2.T.to(lx.weight)\n",
    "            lx.weight.A = Ac\n",
    "            lx.weight.B = Bb\n",
    "            lx.weight.mid = mid\n",
    "            parts = name.split('.')\n",
    "            block = getattr(layer, parts[0])\n",
    "            setattr(block, parts[1], replace(lx))\n",
    "            \n",
    "            \n",
    "            \n",
    "        for j in range(n_samples):\n",
    "            comp_inps[j] = layer(comp_inps[j].unsqueeze(0).cuda(), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]\n",
    "            \n",
    "        layers[i] = layer.cpu()\n",
    "        del layer\n",
    "        \n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "        \n",
    "start = time.time()\n",
    "model.cpu()\n",
    "opt_sequential(model, dataloader, DEV)\n",
    "print(\"total time\", time.time() - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fef7480f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-26T11:53:30.401542Z",
     "iopub.status.busy": "2023-12-26T11:53:30.401087Z",
     "iopub.status.idle": "2023-12-26T11:53:30.420226Z",
     "shell.execute_reply": "2023-12-26T11:53:30.419662Z"
    },
    "papermill": {
     "duration": 0.038429,
     "end_time": "2023-12-26T11:53:30.422134",
     "exception": false,
     "start_time": "2023-12-26T11:53:30.383705",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def opt_eval(model, testenc, dev, dataset: str, log_wandb: bool = False):\n",
    "    print('Evaluating ...')\n",
    "\n",
    "    testenc = testenc.input_ids\n",
    "    nsamples = testenc.numel() // model.seqlen\n",
    "\n",
    "    use_cache = model.config.use_cache\n",
    "    model.config.use_cache = False\n",
    "    layers = model.model.layers\n",
    "\n",
    "    model.model.embed_tokens = model.model.embed_tokens.to(dev)\n",
    "    model.model.rotary_emb = model.model.rotary_emb.to(dev)\n",
    "    layers[0] = layers[0].to(dev)\n",
    "\n",
    "    dtype = next(iter(model.parameters())).dtype\n",
    "    inps = torch.zeros(\n",
    "        (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev\n",
    "    )\n",
    "    cache = {'i': 0, 'attention_mask': None}\n",
    "\n",
    "    class Catcher(nn.Module):\n",
    "        def __init__(self, module):\n",
    "            super().__init__()\n",
    "            self.module = module\n",
    "        def forward(self, inp, **kwargs):\n",
    "            inps[cache['i']] = inp\n",
    "            cache['i'] += 1\n",
    "            cache['attention_mask'] = kwargs['attention_mask']\n",
    "            cache['position_embeddings'] = kwargs['position_embeddings']\n",
    "            raise ValueError\n",
    "    layers[0] = Catcher(layers[0])\n",
    "    for i in range(nsamples):\n",
    "        batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)\n",
    "        try:\n",
    "            model(batch)\n",
    "        except ValueError:\n",
    "            pass\n",
    "    layers[0] = layers[0].module\n",
    "\n",
    "    layers[0] = layers[0].cpu()\n",
    "    model.model.embed_tokens = model.model.embed_tokens.cpu()\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    outs = torch.zeros_like(inps)\n",
    "    attention_mask = cache['attention_mask']\n",
    "    position_embeddings = cache['position_embeddings']\n",
    "\n",
    "    for i in range(len(layers)):\n",
    "        print(i)\n",
    "        layer = layers[i].to(dev)\n",
    "\n",
    "        \n",
    "        for j in range(nsamples):\n",
    "            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]\n",
    "        layers[i] = layer.cpu()\n",
    "        del layer\n",
    "        torch.cuda.empty_cache()\n",
    "        inps, outs = outs, inps\n",
    "\n",
    "    if model.model.norm is not None:\n",
    "        model.model.norm = model.model.norm.to(dev)\n",
    "    model.lm_head = model.lm_head.to(dev)\n",
    "\n",
    "    testenc = testenc.to(dev)\n",
    "    nlls = []\n",
    "    for i in range(nsamples):\n",
    "        hidden_states = inps[i].unsqueeze(0)\n",
    "        if model.model.norm is not None:\n",
    "            hidden_states = model.model.norm(hidden_states)\n",
    "        lm_logits = model.lm_head(hidden_states)\n",
    "        shift_logits = lm_logits[:, :-1, :].contiguous()\n",
    "        shift_labels = testenc[\n",
    "            :, (i * model.seqlen):((i + 1) * model.seqlen)\n",
    "        ][:, 1:]\n",
    "        loss_fct = nn.CrossEntropyLoss()\n",
    "        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))\n",
    "        neg_log_likelihood = loss.float() * model.seqlen\n",
    "        nlls.append(neg_log_likelihood)\n",
    "    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))\n",
    "    print(f\"Perplexity: {ppl.item():3f}\")\n",
    "    if log_wandb:\n",
    "         wandb.log({f'{dataset}/perplexity': ppl.item()})\n",
    "\n",
    "    model.config.use_cache = use_cache"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "448fd6f6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-26T11:53:30.497357Z",
     "iopub.status.busy": "2023-12-26T11:53:30.496880Z",
     "iopub.status.idle": "2023-12-26T11:56:24.045683Z",
     "shell.execute_reply": "2023-12-26T11:56:24.044885Z"
    },
    "papermill": {
     "duration": 173.580751,
     "end_time": "2023-12-26T11:56:24.048044",
     "exception": false,
     "start_time": "2023-12-26T11:53:30.467293",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tok 128000 128001\n",
      "wikitext2\n",
      "Evaluating ...\n",
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n",
      "16\n",
      "17\n",
      "18\n",
      "19\n",
      "20\n",
      "21\n",
      "22\n",
      "23\n",
      "24\n",
      "25\n",
      "26\n",
      "27\n",
      "28\n",
      "29\n",
      "30\n",
      "31\n",
      "Perplexity: 9.049220\n"
     ]
    }
   ],
   "source": [
    "model.gradient_checkpointing_disable()\n",
    "model.eval()\n",
    "\n",
    "for dataset in ['wikitext2']:\n",
    "    dataloader, testloader = get_loaders(\n",
    "        dataset, seed=0, model=model_name, seqlen=model.seqlen\n",
    "    )\n",
    "    print(dataset)\n",
    "    opt_eval(model, testloader, DEV, dataset, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "1397c75c",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), \"/mnt/nvme/llamapush/llama3-8B-dsf1bit-20-ft.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c5e695e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 1090.722145,
   "end_time": "2023-12-26T11:56:25.741463",
   "environment_variables": {},
   "exception": null,
   "input_path": "Llama prune-params-WU.ipynb",
   "output_path": "outputs/Llama-0.6-per_layer-admm-20-iterp15WU.ipynb",
   "parameters": {
    "iterative_prune": 15,
    "iters": 20,
    "sparsity": 0.6
   },
   "start_time": "2023-12-26T11:38:15.019318",
   "version": "2.5.0"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {
     "06d87db571ab4d88b82f52214601bb42": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "FloatProgressModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "FloatProgressModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "ProgressView",
       "bar_style": "success",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_b7ceac1ff43d4e279a2e938bbe2a9758",
       "max": 33,
       "min": 0,
       "orientation": "horizontal",
       "style": "IPY_MODEL_b25d62f1cc0c4f73986ede9999476992",
       "tabbable": null,
       "tooltip": null,
       "value": 33
      }
     },
     "10900902714f449e933af1fbb5b4c688": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "4d15d5f8d6c74e0a8b2a6a56df56d8f2": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "62690d5acae24b38b50cb74283306762": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_9c30a1d8ffe74cfdb177e547c15b771d",
       "placeholder": "​",
       "style": "IPY_MODEL_b64151aeaf904f719d868882ef9e600f",
       "tabbable": null,
       "tooltip": null,
       "value": "Loading checkpoint shards: 100%"
      }
     },
     "84e75e03350d40dd8bf62fdd06a61df7": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HBoxModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HBoxModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HBoxView",
       "box_style": "",
       "children": [
        "IPY_MODEL_62690d5acae24b38b50cb74283306762",
        "IPY_MODEL_06d87db571ab4d88b82f52214601bb42",
        "IPY_MODEL_e6f6ce88afbc4f6199222b97e1218d81"
       ],
       "layout": "IPY_MODEL_c4d01b52ee784af28749c6928e95c0dd",
       "tabbable": null,
       "tooltip": null
      }
     },
     "9c30a1d8ffe74cfdb177e547c15b771d": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "b25d62f1cc0c4f73986ede9999476992": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "ProgressStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "ProgressStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "bar_color": null,
       "description_width": ""
      }
     },
     "b64151aeaf904f719d868882ef9e600f": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLStyleModel",
      "state": {
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLStyleModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "StyleView",
       "background": null,
       "description_width": "",
       "font_size": null,
       "text_color": null
      }
     },
     "b7ceac1ff43d4e279a2e938bbe2a9758": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "c4d01b52ee784af28749c6928e95c0dd": {
      "model_module": "@jupyter-widgets/base",
      "model_module_version": "2.0.0",
      "model_name": "LayoutModel",
      "state": {
       "_model_module": "@jupyter-widgets/base",
       "_model_module_version": "2.0.0",
       "_model_name": "LayoutModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/base",
       "_view_module_version": "2.0.0",
       "_view_name": "LayoutView",
       "align_content": null,
       "align_items": null,
       "align_self": null,
       "border_bottom": null,
       "border_left": null,
       "border_right": null,
       "border_top": null,
       "bottom": null,
       "display": null,
       "flex": null,
       "flex_flow": null,
       "grid_area": null,
       "grid_auto_columns": null,
       "grid_auto_flow": null,
       "grid_auto_rows": null,
       "grid_column": null,
       "grid_gap": null,
       "grid_row": null,
       "grid_template_areas": null,
       "grid_template_columns": null,
       "grid_template_rows": null,
       "height": null,
       "justify_content": null,
       "justify_items": null,
       "left": null,
       "margin": null,
       "max_height": null,
       "max_width": null,
       "min_height": null,
       "min_width": null,
       "object_fit": null,
       "object_position": null,
       "order": null,
       "overflow": null,
       "padding": null,
       "right": null,
       "top": null,
       "visibility": null,
       "width": null
      }
     },
     "e6f6ce88afbc4f6199222b97e1218d81": {
      "model_module": "@jupyter-widgets/controls",
      "model_module_version": "2.0.0",
      "model_name": "HTMLModel",
      "state": {
       "_dom_classes": [],
       "_model_module": "@jupyter-widgets/controls",
       "_model_module_version": "2.0.0",
       "_model_name": "HTMLModel",
       "_view_count": null,
       "_view_module": "@jupyter-widgets/controls",
       "_view_module_version": "2.0.0",
       "_view_name": "HTMLView",
       "description": "",
       "description_allow_html": false,
       "layout": "IPY_MODEL_10900902714f449e933af1fbb5b4c688",
       "placeholder": "​",
       "style": "IPY_MODEL_4d15d5f8d6c74e0a8b2a6a56df56d8f2",
       "tabbable": null,
       "tooltip": null,
       "value": " 33/33 [00:18&lt;00:00,  1.72it/s]"
      }
     }
    },
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
