{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "41bcf8c1",
   "metadata": {},
   "source": [
    "# Toy Examples\n",
    "\n",
    "This notebook builds a small synthetic linear-layer quantization problem.\n",
    "The objective is to replace $\\boldsymbol{W}$ by a quantized matrix $\\boldsymbol{Q}$ while keeping the layer outputs $\\boldsymbol{X}\\boldsymbol{Q}$ close to $\\boldsymbol{X}\\boldsymbol{W}$ under squared error.\n",
    "\n",
    "Please make sure `ipykernel` and `torch` are installed before executing this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "62a48395",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3787da5",
   "metadata": {},
   "source": [
    "We fix the random seed for reproducibility and choose the execution device and floating-point precision."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e6db0e67",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(seed=0)\n",
    "device: torch.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "dtype: torch.dtype = torch.float64"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a586831b",
   "metadata": {},
   "source": [
    "These variables define the synthetic problem size: number of calibration samples, input dimension, and output dimension."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "33d9d616",
   "metadata": {},
   "outputs": [],
   "source": [
    "n, c, r = 1024, 128, 32  # batch size, input dimension, output dimension\n",
    "assert n >= c"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c64076c",
   "metadata": {},
   "source": [
    "We generate random weights $\\boldsymbol{W}$, calibration activations $\\boldsymbol{X}$, quantization scales $\\boldsymbol{S}$, and a small damping term $\\lambda$ for the Hessian-based methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "86393738",
   "metadata": {},
   "outputs": [],
   "source": [
    "W: torch.Tensor = torch.randn(c, r, dtype=dtype, device=device)  # (c, r)\n",
    "X: torch.Tensor = torch.randn(n, c, dtype=dtype, device=device) @ torch.randn(c, dtype=dtype, device=device).diag_embed() @ torch.randn(c, c, dtype=dtype, device=device)  # (n, c)\n",
    "S: torch.Tensor = torch.randn_like(W).abs().clamp(min=1e-2) * (torch.randint_like(W, low=0, high=2) * 2. - 1.) * 1e-1  # (c, r)\n",
    "λ: float = (X ** 2.).sum().item() / (1e2 * c)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e2c9320",
   "metadata": {},
   "source": [
    "This function computes the squared output error between the original and quantized linear layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "41dd2581",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss(\n",
    "        Q: torch.Tensor,  # (c, r)\n",
    "        W: torch.Tensor,  # (c, r)\n",
    "        X: torch.Tensor,  # (n, c)\n",
    ") -> float:\n",
    "    return ((X @ Q - X @ W) ** 2.).sum().item()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57c97151",
   "metadata": {},
   "source": [
    "The functions below compare three quantization procedures on the same problem.\n",
    "RTN rounds each weight independently, while GPTQ and Babai use the Hessian induced by X to update the remaining elements after each rounding step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b89e270e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def RTN(\n",
    "        W: torch.Tensor,  # (c, r)\n",
    "        S: torch.Tensor,  # (c, r)\n",
    "        ℤ_min: float = -torch.inf,\n",
    "        ℤ_max: float = torch.inf,\n",
    ") -> tuple[torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"Round-to-Nearest\"\"\"\n",
    "\n",
    "    Z = (W / S).clamp(min=ℤ_min, max=ℤ_max).round()  # (c, r)\n",
    "    Q = Z * S  # (c, r)\n",
    "\n",
    "    return Z, Q  # (c, r), (c, r)\n",
    "\n",
    "\n",
    "def GPTQ(\n",
    "        W: torch.Tensor,  # (c, r)\n",
    "        S: torch.Tensor,  # (c, r)\n",
    "        X: torch.Tensor,  # (n, c)\n",
    "        P: torch.Tensor,  # (c, c)\n",
    "        λ: float = 0.,\n",
    "        ℤ_min: float = -torch.inf,\n",
    "        ℤ_max: float = torch.inf,\n",
    ") -> tuple[torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"Algorithm 1\"\"\"\n",
    "\n",
    "    dtype, device = W.dtype, W.device\n",
    "    c, r = W.shape\n",
    "    I = torch.eye(c, dtype=dtype, device=device)  # (c, c)\n",
    "\n",
    "    H = P.t() @ (X.t() @ X + λ * I) @ P  # (c, c)\n",
    "    L = torch.linalg.cholesky(H.inverse()); L = L / L.diagonal()  # (c, c)\n",
    "    W, S = P.inverse() @ W, P.inverse() @ S  # (c, r), (c, r)\n",
    "    Q, Z = W.clone(), torch.zeros_like(W)  # (c, r), (c, r)\n",
    "    for j in range(c):\n",
    "        ζ = W[j, :] / S[j, :]  # (r)\n",
    "        Z[j, :] = ζ.clamp(min=ℤ_min, max=ℤ_max).round()  # (r)\n",
    "        Q[j, :] = Z[j, :] * S[j, :]  # (r)\n",
    "        ε = Q[j, :] - W[j, :]  # (r)\n",
    "        W[j:, :] = W[j:, :] + L[j:, j][:, None] @ ε[None, :]  # (c - j, r)\n",
    "    Z, Q = P @ Z, P @ Q  # (c, r), (c, r)\n",
    "\n",
    "    return Z, Q  # (c, r), (c, r)\n",
    "\n",
    "\n",
    "def Babais_Quantize(\n",
    "        W: torch.Tensor,  # (c, r)\n",
    "        S: torch.Tensor,  # (c, r)\n",
    "        X: torch.Tensor,  # (n, c)\n",
    "        T: torch.Tensor,  # (c, c)\n",
    "        λ: float = 0.,\n",
    "        ℤ_min: float = -torch.inf,\n",
    "        ℤ_max: float = torch.inf,\n",
    ") -> tuple[torch.Tensor, torch.Tensor]:\n",
    "    \"\"\"Algorithm 4\"\"\"\n",
    "\n",
    "    dtype, device = W.dtype, W.device\n",
    "    c, r = W.shape\n",
    "    I = torch.eye(c, dtype=dtype, device=device)  # (c, c)\n",
    "\n",
    "    H = T.t() @ (X.t() @ X + λ * I) @ T  # (c, c)\n",
    "    A = torch.linalg.cholesky(H).t()  # (c, c)\n",
    "    W, S = T.inverse() @ W, T.inverse() @ S  # (c, r), (c, r)\n",
    "    Y, Q, Z = A @ W, W.clone(), torch.zeros_like(W)  # (c, r), (c, r), (c, r)\n",
    "    for j in range(c - 1, -1, -1):\n",
    "        ω = Y[j, :] / A[j, j]  # (r)\n",
    "        ζ = ω / S[j, :]  # (r)\n",
    "        Z[j, :] = ζ.clamp(min=ℤ_min, max=ℤ_max).round()  # (r)\n",
    "        Q[j, :] = Z[j, :] * S[j, :]  # (r)\n",
    "        Y = Y - A[:, j][:, None] @ Q[j, :][None, :]  # (c, r)\n",
    "    Z, Q = T @ Z, T @ Q  # (c, r), (c, r)\n",
    "\n",
    "    return Z, Q  # (c, r), (c, r)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a066f0e9",
   "metadata": {},
   "source": [
    "This function computes the no-clipping error bound associated with the chosen quantization order.\n",
    "It will be used to compare the observed GPTQ loss against the theoretical worst-case bound and the average-case reference (one-third of worst-case bound)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c367f200",
   "metadata": {},
   "outputs": [],
   "source": [
    "def GPTQ_Error_Bound(\n",
    "        S: torch.Tensor,  # (c, r)\n",
    "        X: torch.Tensor,  # (n, c)\n",
    "        T: torch.Tensor,  # (c, c)\n",
    "        λ: float = 0.,\n",
    ") -> float:\n",
    "    \"\"\"Theorem 5\"\"\"\n",
    "\n",
    "    dtype, device = W.dtype, W.device\n",
    "    c, r = W.shape\n",
    "    I = torch.eye(c, dtype=dtype, device=device)  # (c, c)\n",
    "\n",
    "    H = T.t() @ (X.t() @ X + λ * I) @ T  # (c, c)\n",
    "    D = (torch.linalg.cholesky(H).diagonal() ** 2.).diag_embed()  # (c, c)\n",
    "\n",
    "    bound = ((T.inverse() @ S).t()[:, None, :] @ D @ (T.inverse() @ S).t()[:, :, None]).sum().item() / 4.\n",
    "    return bound"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47b2f3a1",
   "metadata": {},
   "source": [
    "This section checks the main theoretical claims on a random example.\n",
    "It verifies that Babai's algorithm matches GPTQ under reversed ordering, shows that these methods usually improve over RTN, and confirms that the no-clipping loss does not exceed the theoretical worst-case bound."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a01473a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Basic Verifications\n",
      "Therorem 4: Babai's Quantize (Algorithm 4) is identical to GPTQ (Algorithm 1) with reversed ordering.\n",
      "Babai loss (ℓ_babai=1.9725e+08, ℓ_babai_no_clip=1.2901e+05) is usually less than RTN loss (ℓ_rtn=2.4883e+08, ℓ_rtn_no_clip=4.8143e+05).\n",
      "Therorem 5: GPTQ loss without clipping (ℓ_babai_no_clip=1.2901e+05) is similar to the theoretical average case bound (average_case_no_clip_bound=1.3965e+05) and not larger than the worst case bound (worst_case_no_clip_bound=4.1894e+05)\n"
     ]
    }
   ],
   "source": [
    "def basic_verifications() -> None:\n",
    "    print('Basic Verifications')\n",
    "\n",
    "    ℤ_min, ℤ_max = -8., 7.\n",
    "    order_random: torch.Tensor = torch.randperm(c, device=device)  # (c,)\n",
    "    P: torch.Tensor = torch.eye(c, dtype=dtype, device=device)[order_random, :]  # (c, c)\n",
    "    T: torch.Tensor = P.flip(dims=(1,))  # (c, c)\n",
    "\n",
    "    Z_rtn, Q_rtn = RTN(W, S, ℤ_min, ℤ_max)  # (c, r), (c, r)\n",
    "    Z_rtn_no_clip, Q_rtn_no_clip = RTN(W, S)  # (c, r), (c, r)\n",
    "    Z_gptq, Q_gptq = GPTQ(W, S, X, P, λ, ℤ_min, ℤ_max)  # (c, r), (c, r)\n",
    "    Z_gptq_no_clip, Q_gptq_no_clip = GPTQ(W, S, X, P, λ)  # (c, r), (c, r)\n",
    "    Z_babai, Q_babai = Babais_Quantize(W, S, X, T, λ, ℤ_min, ℤ_max)  # (c, r), (c, r)\n",
    "    Z_babai_no_clip, Q_babai_no_clip = Babais_Quantize(W, S, X, T, λ)  # (c, r), (c, r)\n",
    "\n",
    "    assert Q_babai.equal(Q_gptq) and Z_babai.equal(Z_gptq)\n",
    "    assert Q_babai_no_clip.equal(Q_gptq_no_clip) and Z_babai_no_clip.equal(Z_gptq_no_clip)\n",
    "    print('Therorem 4: Babai\\'s Quantize (Algorithm 4) is identical to GPTQ (Algorithm 1) with reversed ordering.')\n",
    "\n",
    "    ℓ_rtn: float = loss(Q_rtn, W, X)\n",
    "    ℓ_rtn_no_clip: float = loss(Q_rtn_no_clip, W, X)\n",
    "    ℓ_babai: float = loss(Q_babai, W, X)\n",
    "    ℓ_babai_no_clip: float = loss(Q_babai_no_clip, W, X)\n",
    "    print(f'Babai loss ({ℓ_babai=:.4e}, {ℓ_babai_no_clip=:.4e}) is usually less than RTN loss ({ℓ_rtn=:.4e}, {ℓ_rtn_no_clip=:.4e}).')\n",
    "\n",
    "    worst_case_no_clip_bound: float = GPTQ_Error_Bound(S, X, T, λ)\n",
    "    average_case_no_clip_bound: float = worst_case_no_clip_bound / 3.\n",
    "    assert ℓ_babai_no_clip <= worst_case_no_clip_bound\n",
    "    print(f'Therorem 5: GPTQ loss without clipping ({ℓ_babai_no_clip=:.4e}) is similar to the theoretical average case bound ({average_case_no_clip_bound=:.4e}) and not larger than the worst case bound ({worst_case_no_clip_bound=:.4e})')\n",
    "\n",
    "\n",
    "basic_verifications()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d1ac4ce",
   "metadata": {},
   "source": [
    "The final experiment studies how quantization order affects both the loss and the bound.\n",
    "\n",
    "This function implements the min-pivot ordering heuristic, which greedily chooses the smallest current Hessian diagonal entry at each LDL decomposition step, following Algorithm 3 in the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9a118c46",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Min_Pivot(\n",
    "        H: torch.Tensor,  # (c, c)\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"Algorithm 3\"\"\"\n",
    "\n",
    "    dtype, device = H.dtype, H.device\n",
    "    c = H.size(-1)\n",
    "    inf = torch.as_tensor(torch.inf, dtype=dtype, device=device)  # ()\n",
    "    H = H.clone()  # (c, c)\n",
    "\n",
    "    J = torch.ones(c, dtype=torch.bool, device=device)  # (c), bool\n",
    "    T = torch.zeros_like(H)  # (c, c)\n",
    "    for j in range(c):\n",
    "        j_prime = torch.where(J, H.diagonal(), inf).argmin()  # ()\n",
    "        H = H - H[:, j_prime][:, None] @ H[j_prime, :][None, :] / H[j_prime, j_prime]  # (c, c)\n",
    "        T[j_prime, j] = 1.  # ()\n",
    "        J[j_prime] = False  # ()\n",
    "    return T  # (c, c)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb494263",
   "metadata": {},
   "source": [
    "For this comparison, all scales are set to the same value so that the effect of ordering is isolated.\n",
    "The printed table reports the measured loss together with the average-case and worst-case bounds for each ordering rule."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "905e7250",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Quantization Orders Comparison\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Order          \tLoss           \tAverage Bound  \tWorst Bound    \n",
      "back-to-front  \t8.2177e+04     \t8.8700e+04     \t2.6610e+05     \n",
      "front-to-back  \t7.8528e+04     \t8.6942e+04     \t2.6083e+05     \n",
      "random         \t8.0054e+04     \t8.6578e+04     \t2.5973e+05     \n",
      "act-order      \t7.8320e+04     \t8.8911e+04     \t2.6673e+05     \n",
      "min-pivot      \t6.5154e+04     \t7.3871e+04     \t2.2161e+05     \n",
      "The min-pivot order usually achieves the lowest loss and bound among other heuristic orders.\n"
     ]
    }
   ],
   "source": [
    "def quantization_orders_comparison(scale: float) -> None:\n",
    "    print('Quantization Orders Comparison')\n",
    "\n",
    "    S: torch.Tensor = torch.full_like(W, fill_value=scale)  # (c, r)\n",
    "\n",
    "    I: torch.Tensor = torch.eye(c, dtype=dtype, device=device)  # (c, c)\n",
    "    H: torch.Tensor = X.t() @ X + λ * I\n",
    "\n",
    "    T_back_to_front: torch.Tensor = I  # (c, c)\n",
    "    T_front_to_back: torch.Tensor = I.flip(dims=(1,))  # (c, c)\n",
    "    T_random: torch.Tensor = I[torch.randperm(c, device=device), :]  # (c, c)\n",
    "    T_act_order: torch.Tensor = I[H.diagonal().argsort(descending=False), :]  # (c, c)\n",
    "    T_min_pivot: torch.Tensor = Min_Pivot(H)  # (c, c)\n",
    "\n",
    "    ℓ_back_to_front: float = loss(Babais_Quantize(W, S, X, T_back_to_front, λ)[1], W, X)\n",
    "    ℓ_front_to_back: float = loss(Babais_Quantize(W, S, X, T_front_to_back, λ)[1], W, X)\n",
    "    ℓ_random: float = loss(Babais_Quantize(W, S, X, T_random, λ)[1], W, X)\n",
    "    ℓ_act_order: float = loss(Babais_Quantize(W, S, X, T_act_order, λ)[1], W, X)\n",
    "    ℓ_min_pivot: float = loss(Babais_Quantize(W, S, X, T_min_pivot, λ)[1], W, X)\n",
    "\n",
    "    bound_back_to_front: float = GPTQ_Error_Bound(S, X, T_back_to_front, λ)\n",
    "    bound_front_to_back: float = GPTQ_Error_Bound(S, X, T_front_to_back, λ)\n",
    "    bound_random: float = GPTQ_Error_Bound(S, X, T_random, λ)\n",
    "    bound_act_order: float = GPTQ_Error_Bound(S, X, T_act_order, λ)\n",
    "    bound_min_pivot: float = GPTQ_Error_Bound(S, X, T_min_pivot, λ)\n",
    "\n",
    "    print(f'{\"Order\":<15}', f'{\"Loss\":<15}', f'{\"Average Bound\":<15}', f'{\"Worst Bound\":<15}', sep='\\t')\n",
    "    print(f'{\"back-to-front\":<15}', f'{ℓ_back_to_front:<15.4e}', f'{bound_back_to_front / 3.:<15.4e}', f'{bound_back_to_front:<15.4e}', sep='\\t')\n",
    "    print(f'{\"front-to-back\":<15}', f'{ℓ_front_to_back:<15.4e}', f'{bound_front_to_back / 3.:<15.4e}', f'{bound_front_to_back:<15.4e}', sep='\\t')\n",
    "    print(f'{\"random\":<15}', f'{ℓ_random:<15.4e}', f'{bound_random / 3.:<15.4e}', f'{bound_random:<15.4e}', sep='\\t')\n",
    "    print(f'{\"act-order\":<15}', f'{ℓ_act_order:<15.4e}', f'{bound_act_order / 3.:<15.4e}', f'{bound_act_order:<15.4e}', sep='\\t')\n",
    "    print(f'{\"min-pivot\":<15}', f'{ℓ_min_pivot:<15.4e}', f'{bound_min_pivot / 3.:<15.4e}', f'{bound_min_pivot:<15.4e}', sep='\\t')\n",
    "\n",
    "    print('The min-pivot order usually achieves the lowest loss and bound among other heuristic orders.')\n",
    "\n",
    "\n",
    "quantization_orders_comparison(scale=S.abs().mean().item())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
