{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ce8ebf1-20c7-43e9-be5e-32a4fac01e80",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.special import jv\n",
    "import pprint\n",
    "import json\n",
    "\n",
    "def sign_ste(x):\n",
    "    return torch.sign(x)\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class BitLinearNTK_from_pretrained(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim, m=1000):\n",
    "        super(BitLinearNTK_from_pretrained, self).__init__()\n",
    "        self.m = m\n",
    "        self.W = nn.Parameter(torch.randn(input_dim, self.m)) \n",
    "        self.relu = nn.ReLU()\n",
    "        a_tensor = (2.0 * torch.randint(0, 2, (self.m, output_dim)) - 1.0).float()\n",
    "        self.a = nn.Parameter(a_tensor, requires_grad=False)\n",
    "\n",
    "    def forward(self, X):\n",
    "        device = X.device\n",
    "        self.W = self.W.to(device)\n",
    "        self.a = self.a.to(device)\n",
    "        \n",
    "        def sign_ste(input):\n",
    "            return torch.sign(input)\n",
    "\n",
    "        E = self.W.mean(0).unsqueeze(0) \n",
    "        sqrt_V = ((self.W - E) ** 2).sum(0).unsqueeze(0) ** 0.5\n",
    "        wt_W = sign_ste((self.W - E) / sqrt_V)\n",
    "        Y = self.relu(torch.matmul(X, wt_W) * sqrt_V + E * X.sum(-1).unsqueeze(-1))\n",
    "        Y = torch.matmul(Y, self.a) / (self.m ** 0.5) \n",
    "        \n",
    "        return Y\n",
    "\n",
    "class LinearNTK_from_pretrained(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim, m=1000):\n",
    "        super(LinearNTK_from_pretrained, self).__init__()\n",
    "        self.m = m\n",
    "        self.W = nn.Parameter(torch.randn(input_dim, self.m)) \n",
    "        self.relu = nn.ReLU()\n",
    "        a_tensor = (2.0 * torch.randint(0, 2, (self.m, output_dim)) - 1.0).float()\n",
    "        self.a = nn.Parameter(a_tensor, requires_grad=False)\n",
    "\n",
    "    def forward(self, X):\n",
    "        device = X.device\n",
    "        self.W = self.W.to(device)\n",
    "        self.a = self.a.to(device)\n",
    "        \n",
    "        def sign_ste(input):\n",
    "            return torch.sign(input)\n",
    "\n",
    "        Y = self.relu(torch.matmul(X, self.W))\n",
    "        Y = torch.matmul(Y, self.a) / (self.m ** 0.5)\n",
    "        \n",
    "        return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f7e2b57-1ca3-4be6-9c85-2c8e34ff489a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BitLinearNTKNetwork(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim, k, m=1000):\n",
    "        super(BitLinearNTKNetwork, self).__init__()\n",
    "        layers = []\n",
    "        if k == 1:\n",
    "            layers.append(BitLinearNTK_from_pretrained(input_dim, output_dim, m))\n",
    "        else:\n",
    "            layers.append(BitLinearNTK_from_pretrained(input_dim, hidden_dim, m))\n",
    "            for _ in range(k - 2):\n",
    "                layers.append(BitLinearNTK_from_pretrained(hidden_dim, hidden_dim, m))\n",
    "            layers.append(BitLinearNTK_from_pretrained(hidden_dim, output_dim, m))\n",
    "        self.layers = nn.ModuleList(layers)\n",
    "\n",
    "    def forward(self, X):\n",
    "        out = X\n",
    "        for layer in self.layers:\n",
    "            out = layer(out)\n",
    "        return out\n",
    "\n",
    "class LinearNTKNetwork(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim, k, m=1000):\n",
    "        super(LinearNTKNetwork, self).__init__()\n",
    "        layers = []\n",
    "        if k == 1:\n",
    "            layers.append(LinearNTK_from_pretrained(input_dim, output_dim, m))\n",
    "        else:\n",
    "            layers.append(LinearNTK_from_pretrained(input_dim, hidden_dim, m))\n",
    "            for _ in range(k - 2):\n",
    "                layers.append(LinearNTK_from_pretrained(hidden_dim, hidden_dim, m))\n",
    "            layers.append(LinearNTK_from_pretrained(hidden_dim, output_dim, m))\n",
    "        self.layers = nn.ModuleList(layers)\n",
    "\n",
    "    def forward(self, X):\n",
    "        out = X\n",
    "        for layer in self.layers:\n",
    "            out = layer(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44866a3f-49ca-4f8f-9923-156396278673",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_parameters(input_dim, hidden_dim, output_dim, k):\n",
    "    m = 10 * hidden_dim\n",
    "    total_params = (input_dim * m) + (m * output_dim)\n",
    "    if k > 1:\n",
    "        total_params += (input_dim * m) + (m * hidden_dim)\n",
    "        total_params += (k - 2)* ((hidden_dim * m) + (m * hidden_dim))\n",
    "        total_params += (hidden_dim * m) + (m * output_dim)\n",
    "    \n",
    "    return total_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "109942b3-eb2f-4dab-bc2a-5494adc0fa14",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TargetFunction:\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        self.input_dim = input_dim\n",
    "        self.output_dim = output_dim\n",
    "\n",
    "    def __call__(self, x):\n",
    "        raise NotImplementedError(\"Subclasses must implement this method\")\n",
    "\n",
    "    def generate_data(self, num_samples, low=-1, high=1):\n",
    "        x = torch.rand(num_samples, self.input_dim) * (high - low) + low\n",
    "        y = self(x)  \n",
    "        return x, y\n",
    "\n",
    "    def latex_format(self):\n",
    "        raise NotImplementedError(\"Subclasses must provide their LaTeX format\")\n",
    "\n",
    "\n",
    "class ExpSinCosTargetFunction(TargetFunction):\n",
    "    def __init__(self):\n",
    "        super().__init__(input_dim=5, output_dim=1)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        temp = (torch.pi * x) / 2\n",
    "        cos_squared = torch.cos(temp) ** 2\n",
    "        sum_cos_squared = cos_squared.sum(dim=1)\n",
    "        exp_argument = (1 / 5) * sum_cos_squared\n",
    "        y = torch.exp(exp_argument)\n",
    "        return y.unsqueeze(1)\n",
    "\n",
    "    def latex_format(self):\n",
    "        return r\"f(x_1, \\dots, x_5) = \\exp\\left(\\frac{1}{5} \\left( \\sum_{i=1}^{5} \\sin^2 \\left(\\frac{\\pi x_i}{2}\\right) \\right)\\right)\"\n",
    "\n",
    "\n",
    "class LogSquareSineExpTargetFunction(TargetFunction):\n",
    "    def __init__(self):\n",
    "        super().__init__(input_dim=4, output_dim=1)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        log_term = torch.log1p(torch.abs(x[:, 0]))\n",
    "        square_term = x[:, 1] ** 2 - x[:, 1]\n",
    "        sine_term = torch.sin(x[:, 2])\n",
    "        exp_term = -torch.exp(x[:, 3]) \n",
    "        y = log_term + square_term + sine_term + exp_term\n",
    "        return y.unsqueeze(1)\n",
    "\n",
    "    def latex_format(self):\n",
    "        return r\"$f(x_1, x_2, x_3, x_4) = \\log(1 + |x_1|) + x_2^2 - x_2 + \\sin(x_3) - \\exp(x_4)$\"\n",
    "\n",
    "class CustomXYZTargetFunction(TargetFunction):\n",
    "    def __init__(self):\n",
    "        super().__init__(input_dim=4, output_dim=1)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        y = x[:, 0] * x[:, 1] - x[:, 2]\n",
    "        return y.unsqueeze(1)\n",
    "\n",
    "    def latex_format(self):\n",
    "        return r\"$f(x_1, x_2, x_3) = x_1\\times x_2 - x_3$\"\n",
    "\n",
    "\n",
    "class CustomXYTargetFunction(TargetFunction):\n",
    "    def __init__(self):\n",
    "        super().__init__(input_dim=4, output_dim=1)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        y = (x[:, 0] * torch.sin(x[:, 1]) + torch.cos(x[:, 2]) - 0.5 * x[:, 3])\n",
    "        return y.unsqueeze(1)\n",
    "\n",
    "    def latex_format(self):\n",
    "        return r\"$f(x_1, x_2, x_3, x_4) = x_1 \\sin(x_2) + \\cos(x_3) - 0.5 x_4$\"\n",
    "\n",
    "class CustomTanHargetFunction(TargetFunction):\n",
    "    def __init__(self):\n",
    "        super().__init__(input_dim=4, output_dim=1)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        term1 = (x[:, 0] ** 2) / (1 + torch.abs(x[:, 1]))\n",
    "        term2 = -torch.exp(x[:, 2])\n",
    "        term3 = torch.tanh(x[:, 3])\n",
    "        term4 = torch.sqrt(torch.abs(x[:, 0] * x[:, 2]))\n",
    "        \n",
    "        y = term1 + term2 + term3 + term4\n",
    "        return y.unsqueeze(1)\n",
    "\n",
    "    def latex_format(self):\n",
    "        return r\"$f(x_1, x_2, x_3, x_4) = \\frac{x_1^2}{1 + |x_2|} - e^{x_3} + \\tanh(x_4) + \\sqrt{|x_1 \\cdot x_3|}$\"\n",
    "\n",
    "class LambertWGammaTargetFunction(TargetFunction):\n",
    "    def __init__(self):\n",
    "        super().__init__(input_dim=4, output_dim=1)\n",
    "\n",
    "    def lambertw(self, x):\n",
    "        return torch.from_numpy(sp.lambertw(x.cpu().numpy()).real).to(x.device)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        term1 = lambertw(x[:, 0] * x[:, 1])\n",
    "        term2 = x[:, 2] / torch.log(1 + torch.exp(x[:, 3])) \n",
    "        term3 = -torch.from_numpy(sp.gamma(x[:, 1].cpu().numpy())).to(x.device) / (1 + torch.abs(x[:, 0]))\n",
    "        \n",
    "        y = term1 + term2 + term3\n",
    "        return y.unsqueeze(1)\n",
    "\n",
    "    def latex_format(self):\n",
    "        return r\"$f(x_1, x_2, x_3, x_4) = LambertW(x_1 x_2) + \\frac{x_3}{\\log(1 + \\exp(x_4))} - \\frac{\\Gamma(x_2)}{1 + |x_1|}$\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45afe1b0-435e-4289-8cf2-a47a53f58570",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, optimizer, scheduler, criterion, X_train, Y_train, device, num_epochs=100000, clip_value=1.0, print_every=100):\n",
    "    model = model.to(device)\n",
    "    losses = []\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        outputs = model(X_train)\n",
    "        loss = criterion(outputs, Y_train) * 0.5  # Scale loss by 0.5\n",
    "\n",
    "        loss.backward()\n",
    "        nn.utils.clip_grad_norm_(model.parameters(), clip_value)\n",
    "        optimizer.step()\n",
    "\n",
    "        scheduler.step(loss)\n",
    "\n",
    "        if (epoch + 1) % print_every == 0:\n",
    "            loss_value = loss.item()\n",
    "            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss_value:.4f}')\n",
    "            losses.append(loss_value)\n",
    "\n",
    "    return min(losses)\n",
    "\n",
    "def run_experiment_1bit(X_train, Y_train, K, HIDDEN_DIM, INPUT_DIM, OUTPUT_DIM, device, num_epochs=10000):\n",
    "    res = {}\n",
    "    criterion = nn.MSELoss(reduction='sum')\n",
    "\n",
    "    for k in K:\n",
    "        for hd in HIDDEN_DIM:\n",
    "            M = 10 * hd\n",
    "            print(f\"Running experiment with hidden_dim={hd}, k={k}\")\n",
    "\n",
    "            model = BitLinearNTKNetwork(INPUT_DIM, hd, OUTPUT_DIM, k, M)\n",
    "            optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)\n",
    "\n",
    "            min_loss = train_model(model, optimizer, scheduler, criterion, X_train, Y_train, device, num_epochs)\n",
    "            res[(hd, k)] = min_loss\n",
    "\n",
    "    return res\n",
    "\n",
    "def run_experiment_standard(X_train, Y_train, K, HIDDEN_DIM, INPUT_DIM, OUTPUT_DIM, device, num_epochs=10000):\n",
    "    res = {}\n",
    "    criterion = nn.MSELoss(reduction='sum')\n",
    "\n",
    "    for k in K:\n",
    "        for hd in HIDDEN_DIM:\n",
    "            M = 10 * hd\n",
    "            print(f\"Running experiment with hidden_dim={hd}, k={k}\")\n",
    "\n",
    "            model = LinearNTKNetwork(INPUT_DIM, hd, OUTPUT_DIM, k, M)\n",
    "            optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)\n",
    "\n",
    "            min_loss = train_model(model, optimizer, scheduler, criterion, X_train, Y_train, device, num_epochs)\n",
    "            res[(hd, k)] = min_loss\n",
    "\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "020be325-29f4-4576-ae5a-ef285d1f9dc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_DIM = 4\n",
    "HIDDEN_DIM = [256, 128, 64, 32]\n",
    "OUTPUT_DIM = 1      \n",
    "K = [5, 3]\n",
    "\n",
    "func = LogSquareSineExpTargetFunction()\n",
    "\n",
    "num_epochs = 100000\n",
    "num_samples = 100\n",
    "device = torch.device(\"cpu\")\n",
    "X_train, Y_train = func.generate_data(num_samples)\n",
    "X_train, Y_train = X_train.float().to(device), Y_train.float().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6d57765-7f02-4430-acb0-9f4818b75c9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_standard = run_experiment_standard(X_train, Y_train, K, HIDDEN_DIM, INPUT_DIM, OUTPUT_DIM, device, num_epochs)\n",
    "results_1bit = run_experiment_1bit(X_train, Y_train, K, HIDDEN_DIM, INPUT_DIM, OUTPUT_DIM, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b5d7cad-4204-4d49-b443-a54bb6fa778b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rc\n",
    "import numpy as np\n",
    "from matplotlib.ticker import LogLocator\n",
    "\n",
    "\n",
    "layer_sizes = [3, 5]\n",
    "results_1bit = {layer: {key[0]: results_1bit[key] for key in results_1bit if key[1] == layer} for layer in layer_sizes}\n",
    "results_standard = {layer: {key[0]: results_standard[key] for key in results_standard if key[1] == layer} for layer in layer_sizes}\n",
    "\n",
    "rc('font', family='sans-serif', size=20)\n",
    "rc('text', usetex=True)\n",
    "colors = ['#FF3E30', '#176BEF', '#107C10', '#FFA900']\n",
    "markers = ['o', '^', 's', 'D']\n",
    "linestyles = ['-', '-', '-', '--']\n",
    "markersizes = [14, 18, 10, 14]\n",
    "lw = 7.0\n",
    "lsize = 55\n",
    "xsize = 55\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(15, 12))\n",
    "for i, layer in enumerate(layer_sizes):\n",
    "    x_vals = sorted(results_1bit[layer].keys())\n",
    "    y_vals = [results_1bit[layer][x] for x in x_vals]\n",
    "    p_vals = [count_parameters(INPUT_DIM, x, OUTPUT_DIM, layer) for x in x_vals]\n",
    "    ax.plot(p_vals, y_vals, marker=markers[i], linestyle=linestyles[i], color=colors[i],\n",
    "            linewidth=lw, markersize=markersizes[i], markeredgecolor='k', markeredgewidth=1.5,\n",
    "            label=f'MLP 1-bit (depth {layer})')\n",
    "    \n",
    "for i, layer in enumerate(layer_sizes):\n",
    "    x_vals = sorted(results_standard[layer].keys())\n",
    "    y_vals = [results_standard[layer][x] for x in x_vals]\n",
    "    p_vals = [count_parameters(INPUT_DIM, x, OUTPUT_DIM, layer) for x in x_vals]\n",
    "    ax.plot(p_vals, y_vals, marker=markers[i], linestyle=linestyles[i], color=colors[i+2],\n",
    "            linewidth=lw, markersize=markersizes[i], markeredgecolor='k', markeredgewidth=1.5,\n",
    "            label=f'MLP FP32 (depth {layer})')\n",
    "\n",
    "ax.tick_params(axis='both', which='major', labelsize=xsize, pad=10)\n",
    "ax.set_xlabel('Hidden Dimension', size=55, labelpad=10)\n",
    "ax.set_ylabel('Loss', size=55, labelpad=15)\n",
    "ax.set_xscale('log', base=2)\n",
    "ax.set_yscale('log', base=10)\n",
    "\n",
    "ax.grid()\n",
    "\n",
    "ax.legend(loc='lower left', fontsize=38, framealpha=0.5, edgecolor='k')\n",
    "\n",
    "ax.set_title(func.latex_format(), fontsize=55)\n",
    "\n",
    "fig.tight_layout(pad=2)\n",
    "fig.savefig(\"f6.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da02cb86-0c77-4fcf-808b-47c8f32485d3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kvsketch",
   "language": "python",
   "name": "kv_sketch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
