{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ce8ebf1-20c7-43e9-be5e-32a4fac01e80",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.special import jv\n",
    "import pprint\n",
    "import json\n",
    "\n",
    "def sign_ste(x):\n",
    "    return torch.sign(x)\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class BitLinearNTK_from_pretrained(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim, m=1000):\n",
    "        super(BitLinearNTK_from_pretrained, self).__init__()\n",
    "        self.m = m\n",
    "        self.W = nn.Parameter(torch.randn(input_dim, self.m)) \n",
    "        self.relu = nn.ReLU()\n",
    "        a_tensor = (2.0 * torch.randint(0, 2, (self.m, output_dim)) - 1.0).float()\n",
    "        self.a = nn.Parameter(a_tensor, requires_grad=False)\n",
    "        self.b1 = nn.Parameter(torch.randn(self.m)) \n",
    "        self.b2 = nn.Parameter(torch.randn(output_dim)) \n",
    "\n",
    "    def forward(self, X):\n",
    "        device = X.device\n",
    "        self.W = self.W.to(device)\n",
    "        self.a = self.a.to(device)\n",
    "        \n",
    "        def sign_ste(input):\n",
    "            return torch.sign(input)\n",
    "        \n",
    "        wt_W = sign_ste((self.W ))\n",
    "        Y = self.relu(torch.matmul(X, wt_W) + self.b1)\n",
    "        Y = torch.matmul(Y, self.a) / (self.m ** 0.5) + self.b2\n",
    "        \n",
    "        return Y\n",
    "\n",
    "class LinearNTK_from_pretrained(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim, m=1000):\n",
    "        super(LinearNTK_from_pretrained, self).__init__()\n",
    "        self.m = m\n",
    "        self.W = nn.Parameter(torch.randn(input_dim, self.m)) \n",
    "        self.relu = nn.ReLU()\n",
    "        a_tensor = (2.0 * torch.randint(0, 2, (self.m, output_dim)) - 1.0).float()\n",
    "        self.a = nn.Parameter(a_tensor, requires_grad=False)\n",
    "        self.b1 = nn.Parameter(torch.randn(self.m)) \n",
    "        self.b2 = nn.Parameter(torch.randn(output_dim)) \n",
    "\n",
    "    def forward(self, X):\n",
    "        device = X.device\n",
    "        self.W = self.W.to(device)\n",
    "        self.a = self.a.to(device)\n",
    "        \n",
    "        def sign_ste(input):\n",
    "            return torch.sign(input)\n",
    "\n",
    "        Y = self.relu(torch.matmul(X, self.W) + self.b1)\n",
    "        Y = torch.matmul(Y, self.a) / (self.m ** 0.5) + self.b2\n",
    "        \n",
    "        return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f7e2b57-1ca3-4be6-9c85-2c8e34ff489a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BitLinearNTKNetwork(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim, k, m=1000):\n",
    "        super(BitLinearNTKNetwork, self).__init__()\n",
    "        layers = []\n",
    "        if k == 1:\n",
    "            layers.append(BitLinearNTK_from_pretrained(input_dim, output_dim, m))\n",
    "        else:\n",
    "            layers.append(BitLinearNTK_from_pretrained(input_dim, hidden_dim, m))\n",
    "            for _ in range(k - 2):\n",
    "                layers.append(BitLinearNTK_from_pretrained(hidden_dim, hidden_dim, m))\n",
    "            layers.append(BitLinearNTK_from_pretrained(hidden_dim, output_dim, m))\n",
    "        self.layers = nn.ModuleList(layers)\n",
    "\n",
    "    def forward(self, X):\n",
    "        out = X\n",
    "        for layer in self.layers:\n",
    "            out = layer(out)\n",
    "        return out\n",
    "\n",
    "class LinearNTKNetwork(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim, k, m=1000):\n",
    "        super(LinearNTKNetwork, self).__init__()\n",
    "        layers = []\n",
    "        if k == 1:\n",
    "            layers.append(LinearNTK_from_pretrained(input_dim, output_dim, m))\n",
    "        else:\n",
    "            layers.append(LinearNTK_from_pretrained(input_dim, hidden_dim, m))\n",
    "            for _ in range(k - 2):\n",
    "                layers.append(LinearNTK_from_pretrained(hidden_dim, hidden_dim, m))\n",
    "            layers.append(LinearNTK_from_pretrained(hidden_dim, output_dim, m))\n",
    "        self.layers = nn.ModuleList(layers)\n",
    "\n",
    "    def forward(self, X):\n",
    "        out = X\n",
    "        for layer in self.layers:\n",
    "            out = layer(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44866a3f-49ca-4f8f-9923-156396278673",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_parameters(input_dim, hidden_dim, output_dim, k):\n",
    "    m = 10 * hidden_dim\n",
    "    total_params = (input_dim * m) + (m * output_dim)\n",
    "    if k > 1:\n",
    "        total_params += (input_dim * m) + (m * hidden_dim)\n",
    "        total_params += (k - 2)* ((hidden_dim * m) + (m * hidden_dim))\n",
    "        total_params += (hidden_dim * m) + (m * output_dim)\n",
    "    \n",
    "    return total_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "120b380e-1141-48df-a70a-c5d98519c450",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TargetFunction:\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        self.input_dim = input_dim\n",
    "        self.output_dim = output_dim\n",
    "\n",
    "    def __call__(self, x):\n",
    "        raise NotImplementedError(\"Subclasses must implement this method\")\n",
    "\n",
    "    def generate_train_data(self, num_samples, low=-np.pi, high=np.pi):\n",
    "        \"\"\"Generate uniformly spaced training data.\"\"\"\n",
    "        x = torch.linspace(low, high, num_samples).unsqueeze(1)\n",
    "        y = self(x)\n",
    "        return x, y\n",
    "\n",
    "    def generate_test_data(self, num_samples, low=-np.pi, high=np.pi, seed=42):\n",
    "        \"\"\"Generate randomly spaced testing data.\"\"\"\n",
    "        torch.manual_seed(seed)\n",
    "        np.random.seed(seed)\n",
    "        x = torch.rand(num_samples, self.input_dim) * (high - low) + low\n",
    "        x, _ = torch.sort(x, dim=0)\n",
    "        y = self(x)\n",
    "        return x, y\n",
    "\n",
    "    def latex_format(self):\n",
    "        raise NotImplementedError(\"Subclasses must provide their LaTeX format\")\n",
    "\n",
    "\n",
    "class BesselTargetFunction(TargetFunction):\n",
    "    def __init__(self):\n",
    "        super().__init__(input_dim=1, output_dim=1)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return torch.from_numpy(jv(0, 20 * x.numpy()))\n",
    "\n",
    "    def latex_format(self):\n",
    "        return r\"f(x) = J_0(20x)\"\n",
    "\n",
    "\n",
    "class LogSinTanExpTargetFunction(TargetFunction):\n",
    "    def __init__(self):\n",
    "        super().__init__(input_dim=1, output_dim=1)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return torch.log(1 + torch.abs(x)) * torch.sin(10 * x) + 10 * torch.tan(x / 6) * torch.exp(-x**2)\n",
    "\n",
    "    def latex_format(self):\n",
    "        return r\"f(x) = \\log(1 + |x|) \\sin(10x) + 10 \\tan\\left(\\frac{x}{6}\\right) \\exp(-x^2)\"\n",
    "\n",
    "\n",
    "class ExpSinHeavisideTargetFunction(TargetFunction):\n",
    "    def __init__(self):\n",
    "        super().__init__(input_dim=1, output_dim=1)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return torch.exp(-0.5 * torch.abs(x)) * torch.sin(5 * x) + torch.heaviside(x - 1, torch.tensor(0.0))\n",
    "\n",
    "    def latex_format(self):\n",
    "        return r\"f(x) = \\exp(-0.5 |x|) \\sin(5x) + H(x - 1)\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1917ea53-8b2a-46c6-a686-91bd4f88ad73",
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_DIM = 1\n",
    "HIDDEN_DIM = [1000]\n",
    "OUTPUT_DIM = 1      \n",
    "K = [2]\n",
    "num_epochs = 10000\n",
    "num_samples = 100\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "func = BesselTargetFunction()\n",
    "X_train, Y_train = func.generate_train_data(num_samples, low=-np.pi, high=np.pi)\n",
    "X_test, Y_test = func.generate_test_data(num_samples, low=-np.pi, high=np.pi)\n",
    "X_train, Y_train = X_train.float().to(device), Y_train.float().to(device)\n",
    "X_test, Y_test = X_test.float().to(device), Y_test.float().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2aa42136-11f7-4854-a55a-67c6e2b0b889",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, optimizer, scheduler, criterion, X_train, Y_train, device, num_epochs=100000, clip_value=1.0, print_every=100):\n",
    "    model = model.to(device)\n",
    "    error_train, error_test = [], []\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        outputs = model(X_train)\n",
    "        loss = criterion(outputs, Y_train) * 0.5  # Scale loss by 0.5\n",
    "\n",
    "        loss.backward()\n",
    "        nn.utils.clip_grad_norm_(model.parameters(), clip_value)\n",
    "        optimizer.step()\n",
    "\n",
    "        scheduler.step(loss)\n",
    "\n",
    "        if (epoch + 1) % print_every == 0:\n",
    "            loss_value = loss.item()\n",
    "            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss_value:.4f}')\n",
    "            y_pred_test = model(X_test)\n",
    "            y_pred_train = model(X_train)\n",
    "            \n",
    "            error_train.append(torch.norm(y_pred_train - Y_train).cpu().detach().numpy())\n",
    "            error_test.append(torch.norm(y_pred_test - Y_test).cpu().detach().numpy())\n",
    "\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        y_pred = model(X_test)\n",
    "    return y_pred\n",
    "\n",
    "def run_experiment_1bit(X_train, Y_train, K, HIDDEN_DIM, INPUT_DIM, OUTPUT_DIM, device, num_epochs=10000):\n",
    "    res = {}\n",
    "    criterion = nn.MSELoss(reduction='sum')\n",
    "\n",
    "    for k in K:\n",
    "        for hd in HIDDEN_DIM:\n",
    "            M = 10 * hd\n",
    "            print(f\"Running experiment with hidden_dim={hd}, k={k}\")\n",
    "\n",
    "            model = BitLinearNTKNetwork(INPUT_DIM, hd, OUTPUT_DIM, k, M)\n",
    "            optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)\n",
    "\n",
    "            min_loss = train_model(model, optimizer, scheduler, criterion, X_train, Y_train, device, num_epochs)\n",
    "            res[(hd, k)] = min_loss\n",
    "\n",
    "    return res\n",
    "\n",
    "def run_experiment_standard(X_train, Y_train, K, HIDDEN_DIM, INPUT_DIM, OUTPUT_DIM, device, num_epochs=10000):\n",
    "    res = {}\n",
    "    criterion = nn.MSELoss(reduction='sum')\n",
    "\n",
    "    for k in K:\n",
    "        for hd in HIDDEN_DIM:\n",
    "            M = 10 * hd\n",
    "            print(f\"Running experiment with hidden_dim={hd}, k={k}\")\n",
    "\n",
    "            model = LinearNTKNetwork(INPUT_DIM, hd, OUTPUT_DIM, k, M)\n",
    "            optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)\n",
    "\n",
    "            min_loss = train_model(model, optimizer, scheduler, criterion, X_train, Y_train, device, num_epochs)\n",
    "            res[(hd, k)] = min_loss\n",
    "\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efc64dbe-8743-4646-baed-a2ff0f393759",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_standard = run_experiment_standard(X_train, Y_train, K, HIDDEN_DIM, INPUT_DIM, OUTPUT_DIM, device, num_epochs)\n",
    "results_1bit = run_experiment_1bit(X_train, Y_train, K, HIDDEN_DIM, INPUT_DIM, OUTPUT_DIM, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad20269f-8a1e-4aea-8b14-4734d4ec682b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79c9fb9d-f481-45c4-9a6f-f81e71e1eff8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rc\n",
    "import numpy as np\n",
    "rc('font', family='sans-serif', size=20)\n",
    "rc('text', usetex=True)\n",
    "colors = ['#176BEF', '#FF3E30', '#107C10', '#FFA900', '#cc3226', '#0c5e0c', '#cc8800']\n",
    "markers = ['o', '*', 's', '^', 'D']\n",
    "linestyles = ['-', '-', '-', '--']\n",
    "\n",
    "plt.figure(figsize=(10,8))\n",
    "plt.plot(X_test.cpu().numpy(), Y_test.cpu().numpy(), label=func.latex_format(), color=colors[0], alpha=0.5)\n",
    "\n",
    "plt.scatter(X_test.cpu().numpy(), results_standard[(1000, 2)].cpu().numpy(), label='Standard MLP', color=colors[6], marker=markers[2], alpha=1, s=50)\n",
    "plt.scatter(X_test.cpu().numpy(), results_1bit[(1000, 2)].cpu().numpy(), label='1-bit MLP',color=colors[5], marker=markers[1], alpha=1, s=30)\n",
    "\n",
    "plt.legend(loc='upper left', ncol=1, fontsize=25)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.grid(True)\n",
    "plt.savefig(\"1d-f3.pdf\", )\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "798174a7-0dd2-4b91-9fbc-2d35eafb4609",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kvsketch",
   "language": "python",
   "name": "kv_sketch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
