{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "020d480e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=8\n",
      "env: CUDA_LAUNCH_BLOCKING=1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/nfs/scistore19/alistgrp/apanfero/micromamba/envs/llmb/lib/python3.10/site-packages/torch/cuda/__init__.py:799: UserWarning: Can't initialize NVML\n",
      "  warnings.warn(\"Can't initialize NVML\")\n"
     ]
    }
   ],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=8\n",
    "%env CUDA_LAUNCH_BLOCKING=1\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"../src\")\n",
    "\n",
    "import torch\n",
    "from models.quantization.quantizers import QuestMXFP4Quantizer, AlbertTsengQuantizer, EdenSRQuantizer, IsolatedEdenQuantizer, QuestNvfp4Quantizer, Nvfp4Quantizer\n",
    "\n",
    "from tqdm.auto import trange, tqdm\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "741673de",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/nfs/scistore19/alistgrp/apanfero/micromamba/envs/llmb/lib/python3.10/site-packages/torch/cuda/__init__.py:799: UserWarning: Can't initialize NVML\n",
      "  warnings.warn(\"Can't initialize NVML\")\n",
      "/nfs/scistore19/alistgrp/apanfero/micromamba/envs/llmb/lib/python3.10/site-packages/torch/cuda/__init__.py:799: UserWarning: Can't initialize NVML\n",
      "  warnings.warn(\"Can't initialize NVML\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Scales DType | Group Size | Unbiased | MSE        | Rate, bits | Magnitude Misalignment |\n",
      "|--------------|------------|----------|------------|------------|------------------------|\n",
      "| E4M3         | 16         | NO       |        9.0 |       3.40 |                0.00792 |\n",
      "| E4M3         | 16         | SR       |       23.5 |       2.71 |                0.00000 |\n",
      "| E4M3         | 16         | EDEN     |        9.8 |       3.34 |                0.00056 |\n"
     ]
    }
   ],
   "source": [
    "x = (torch.randn(2**20, 128, device=\"cuda\") * torch.logspace(0, 8, 2**20, base=2, device=\"cuda\").unsqueeze(1)).flatten()\n",
    "\n",
    "scale_dtype_group = [\n",
    "    # (\"fp32\", 128),\n",
    "    (\"e4m3\", 16),\n",
    "    # (\"e8m0\", 32),\n",
    "]\n",
    "optimal_scale_override = {\n",
    "    \"eden\": {\n",
    "        \"fp32\": 0.96,\n",
    "        \"e4m3\": 0.93,\n",
    "        \"e8m0\": 0.91,\n",
    "    },\n",
    "    \"no\": {\n",
    "        \"fp32\": 0.96,\n",
    "        \"e4m3\": 0.93,\n",
    "        \"e8m0\": 0.84,\n",
    "    },\n",
    "    \"sr\": {\n",
    "        \"fp32\": 1.00,\n",
    "        \"e4m3\": 1.00,\n",
    "        \"e8m0\": 1.00,   \n",
    "    },\n",
    "}\n",
    "\n",
    "table_rows = []\n",
    "data = {}\n",
    "\n",
    "for (scale_dtype, group_dim) in scale_dtype_group:\n",
    "    for unbiased in [\"no\", \"sr\", \"eden\"]:\n",
    "        scale_override = optimal_scale_override[unbiased][scale_dtype]\n",
    "\n",
    "        \n",
    "        quantizer = EdenSRQuantizer(hadamard_dim=128, group_dim=group_dim, scale_dtype=scale_dtype, unbiased=unbiased, scale_override=scale_override, four_over_six=False)\n",
    "        dq = quantizer(x).view(-1, quantizer.hadamard_dim) @ quantizer.hadamard_matrix\n",
    "        ref = x.view(-1, quantizer.hadamard_dim)\n",
    "        quad_err = ((ref - dq).pow(2).sum(dim=-1) / ref.pow(2).sum(dim=-1)).mean()\n",
    "        eff_bitwidth = (-torch.log2(quad_err) / 2).item()\n",
    "        magnitude_alignment = ((ref * dq).sum(dim=-1) / (ref * ref).sum(dim=-1)).mean().item()\n",
    "        \n",
    "        data[(group_dim, scale_dtype, unbiased)] = (eff_bitwidth, 1 - magnitude_alignment)\n",
    "        \n",
    "        table_rows.append(\n",
    "            (scale_dtype, group_dim, unbiased, quad_err, eff_bitwidth, 1 - magnitude_alignment)\n",
    "        )\n",
    "\n",
    "# Print markdown table\n",
    "print(\"| Scales DType | Group Size | Unbiased | MSE        | Rate, bits | Magnitude Misalignment |\")\n",
    "print(\"|--------------|------------|----------|------------|------------|------------------------|\")\n",
    "for row in table_rows:\n",
    "    print(f\"| {str(row[0].upper()):<12} | {str(row[1]):<10} | {str(row[2]).upper():<8} | {row[3] * 1e3:>10.1f} | {row[4]:>10.2f} | {row[5]:>22.5f} |\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "064f8940",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "four_over_six=False,square=False: 9.0 3.394\n",
      "four_over_six=True,square=False: 7.6 3.524\n",
      "four_over_six=False,square=True: 12.4 3.167\n",
      "four_over_six=True,square=True: 12.4 3.168\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn((8192, 8192), device=\"cuda\")\n",
    "\n",
    "for square in [False, True]:\n",
    "    for four_over_six in [False, True]:\n",
    "        quantizer = Nvfp4Quantizer(square=square, four_over_six=four_over_six)\n",
    "        dq = quantizer(x)\n",
    "        quad_err = ((x - dq).pow(2).sum(dim=-1) / x.pow(2).sum(dim=-1)).mean()\n",
    "        eff_bitwidth = (-torch.log2(quad_err) / 2).item()\n",
    "        magnitude_alignment = ((x * dq).sum(dim=-1) / (x * x).sum(dim=-1)).mean().item()\n",
    "        \n",
    "        print(f\"{four_over_six=},{square=}: {quad_err*1e3:.1f} {eff_bitwidth:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7e707a16",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "994ed770ce0f49ce88015e9e07494ecd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "90ce2e99cc2d471eb5db5d4a15fb6836",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1: 35.84 2.401\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "04725f5d44f64de2b716685d6980d52f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4: 9.11 3.389\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ad0abf774cb94ec3b790bf2132331867",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/16 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16: 2.44 4.339\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9f90ac5ea40b4823bfa7d2f2c4b54fc7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/64 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "64: 0.77 5.170\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn((4096, 4096), device=\"cuda\")\n",
    "y = torch.randn((4096, 4096), device=\"cuda\")\n",
    "\n",
    "unbiased = \"sr\"\n",
    "\n",
    "quantizer = EdenSRQuantizer(hadamard_dim=128, group_dim=16, scale_dtype=\"e4m3\", unbiased=unbiased, scale_override=optimal_scale_override[unbiased][\"e4m3\"], rerotate='signs', four_over_six=True)\n",
    "\n",
    "for acc_steps in tqdm([1, 4, 16, 64]):\n",
    "    acc_prod = torch.zeros((x.shape[0], y.shape[0]), device=\"cuda\")\n",
    "    for step in trange(acc_steps, leave=False):\n",
    "        quantizer.re_randomize()\n",
    "        xq = quantizer(x)\n",
    "        yq = quantizer(y)\n",
    "        acc_prod += xq @ yq.T\n",
    "        \n",
    "    quad_err = (acc_prod / acc_steps - x @ y.T).pow(2).mean() / (x @ y.T).pow(2).mean()\n",
    "    eff_bitwidth = (-torch.log2(quad_err) / 2).item()\n",
    "    print(f\"{acc_steps}: {quad_err * 1e3:.2f} {eff_bitwidth:.3f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c039f1cc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d9922e1c889d471cb1201a427f80ad78",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5e502c1c5afa490aa885346ad2c3a98f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1: 2.197\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3725c4c2c97648618617110198e5a330",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4: 3.196\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "10919c1f3e8742dfbc5fa53da3ef83df",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/16 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16: 4.196\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "33c13c5d2ee146c29465d8f8640e58ec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/64 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "64: 5.197\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d13baf656cf745d79ce32916150c3f5d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/256 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "256: 6.197\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d3ffc78fda684e3ea02c7859b33eba00",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1024 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1024: 7.196\n"
     ]
    }
   ],
   "source": [
    "from models.quantization.quantizers.nvfp4_triton import sr_1x16s_fp4_kernel_wrapper\n",
    "\n",
    "x = torch.randn((4096, 4096), device=\"cuda\")\n",
    "y = torch.randn((4096, 4096), device=\"cuda\")\n",
    "\n",
    "for acc_steps in tqdm([1, 4, 16, 64, 256, 1024]):\n",
    "    acc_prod = torch.zeros((x.shape[0], y.shape[0]), device=\"cuda\")\n",
    "    for step in trange(acc_steps, leave=False):\n",
    "        quantizer.re_randomize()\n",
    "        xq = sr_1x16s_fp4_kernel_wrapper(x, 17/16, 16, False)\n",
    "        yq = sr_1x16s_fp4_kernel_wrapper(y, 17/16, 16, False)\n",
    "        acc_prod += xq @ yq.T\n",
    "        \n",
    "    quad_err = (acc_prod / acc_steps - x @ y.T).pow(2).mean() / (x @ y.T).pow(2).mean()\n",
    "    eff_bitwidth = (-torch.log2(quad_err) / 2).item()\n",
    "    print(f\"{acc_steps}: {eff_bitwidth:.3f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "eeeb0269",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "502881a97a0441d096d36a33923865a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f75a8e9797804813abeb9fd69bb70ea7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1:\n",
      "\tx: 2.847 bits\n",
      "\tw: 2.834 bits\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "243d6238b44d4d279168da47120e0deb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4:\n",
      "\tx: 3.839 bits\n",
      "\tw: 3.848 bits\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f1d4dca9e574215b8ad8d926fda18ed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/16 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16:\n",
      "\tx: 4.843 bits\n",
      "\tw: 4.828 bits\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "186351720f2743b6b3652040392a4d8f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/64 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "64:\n",
      "\tx: 5.840 bits\n",
      "\tw: 5.797 bits\n"
     ]
    }
   ],
   "source": [
    "from models.quantization.schemes.tetrajetv2 import TetraJetV2Linear, TetraJetV2_fn\n",
    "from models.quantization.schemes.quartet_2 import Quartet_II_Linear\n",
    "\n",
    "x = torch.randn((3, 128, 4096), device=\"cuda\", requires_grad=True)\n",
    "\n",
    "# linear = TetraJetV2Linear(4096, 1024, device=\"cuda\", dtype=torch.float32, bias=False, disable_forward_quant=True)\n",
    "linear = Quartet_II_Linear(4096, 1024, device=\"cuda\", dtype=torch.float32, bias=False, disable_forward_quant=True, hadamard_dim=128)\n",
    "\n",
    "head = torch.nn.Linear(1024, 1, device=\"cuda\")\n",
    "target = torch.randn(3, 128, 1, device=\"cuda\")\n",
    "\n",
    "\n",
    "def get_loss(x, linear, head, target):\n",
    "    return (head(linear(x)) - target).pow(2).mean()\n",
    "\n",
    "\n",
    "linear.disable_backward_quant = True\n",
    "x.grad = None\n",
    "linear.weight.grad = None\n",
    "get_loss(x, linear, head, target).backward()\n",
    "ref_x_grad = x.grad.clone().detach()\n",
    "ref_w_grad = linear.weight.grad.clone().detach()\n",
    "\n",
    "\n",
    "linear.disable_backward_quant = False\n",
    "for acc_steps in tqdm([1, 4, 16, 64]):    \n",
    "    x.grad = None\n",
    "    linear.weight.grad = None\n",
    "    \n",
    "    loss = get_loss(x, linear, head, target)\n",
    "\n",
    "    for step in trange(acc_steps, leave=False):\n",
    "        loss.backward(retain_graph=True)\n",
    "        \n",
    "    x_quad_err = (x.grad / acc_steps - ref_x_grad).pow(2).mean() / ref_x_grad.pow(2).mean()\n",
    "    x_eff_bitwidth = (-torch.log2(x_quad_err) / 2).item()\n",
    "    print(f\"{acc_steps}:\\n\\tx: {x_eff_bitwidth:.3f} bits\")\n",
    "    \n",
    "    w_quad_err = (linear.weight.grad / acc_steps - ref_w_grad).pow(2).mean() / ref_w_grad.pow(2).mean()\n",
    "    w_eff_bitwidth = (-torch.log2(w_quad_err) / 2).item()\n",
    "    print(f\"\\tw: {w_eff_bitwidth:.3f} bits\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "787e56a4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
