{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.stats import norm\n",
    "\n",
    "def compute_mse(grid):\n",
    "    \"\"\"\n",
    "    Compute the Mean Squared Error (MSE) for a given scaling parameter 'a' and number of quantization levels 'N'.\n",
    "    \"\"\"\n",
    "    q = [-np.inf] + [(grid[i] + grid[i+1]) / 2 for i in range(len(grid) - 1)] + [np.inf]  # Quantization boundaries\n",
    "\n",
    "    MSE = 0.0\n",
    "    for i in range(len(grid)):\n",
    "        left = q[i]\n",
    "        right = q[i + 1]\n",
    "        center = grid[i]\n",
    "\n",
    "        # Probability of the interval\n",
    "        P_i = norm.cdf(right) - norm.cdf(left)\n",
    "\n",
    "        # First and second moments over the interval\n",
    "        M1_i = norm.expect(lambda t: t, loc=0, scale=1, lb=left, ub=right)\n",
    "        M2_i = norm.expect(lambda t: t**2, loc=0, scale=1, lb=left, ub=right)\n",
    "\n",
    "        # MSE for the i-th interval\n",
    "        E_i = M2_i - 2 * center * M1_i + center**2 * P_i\n",
    "        MSE += E_i\n",
    "\n",
    "    # Total MSE\n",
    "    return MSE\n",
    "\n",
    "\n",
    "def get_uniform_grid(a, N):\n",
    "    return np.linspace(-a, a, N)  # Quantization centers\n",
    "\n",
    "def get_fp4_grid(a:float=1):\n",
    "    zeros = [+0, -0]\n",
    "    normal = [sign * (1+m)/2 * 2**(e-1) for sign in [1, -1] for e in range(1,4) for m in range(1,3)]\n",
    "    subnormal = [sign * (0+m) * 2**(-1) for m in range(1,2) for sign in [1, -1]]\n",
    "    return a * np.array(sorted(zeros + normal + subnormal))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "GRID_MSES = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1: optimal scaling parameter (a): 0.7978845587140913, Minimum MSE: 0.3633802276324186\n",
      "2: optimal scaling parameter (a): 1.493534520977036, Minimum MSE: 0.11884605038769407\n",
      "3: optimal scaling parameter (a): 2.0510679063024964, Minimum MSE: 0.03743965939152373\n",
      "4: optimal scaling parameter (a): 2.5139324513630887, Minimum MSE: 0.011542884500323213\n",
      "5: optimal scaling parameter (a): 2.9160897658147453, Minimum MSE: 0.003495211376111403\n",
      "6: optimal scaling parameter (a): 3.276597435983721, Minimum MSE: 0.0010400475795804263\n",
      "7: optimal scaling parameter (a): 3.6010436416224247, Minimum MSE: 0.00030436603842457166\n",
      "8: optimal scaling parameter (a): 3.884997364699907, Minimum MSE: 8.782117814336654e-05\n"
     ]
    }
   ],
   "source": [
    "from scipy.optimize import minimize\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "for bits in [1,2,3,4,5,6,7,8]:\n",
    "    # Number of quantization levels\n",
    "    N = 2**bits  # You can change this value as needed\n",
    "\n",
    "    # Objective function for minimization\n",
    "    def objective(a):\n",
    "        return compute_mse(get_uniform_grid(a[0], N))\n",
    "\n",
    "    # Initial guess for 'a'\n",
    "    a0 = [2.0]\n",
    "\n",
    "    # Bounds for 'a' to ensure it's positive\n",
    "    bounds = [(0.1, 10.0)]\n",
    "\n",
    "    # Minimize the MSE\n",
    "    result = minimize(objective, a0, bounds=bounds, method='L-BFGS-B')\n",
    "\n",
    "    # Optimal scaling parameter and corresponding MSE\n",
    "    optimal_a = result.x[0]\n",
    "    minimum_mse = result.fun\n",
    "\n",
    "    GRID_MSES[bits] = minimum_mse\n",
    "    print(f\"{bits}: optimal scaling parameter (a): {optimal_a}, Minimum MSE: {minimum_mse}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimal scaling parameter (a): 0.487079483934662\n",
      "Minimum MSE: 0.012684904138719949\n"
     ]
    }
   ],
   "source": [
    "from scipy.optimize import minimize\n",
    "\n",
    "# Number of quantization levels\n",
    "N = 16  # You can change this value as needed\n",
    "\n",
    "# Objective function for minimization\n",
    "def objective(a):\n",
    "    return compute_mse(get_fp4_grid(a[0]))\n",
    "\n",
    "# Initial guess for 'a'\n",
    "a0 = [1.0]\n",
    "\n",
    "# Bounds for 'a' to ensure it's positive\n",
    "bounds = [(0.1, 10.0)]\n",
    "\n",
    "# Minimize the MSE\n",
    "result = minimize(objective, a0, bounds=bounds, method='L-BFGS-B')\n",
    "\n",
    "# Optimal scaling parameter and corresponding MSE\n",
    "optimal_a = result.x[0]\n",
    "minimum_mse = result.fun\n",
    "\n",
    "GRID_MSES[\"fp4\"] = minimum_mse\n",
    "\n",
    "print(f\"Optimal scaling parameter (a): {optimal_a}\")\n",
    "print(f\"Minimum MSE: {minimum_mse}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1: np.float64(0.3633802276324186),\n",
       " 2: np.float64(0.11884605038769407),\n",
       " 3: np.float64(0.03743965939152373),\n",
       " 4: np.float64(0.011542884500323213),\n",
       " 8: np.float64(8.782117814336654e-05),\n",
       " 'fp4': np.float64(0.012684904138719949)}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "GRID_MSES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimal scaling parameter (a): 1.2240089519030855\n",
      "Minimum MSE: 0.19017403925019966\n"
     ]
    }
   ],
   "source": [
    "from scipy.optimize import minimize\n",
    "\n",
    "# Number of quantization levels\n",
    "N = 3  # You can change this value as needed\n",
    "\n",
    "# Objective function for minimization\n",
    "def objective(a):\n",
    "    return compute_mse(get_uniform_grid(a[0], N))\n",
    "\n",
    "# Initial guess for 'a'\n",
    "a0 = [2.0]\n",
    "\n",
    "# Bounds for 'a' to ensure it's positive\n",
    "bounds = [(0.1, 10.0)]\n",
    "\n",
    "# Minimize the MSE\n",
    "result = minimize(objective, a0, bounds=bounds, method='L-BFGS-B')\n",
    "\n",
    "# Optimal scaling parameter and corresponding MSE\n",
    "optimal_a = result.x[0]\n",
    "minimum_mse = result.fun\n",
    "\n",
    "print(f\"Optimal scaling parameter (a): {optimal_a}\")\n",
    "print(f\"Minimum MSE: {minimum_mse}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-1.22400895,  0.        ,  1.22400895])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_uniform_grid(optimal_a, N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
