{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d6f149d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/anon/miniconda3/envs/neurips_2025/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Loaded /home/anon/miniconda3/envs/neurips_2025/lib/python3.12/site-packages/gemlite/configs/4090.json config.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<module 'utils' from '/home/anon/phd_master/neurips_2025/clean/utils.py'>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import utils\n",
    "from importlib import reload\n",
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import tqdm\n",
    "from gemlite.core import GemLiteLinearTriton\n",
    "reload(utils)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bdb9d4a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_checkpoint_path = \"../checkpoints\"\n",
    "\n",
    "# Set to false if you want to benchamrk the base model without double binary factorization\n",
    "dbf = True\n",
    "experiment_id = 0\n",
    "LEN_TOKENS_FOR_GENERATION = 128\n",
    "REPETITIONS = 100\n",
    "\n",
    "# We provide triton autotune config. It saves the time during autotuning,\n",
    "# but the results may not be optimal for your GPU.\n",
    "load_triton_autotune_config = True\n",
    "\n",
    "experiments = [\n",
    "    (\"latest_llama2-1b-10.pt\", \"meta-llama/Llama-2-7b-hf\"),\n",
    "    (\"latest_llama2-1b-15.pt\", \"meta-llama/Llama-2-7b-hf\"),\n",
    "    (\"latest_llama2-1b-20.pt\", \"meta-llama/Llama-2-7b-hf\"),\n",
    "    (\"latest_llama2-1b-23.pt\", \"meta-llama/Llama-2-7b-hf\"),\n",
    "    (\"latest_llama31b10.pt\", \"meta-llama/Meta-Llama-3-8B\"),\n",
    "    (\"latest_llama31b15.pt\", \"meta-llama/Meta-Llama-3-8B\"),\n",
    "    (\"latest_llama31b20.pt\", \"meta-llama/Meta-Llama-3-8B\"),\n",
    "    (\"latest_llama31b23.pt\", \"meta-llama/Meta-Llama-3-8B\"),\n",
    "]\n",
    "dbp_checkpoint_name, base_model_str = experiments[experiment_id]\n",
    "\n",
    "checkpoint_path = os.path.join(base_checkpoint_path, dbp_checkpoint_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f0170e1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "if load_triton_autotune_config:\n",
    "    GemLiteLinearTriton.load_config(f'test_config_full_eval.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3b96eb6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = utils.get_tokezizer(base_model_str)\n",
    "input_ids = tokenizer(\"\", return_tensors=\"pt\").to(\"cuda\")\n",
    "dbf_weights = torch.load(checkpoint_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "37c04d96",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|███████████████████████████████████████| 2/2 [00:02<00:00,  1.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.66 s, sys: 1.28 s, total: 2.94 s\n",
      "Wall time: 3.39 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "model = utils.get_llama(base_model_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b70bde33",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<module 'utils' from '/home/anon/phd_master/neurips_2025/clean/utils.py'>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reload(utils)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8c193b9c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n",
      "16\n",
      "17\n",
      "18\n",
      "19\n",
      "20\n",
      "21\n",
      "22\n",
      "23\n",
      "24\n",
      "25\n",
      "26\n",
      "27\n",
      "28\n",
      "29\n",
      "30\n",
      "31\n",
      "CPU times: user 2min 58s, sys: 49.3 s, total: 3min 47s\n",
      "Wall time: 3min 4s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "if dbf:\n",
    "    for i in range(32):\n",
    "        print(i)\n",
    "        att = model.model.layers[i].self_attn\n",
    "        mlp = model.model.layers[i].mlp\n",
    "        \n",
    "        att.q_proj = utils.DoubleBinaryLinear(dbf_weights, i, \"self_attn.q_proj\")\n",
    "        att.k_proj = utils.DoubleBinaryLinear(dbf_weights, i, \"self_attn.k_proj\")\n",
    "        att.v_proj = utils.DoubleBinaryLinear(dbf_weights, i, \"self_attn.v_proj\")\n",
    "        att.o_proj = utils.DoubleBinaryLinear(dbf_weights, i, \"self_attn.o_proj\")\n",
    "        mlp.gate_proj = utils.DoubleBinaryLinear(dbf_weights, i, \"mlp.gate_proj\")\n",
    "        mlp.up_proj = utils.DoubleBinaryLinear(dbf_weights, i, \"mlp.up_proj\")\n",
    "        mlp.down_proj = utils.DoubleBinaryLinear(dbf_weights, i, \"mlp.down_proj\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f82d1f4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = utils.CustomGenerator(model, tokenizer, LEN_TOKENS_FOR_GENERATION)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a447ff7e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/anon/miniconda3/envs/neurips_2025/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:194: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 54.6 s, sys: 16.6 s, total: 1min 11s\n",
      "Wall time: 1min 12s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# First generate is slow because of compilation\n",
    "out = generator.generate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "022e2d2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 643 ms, sys: 7.9 ms, total: 651 ms\n",
      "Wall time: 650 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "out = generator.generate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "38809af7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Return to Forest From Panels From Cages Trials Divilation 19.ng/ Library\n",
      "Home › Collections › Forestmah › Forest ›\n",
      "NEWSPAN_29-12-2012_PH4288.jpg (145.4 KB, 173 W) Add: 1 (1.06 MB) Click here to find the college's WGC page. Download\n",
      "Going full\n",
      "(94.25 MB, 15.99 MB) Added: 10 (9\n"
     ]
    }
   ],
   "source": [
    "print(utils.to_text(out, tokenizer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "964b1f12",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████| 100/100 [01:05<00:00,  1.54it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(np.float64(647.1263427734375),\n",
       " np.float64(650.03896484375),\n",
       " np.float64(656.0706176757812))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t_milis = []\n",
    "for _ in tqdm.tqdm(range(REPETITIONS)):\n",
    "    start_event = torch.cuda.Event(enable_timing=True)\n",
    "    end_event = torch.cuda.Event(enable_timing=True)\n",
    "    start_event.record()\n",
    "    out = generator.generate()\n",
    "    end_event.record()\n",
    "    torch.cuda.synchronize()\n",
    "    t_milis.append(start_event.elapsed_time(end_event))\n",
    "\n",
    "np.min(t_milis), np.mean(t_milis[-5:]), np.max(t_milis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e9ebac21",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Inference time 650.03896484375 milis\n",
      "Tokens/s = 196.91127289695223\n"
     ]
    }
   ],
   "source": [
    "inference_time = np.mean(t_milis[-5:])\n",
    "print(f\"Inference time {inference_time} milis\")\n",
    "print(f\"Tokens/s = {LEN_TOKENS_FOR_GENERATION/(inference_time/1000)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3db68b6d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
