{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (anonymous)\n",
    "\n",
    "# Usage\n",
    "\n",
    "A quick guide to `weight_formats` library usage."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: TOKENIZERS_PARALLELISM=true\n"
     ]
    }
   ],
   "source": [
    "%env TOKENIZERS_PARALLELISM=true\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch\n",
    "\n",
    "# By convention, we use `import weight_formats.quantisation as Q`, etc.\n",
    "import weight_formats.analysis as A\n",
    "import weight_formats.experiments as E\n",
    "import weight_formats.experiments.token_prediction as ET\n",
    "import weight_formats.fit as F\n",
    "import weight_formats.model_quantisation as M\n",
    "import weight_formats.quantisation as Q\n",
    "import weight_formats.sensitivity as S\n",
    "\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load a model and dataset (with calculation of reference output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = E.RequantisableModel.load(\"meta-llama/Llama-3.2-1B\", DEVICE, torch.bfloat16)\n",
    "params = {k: v.detach() for k, v in model.model.named_parameters() if v.ndim == 2}\n",
    "\n",
    "# For a quick test - 1 batch of shape (16, 256)\n",
    "data = E.token_prediction.Dataset.load(model.model, batch_size=16, sequence_limit=16, sequence_length=256, kl_topk=128)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Quantise a single parameter, and evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "b = 4.00\n",
      "R = 0.072\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'cross_entropy': tensor([2.6887, 2.6609, 2.9141, 3.2589, 2.1054, 2.9393, 1.9774, 3.2417, 2.9185,\n",
       "         2.8662, 3.1528, 3.0990, 2.6329, 3.1736, 2.8883, 3.0719]),\n",
       " 'kl_div': tensor([0.0011, 0.0011, 0.0011, 0.0012, 0.0010, 0.0010, 0.0012, 0.0013, 0.0012,\n",
       "         0.0011, 0.0012, 0.0013, 0.0013, 0.0011, 0.0011, 0.0012])}"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model.reset()\n",
    "p = model.model.state_dict()['model.layers.15.self_attn.v_proj.weight']\n",
    "\n",
    "fmt = Q.CompressedLUTFormat.train_grid(p, p.std().item() / 4)\n",
    "print(f\"b = {fmt.count_bits_tensor(p) / p.nelement():.2f}\")\n",
    "print(f\"R = {Q.qrmse_norm(fmt, p).cpu():.3f}\")\n",
    "p[...] = fmt.quantise(p)\n",
    "\n",
    "display({k: v.cpu() for k, v in data.evaluate(model.model).items()})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Quantise all parameters using `F.Scaled.fit`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "b = 4.25\n",
      "model.layers.0.mlp.gate_proj.weight: {'bits': 71303168, 'rmse': 0.0017555853119120002}\n"
     ]
    }
   ],
   "source": [
    "# This takes a few seconds, as it runs k-means per parameter tensor.\n",
    "model.reset()\n",
    "log = M.quantise_2d_fixed(model.model, F.Scaled(4, \"lloyd_max\", Q.BFLOAT16, (1, 64), \"absmax\", compressor=None, sparse_ratio=0))\n",
    "print(f\"b = {log['bits_per_param']:.2f}\")\n",
    "k = \"model.layers.0.mlp.gate_proj.weight\"\n",
    "print(f\"{k}: {log['params'][k]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Run a mini-sweep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:12<00:00, 12.02s/it]\n",
      "100%|██████████| 1/1 [00:12<00:00, 12.16s/it]\n"
     ]
    }
   ],
   "source": [
    "tests = [\n",
    "    ET.QuantiseFixed(F.Scaled(4, \"int\", Q.BFLOAT16, (1, 64), \"absmax\")),\n",
    "    ET.QuantiseFixed(F.Scaled(4, \"int\", Q.BFLOAT16, (1, 64), \"signmax\")),\n",
    "]\n",
    "ET.run_sweep([ET.Run(\"dev\", test, \"meta-llama/Llama-3.2-1B\") for test in tests])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "14 runs\n",
      "\n",
      "### 4b-int{1,64:BFLOAT16:absmax:search} ###\n",
      "     b = 4.25\n",
      "  D_KL = 0.196\n",
      "\n",
      "### 4b-int{1,64:BFLOAT16:signmax:search} ###\n",
      "     b = 4.25\n",
      "  D_KL = 0.163\n",
      "\n"
     ]
    }
   ],
   "source": [
    "runs = E.runs(\"dev\")\n",
    "print(len(runs), \"runs\")\n",
    "print()\n",
    "\n",
    "for run in runs[-2:]:\n",
    "    print(f\"### {run.config.test.fmt_str} ###\")\n",
    "    print(f\"     b = {run.summary.bits_per_param:.2f}\")\n",
    "    print(f\"  D_KL = {torch.tensor(run.summary.kl_div).mean().item():.3f}\")\n",
    "    print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
