{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ce8ebf1-20c7-43e9-be5e-32a4fac01e80",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.special import jv\n",
    "import pprint\n",
    "import json\n",
    "\n",
    "def sign_ste(x):\n",
    "    return torch.sign(x)\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class BitLinearNTK_from_pretrained(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim, m=1000):\n",
    "        super(BitLinearNTK_from_pretrained, self).__init__()\n",
    "        self.m = m\n",
    "        self.W = nn.Parameter(torch.randn(input_dim, self.m)) \n",
    "        self.relu = nn.ReLU()\n",
    "        a_tensor = (2.0 * torch.randint(0, 2, (self.m, output_dim)) - 1.0).float()\n",
    "        self.a = nn.Parameter(a_tensor, requires_grad=False)\n",
    "        self.b1 = nn.Parameter(torch.randn(self.m)) \n",
    "        self.b2 = nn.Parameter(torch.randn(output_dim)) \n",
    "\n",
    "    def forward(self, X):\n",
    "        device = X.device\n",
    "        self.W = self.W.to(device)\n",
    "        self.a = self.a.to(device)\n",
    "        \n",
    "        def sign_ste(input):\n",
    "            return torch.sign(input)\n",
    "        \n",
    "        wt_W = sign_ste((self.W))\n",
    "        Y = self.relu(torch.matmul(X, wt_W) + self.b1)\n",
    "        Y = torch.matmul(Y, self.a) / (self.m ** 0.5) + self.b2\n",
    "\n",
    "        return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f7e2b57-1ca3-4be6-9c85-2c8e34ff489a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BitLinearNTKNetwork(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim, k, m=1000):\n",
    "        super(BitLinearNTKNetwork, self).__init__()\n",
    "        layers = []\n",
    "        if k == 1:\n",
    "            layers.append(BitLinearNTK_from_pretrained(input_dim, output_dim, m))\n",
    "        else:\n",
    "            layers.append(BitLinearNTK_from_pretrained(input_dim, hidden_dim, m))\n",
    "            for _ in range(k - 2):\n",
    "                layers.append(BitLinearNTK_from_pretrained(hidden_dim, hidden_dim, m))\n",
    "            layers.append(BitLinearNTK_from_pretrained(hidden_dim, output_dim, m))\n",
    "        self.layers = nn.ModuleList(layers)\n",
    "\n",
    "    def forward(self, X):\n",
    "        out = X\n",
    "        for layer in self.layers:\n",
    "            out = layer(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aecbe2b4-842a-42f6-8410-50ea3b14ee2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_parameters(input_dim, hidden_dim, output_dim, k):\n",
    "    m = 10 * hidden_dim\n",
    "    total_params = (input_dim * m) + (m * output_dim)\n",
    "    if k > 1:\n",
    "        total_params += (input_dim * m) + (m * hidden_dim)\n",
    "        total_params += (k - 2)* ((hidden_dim * m) + (m * hidden_dim))\n",
    "        total_params += (hidden_dim * m) + (m * output_dim)\n",
    "    \n",
    "    return total_params\n",
    "\n",
    "def human_readable_params(num):\n",
    "    if num >= 1_000_000: \n",
    "        return f'{num / 1_000_000:.2f}M Parameters'\n",
    "    elif num >= 1_000: \n",
    "        return f'{num / 1_000:.2f}K Parameters'\n",
    "    else: \n",
    "        return f'{num} Parameters'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb9bb097-b200-4fe4-9494-c827d3cec562",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TargetFunction:\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        self.input_dim = input_dim\n",
    "        self.output_dim = output_dim\n",
    "\n",
    "    def __call__(self, x):\n",
    "        raise NotImplementedError(\"Subclasses must implement this method\")\n",
    "\n",
    "    def generate_train_data(self, num_samples, low=-np.pi, high=np.pi):\n",
    "        \"\"\"Generate uniformly spaced training data.\"\"\"\n",
    "        x = torch.linspace(low, high, num_samples).unsqueeze(1)\n",
    "        y = self(x)\n",
    "        return x, y\n",
    "\n",
    "    def generate_test_data(self, num_samples, low=-np.pi, high=np.pi, seed=42):\n",
    "        \"\"\"Generate randomly spaced testing data.\"\"\"\n",
    "        torch.manual_seed(seed)\n",
    "        np.random.seed(seed)\n",
    "        x = torch.rand(num_samples, self.input_dim) * (high - low) + low\n",
    "        x, _ = torch.sort(x, dim=0)\n",
    "        y = self(x)\n",
    "        return x, y\n",
    "\n",
    "    def latex_format(self):\n",
    "        raise NotImplementedError(\"Subclasses must provide their LaTeX format\")\n",
    "\n",
    "\n",
    "class BesselTargetFunction(TargetFunction):\n",
    "    def __init__(self):\n",
    "        super().__init__(input_dim=1, output_dim=1)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return torch.from_numpy(jv(0, 20 * x.numpy()))\n",
    "\n",
    "    def latex_format(self):\n",
    "        return r\"f(x) = J_0(20x)\"\n",
    "\n",
    "\n",
    "class LogSinTanExpTargetFunction(TargetFunction):\n",
    "    def __init__(self):\n",
    "        super().__init__(input_dim=1, output_dim=1)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return torch.log(1 + torch.abs(x)) * torch.sin(10 * x) + 10 * torch.tan(x / 6) * torch.exp(-x**2)\n",
    "\n",
    "    def latex_format(self):\n",
    "        return r\"f(x) = \\log(1 + |x|) \\sin(10x) + 10 \\tan\\left(\\frac{x}{6}\\right) \\exp(-x^2)\"\n",
    "\n",
    "\n",
    "class ExpSinHeavisideTargetFunction(TargetFunction):\n",
    "    def __init__(self):\n",
    "        super().__init__(input_dim=1, output_dim=1)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return torch.exp(-0.5 * torch.abs(x)) * torch.sin(5 * x) + torch.heaviside(x - 1, torch.tensor(0.0))\n",
    "\n",
    "    def latex_format(self):\n",
    "        return r\"f(x) = \\exp(-0.5 |x|) \\sin(5x) + H(x - 1)\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1917ea53-8b2a-46c6-a686-91bd4f88ad73",
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_DIM = 1\n",
    "HIDDEN_DIM = [10, 100, 1000]\n",
    "OUTPUT_DIM = 1      \n",
    "K = [2]\n",
    "num_epochs = 10000\n",
    "num_samples = 100\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "func = BesselTargetFunction()\n",
    "X_train, Y_train = func.generate_train_data(num_samples, low=-np.pi, high=np.pi)\n",
    "X_test, Y_test = func.generate_test_data(num_samples, low=-np.pi, high=np.pi)\n",
    "X_train, Y_train = X_train.float().to(device), Y_train.float().to(device)\n",
    "X_test, Y_test = X_test.float().to(device), Y_test.float().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "672bbbdc-c780-48a9-908e-e0addfaeb23f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, optimizer, scheduler, criterion, X_train, Y_train, device, num_epochs=100000, clip_value=1.0, print_every=100):\n",
    "    model = model.to(device)\n",
    "    error_train, error_test = [], []\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        outputs = model(X_train)\n",
    "        loss = criterion(outputs, Y_train) * 0.5  # Scale loss by 0.5\n",
    "\n",
    "        loss.backward()\n",
    "        nn.utils.clip_grad_norm_(model.parameters(), clip_value)\n",
    "        optimizer.step()\n",
    "\n",
    "        scheduler.step(loss)\n",
    "\n",
    "        if (epoch + 1) % print_every == 0:\n",
    "            loss_value = loss.item()\n",
    "            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss_value:.4f}')\n",
    "            y_pred_test = model(X_test)\n",
    "            y_pred_train = model(X_train)\n",
    "            \n",
    "            error_train.append(torch.norm(y_pred_train - Y_train).cpu().detach().numpy())\n",
    "            error_test.append(torch.norm(y_pred_test - Y_test).cpu().detach().numpy())\n",
    "\n",
    "    return {\"train\" : error_train, \"test\": error_test}\n",
    "\n",
    "def run_experiment_1bit(X_train, Y_train, K, HIDDEN_DIM, INPUT_DIM, OUTPUT_DIM, device, num_epochs=10000):\n",
    "    res = {}\n",
    "    criterion = nn.MSELoss(reduction='sum')\n",
    "\n",
    "    for k in K:\n",
    "        for hd in HIDDEN_DIM:\n",
    "            M = 10 * hd\n",
    "            print(f\"Running experiment with hidden_dim={hd}, k={k}\")\n",
    "\n",
    "            model = BitLinearNTKNetwork(INPUT_DIM, hd, OUTPUT_DIM, k, M)\n",
    "            optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)\n",
    "\n",
    "            min_loss = train_model(model, optimizer, scheduler, criterion, X_train, Y_train, device, num_epochs)\n",
    "            res[(hd, k)] = min_loss\n",
    "\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57cbc07e-e289-4e29-bf59-80bd3e717bb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_1bit = run_experiment_1bit(X_train, Y_train, K, HIDDEN_DIM, INPUT_DIM, OUTPUT_DIM, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79c9fb9d-f481-45c4-9a6f-f81e71e1eff8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rc\n",
    "import numpy as np\n",
    "\n",
    "# Configure plot fonts and styles\n",
    "rc('font', family='sans-serif', size=30)\n",
    "rc('text', usetex=True)\n",
    "\n",
    "colors = ['#176BEF', '#cc8800', '#FF3E30', '#107C10', '#FFA900', '#cc3226', '#0c5e0c', ]\n",
    "markers = ['o', '*', 's', '^', 'D', 'P', 'X']\n",
    "linestyles = ['-', '--', ':', '-.']\n",
    "\n",
    "def plot_train_test_errors_all(res_standard, configs):\n",
    "    plt.figure(figsize=(10, 10))\n",
    "\n",
    "    for idx, config_key in enumerate(configs):\n",
    "        train_errors = res_standard[config_key][0]\n",
    "        test_errors = res_standard[config_key][1]\n",
    "\n",
    "        train_errors = [x for x in train_errors]\n",
    "        test_errors = [x for x in test_errors]\n",
    "\n",
    "        epochs = range(1, len(train_errors) + 1)\n",
    "\n",
    "        plt.plot(epochs, train_errors, label=f'Train ({human_readable_params(count_parameters(1, config_key[0], 1, 2))})', linestyle='-', color=colors[idx], marker=markers[idx])\n",
    "        plt.plot(epochs, test_errors, label=f'Test ({human_readable_params(count_parameters(1, config_key[0], 1, 2))})', linestyle='-', color=colors[idx + 4], marker=markers[idx])\n",
    "\n",
    "    plt.xlabel('Epoch')\n",
    "    plt.ylabel('Error')\n",
    "    plt.grid(True)    \n",
    "    \n",
    "    # plt.legend(loc='upper center', fontsize=50, ncol=3)\n",
    "\n",
    "    # Show the plot\n",
    "    plt.savefig('error_f2.pdf', bbox_inches='tight')\n",
    "    plt.show()\n",
    "    \n",
    "\n",
    "configs = [(10, 2), (100, 2), (1000, 2)]\n",
    "plot_train_test_errors_all(results_1bit, configs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kvsketch",
   "language": "python",
   "name": "kv_sketch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
