{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "82d10b5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import math\n",
    "\n",
    "\"\"\"credit to opensource https://github.com/Blealtan/efficient-kan\"\"\"\n",
    "\n",
    "class KANLinear(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        in_features,\n",
    "        out_features,\n",
    "        grid_size=5,\n",
    "        spline_order=3,\n",
    "        scale_noise=0.1,\n",
    "        scale_base=1.0,\n",
    "        scale_spline=1.0,\n",
    "        enable_standalone_scale_spline=True,\n",
    "        base_activation=torch.nn.SiLU,\n",
    "        grid_eps=0.02,\n",
    "        grid_range=[-1, 1],\n",
    "    ):\n",
    "        super(KANLinear, self).__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.grid_size = grid_size\n",
    "        self.spline_order = spline_order\n",
    "\n",
    "        h = (grid_range[1] - grid_range[0]) / grid_size\n",
    "        grid = (\n",
    "            (\n",
    "                torch.arange(-spline_order, grid_size + spline_order + 1) * h\n",
    "                + grid_range[0]\n",
    "            )\n",
    "            .expand(in_features, -1)\n",
    "            .contiguous()\n",
    "        )\n",
    "        self.register_buffer(\"grid\", grid)\n",
    "\n",
    "        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))\n",
    "        self.spline_weight = torch.nn.Parameter(\n",
    "            torch.Tensor(out_features, in_features, grid_size + spline_order)\n",
    "        )\n",
    "        if enable_standalone_scale_spline:\n",
    "            self.spline_scaler = torch.nn.Parameter(\n",
    "                torch.Tensor(out_features, in_features)\n",
    "            )\n",
    "\n",
    "        self.scale_noise = scale_noise\n",
    "        self.scale_base = scale_base\n",
    "        self.scale_spline = scale_spline\n",
    "        self.enable_standalone_scale_spline = enable_standalone_scale_spline\n",
    "        self.base_activation = base_activation()\n",
    "        self.grid_eps = grid_eps\n",
    "\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)\n",
    "        with torch.no_grad():\n",
    "            noise = (\n",
    "                (\n",
    "                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)\n",
    "                    - 1 / 2\n",
    "                )\n",
    "                * self.scale_noise\n",
    "                / self.grid_size\n",
    "            )\n",
    "            self.spline_weight.data.copy_(\n",
    "                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)\n",
    "                * self.curve2coeff(\n",
    "                    self.grid.T[self.spline_order : -self.spline_order],\n",
    "                    noise,\n",
    "                )\n",
    "            )\n",
    "            if self.enable_standalone_scale_spline:\n",
    "                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)\n",
    "                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)\n",
    "\n",
    "    def b_splines(self, x: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Compute the B-spline bases for the given input tensor.\n",
    "\n",
    "        Args:\n",
    "            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\n",
    "\n",
    "        Returns:\n",
    "            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).\n",
    "        \"\"\"\n",
    "        assert x.dim() == 2 and x.size(1) == self.in_features\n",
    "\n",
    "        grid: torch.Tensor = (\n",
    "            self.grid\n",
    "        )  # (in_features, grid_size + 2 * spline_order + 1)\n",
    "        x = x.unsqueeze(-1)\n",
    "        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)\n",
    "        for k in range(1, self.spline_order + 1):\n",
    "            bases = (\n",
    "                (x - grid[:, : -(k + 1)])\n",
    "                / (grid[:, k:-1] - grid[:, : -(k + 1)])\n",
    "                * bases[:, :, :-1]\n",
    "            ) + (\n",
    "                (grid[:, k + 1 :] - x)\n",
    "                / (grid[:, k + 1 :] - grid[:, 1:(-k)])\n",
    "                * bases[:, :, 1:]\n",
    "            )\n",
    "\n",
    "        assert bases.size() == (\n",
    "            x.size(0),\n",
    "            self.in_features,\n",
    "            self.grid_size + self.spline_order,\n",
    "        )\n",
    "        return bases.contiguous()\n",
    "\n",
    "    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Compute the coefficients of the curve that interpolates the given points.\n",
    "\n",
    "        Args:\n",
    "            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\n",
    "            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).\n",
    "\n",
    "        Returns:\n",
    "            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).\n",
    "        \"\"\"\n",
    "        assert x.dim() == 2 and x.size(1) == self.in_features\n",
    "        assert y.size() == (x.size(0), self.in_features, self.out_features)\n",
    "\n",
    "        A = self.b_splines(x).transpose(\n",
    "            0, 1\n",
    "        )  # (in_features, batch_size, grid_size + spline_order)\n",
    "        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)\n",
    "        solution = torch.linalg.lstsq(\n",
    "            A, B\n",
    "        ).solution  # (in_features, grid_size + spline_order, out_features)\n",
    "        result = solution.permute(\n",
    "            2, 0, 1\n",
    "        )  # (out_features, in_features, grid_size + spline_order)\n",
    "\n",
    "        assert result.size() == (\n",
    "            self.out_features,\n",
    "            self.in_features,\n",
    "            self.grid_size + self.spline_order,\n",
    "        )\n",
    "        return result.contiguous()\n",
    "\n",
    "    @property\n",
    "    def scaled_spline_weight(self):\n",
    "        return self.spline_weight * (\n",
    "            self.spline_scaler.unsqueeze(-1)\n",
    "            if self.enable_standalone_scale_spline\n",
    "            else 1.0\n",
    "        )\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        assert x.size(-1) == self.in_features\n",
    "        original_shape = x.shape\n",
    "        x = x.reshape(-1, self.in_features)\n",
    "\n",
    "        base_output = F.linear(self.base_activation(x), self.base_weight)\n",
    "        spline_output = F.linear(\n",
    "            self.b_splines(x).view(x.size(0), -1),\n",
    "            self.scaled_spline_weight.view(self.out_features, -1),\n",
    "        )\n",
    "        output = base_output + spline_output\n",
    "        \n",
    "        output = output.reshape(*original_shape[:-1], self.out_features)\n",
    "        return output\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def update_grid(self, x: torch.Tensor, margin=0.01):\n",
    "        assert x.dim() == 2 and x.size(1) == self.in_features\n",
    "        batch = x.size(0)\n",
    "\n",
    "        splines = self.b_splines(x)  # (batch, in, coeff)\n",
    "        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)\n",
    "        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)\n",
    "        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)\n",
    "        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)\n",
    "        unreduced_spline_output = unreduced_spline_output.permute(\n",
    "            1, 0, 2\n",
    "        )  # (batch, in, out)\n",
    "\n",
    "        # sort each channel individually to collect data distribution\n",
    "        x_sorted = torch.sort(x, dim=0)[0]\n",
    "        grid_adaptive = x_sorted[\n",
    "            torch.linspace(\n",
    "                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device\n",
    "            )\n",
    "        ]\n",
    "\n",
    "        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size\n",
    "        grid_uniform = (\n",
    "            torch.arange(\n",
    "                self.grid_size + 1, dtype=torch.float32, device=x.device\n",
    "            ).unsqueeze(1)\n",
    "            * uniform_step\n",
    "            + x_sorted[0]\n",
    "            - margin\n",
    "        )\n",
    "\n",
    "        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive\n",
    "        grid = torch.concatenate(\n",
    "            [\n",
    "                grid[:1]\n",
    "                - uniform_step\n",
    "                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),\n",
    "                grid,\n",
    "                grid[-1:]\n",
    "                + uniform_step\n",
    "                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),\n",
    "            ],\n",
    "            dim=0,\n",
    "        )\n",
    "\n",
    "        self.grid.copy_(grid.T)\n",
    "        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))\n",
    "\n",
    "    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):\n",
    "        \"\"\"\n",
    "        Compute the regularization loss.\n",
    "\n",
    "        This is a dumb simulation of the original L1 regularization as stated in the\n",
    "        paper, since the original one requires computing absolutes and entropy from the\n",
    "        expanded (batch, in_features, out_features) intermediate tensor, which is hidden\n",
    "        behind the F.linear function if we want an memory efficient implementation.\n",
    "\n",
    "        The L1 regularization is now computed as mean absolute value of the spline\n",
    "        weights. The authors implementation also includes this term in addition to the\n",
    "        sample-based regularization.\n",
    "        \"\"\"\n",
    "        l1_fake = self.spline_weight.abs().mean(-1)\n",
    "        regularization_loss_activation = l1_fake.sum()\n",
    "        p = l1_fake / regularization_loss_activation\n",
    "        regularization_loss_entropy = -torch.sum(p * p.log())\n",
    "        return (\n",
    "            regularize_activation * regularization_loss_activation\n",
    "            + regularize_entropy * regularization_loss_entropy\n",
    "        )\n",
    "\n",
    "\n",
    "class KAN(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        layers_hidden,\n",
    "        grid_size=5,\n",
    "        spline_order=3,\n",
    "        scale_noise=0.1,\n",
    "        scale_base=1.0,\n",
    "        scale_spline=1.0,\n",
    "        base_activation=torch.nn.SiLU,\n",
    "        grid_eps=0.02,\n",
    "        grid_range=[-1, 1],\n",
    "    ):\n",
    "        super(KAN, self).__init__()\n",
    "        self.grid_size = grid_size\n",
    "        self.spline_order = spline_order\n",
    "\n",
    "        self.layers = torch.nn.ModuleList()\n",
    "        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):\n",
    "            self.layers.append(\n",
    "                KANLinear(\n",
    "                    in_features,\n",
    "                    out_features,\n",
    "                    grid_size=grid_size,\n",
    "                    spline_order=spline_order,\n",
    "                    scale_noise=scale_noise,\n",
    "                    scale_base=scale_base,\n",
    "                    scale_spline=scale_spline,\n",
    "                    base_activation=base_activation,\n",
    "                    grid_eps=grid_eps,\n",
    "                    grid_range=grid_range,\n",
    "                )\n",
    "            )\n",
    "\n",
    "    def forward(self, x: torch.Tensor, update_grid=False):\n",
    "        for layer in self.layers:\n",
    "            if update_grid:\n",
    "                layer.update_grid(x)\n",
    "            x = layer(x)\n",
    "        return x\n",
    "\n",
    "    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):\n",
    "        return sum(\n",
    "            layer.regularization_loss(regularize_activation, regularize_entropy)\n",
    "            for layer in self.layers\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5b97921c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 938/938 [00:06<00:00, 136.33it/s, accuracy=0.875, loss=0.314, lr=0.001] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Val Loss: 0.23603566991058506, Val Accuracy: 0.9290406050955414\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 938/938 [00:06<00:00, 136.44it/s, accuracy=1, loss=0.0354, lr=0.0008]    \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2, Val Loss: 0.15943998358194617, Val Accuracy: 0.9548168789808917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 938/938 [00:06<00:00, 139.90it/s, accuracy=0.969, loss=0.0508, lr=0.00064]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3, Val Loss: 0.14269541657179785, Val Accuracy: 0.9600915605095541\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 938/938 [00:07<00:00, 118.11it/s, accuracy=0.969, loss=0.134, lr=0.000512] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4, Val Loss: 0.11162164308735804, Val Accuracy: 0.9666600318471338\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 938/938 [00:09<00:00, 101.68it/s, accuracy=1, loss=0.00867, lr=0.00041]   \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5, Val Loss: 0.10873399607119429, Val Accuracy: 0.9674562101910829\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 938/938 [00:07<00:00, 125.10it/s, accuracy=1, loss=0.00895, lr=0.000328]   \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6, Val Loss: 0.09907284306709292, Val Accuracy: 0.9704418789808917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 938/938 [00:07<00:00, 126.58it/s, accuracy=1, loss=0.0279, lr=0.000262]    \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7, Val Loss: 0.09754395741557073, Val Accuracy: 0.9692476114649682\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 938/938 [00:08<00:00, 109.79it/s, accuracy=0.969, loss=0.0769, lr=0.00021]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8, Val Loss: 0.09209574455362715, Val Accuracy: 0.9721337579617835\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 938/938 [00:08<00:00, 105.37it/s, accuracy=1, loss=0.0155, lr=0.000168]    \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9, Val Loss: 0.0904715131185237, Val Accuracy: 0.9710390127388535\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 938/938 [00:07<00:00, 120.90it/s, accuracy=0.969, loss=0.078, lr=0.000134] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10, Val Loss: 0.08993846593294175, Val Accuracy: 0.9722332802547771\n"
     ]
    }
   ],
   "source": [
    "# Train on MNIST\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Load MNIST\n",
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]\n",
    ")\n",
    "trainset = torchvision.datasets.MNIST(\n",
    "    root=\"./data\", train=True, download=True, transform=transform\n",
    ")\n",
    "valset = torchvision.datasets.MNIST(\n",
    "    root=\"./data\", train=False, download=True, transform=transform\n",
    ")\n",
    "trainloader = DataLoader(trainset, batch_size=64, shuffle=True)\n",
    "valloader = DataLoader(valset, batch_size=64, shuffle=False)\n",
    "\n",
    "model = KAN([28 * 28, 64, 10])\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)\n",
    "scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "for epoch in range(10):\n",
    "    # Train\n",
    "    model.train()\n",
    "    with tqdm(trainloader) as pbar:\n",
    "        for i, (images, labels) in enumerate(pbar):\n",
    "            images = images.view(-1, 28 * 28).to(device)\n",
    "            optimizer.zero_grad()\n",
    "            output = model(images)\n",
    "            loss = criterion(output, labels.to(device))\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()\n",
    "            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])\n",
    "\n",
    "    # Validation\n",
    "    model.eval()\n",
    "    val_loss = 0\n",
    "    val_accuracy = 0\n",
    "    with torch.no_grad():\n",
    "        for images, labels in valloader:\n",
    "            images = images.view(-1, 28 * 28).to(device)\n",
    "            output = model(images)\n",
    "            val_loss += criterion(output, labels.to(device)).item()\n",
    "            val_accuracy += (\n",
    "                (output.argmax(dim=1) == labels.to(device)).float().mean().item()\n",
    "            )\n",
    "    val_loss /= len(valloader)\n",
    "    val_accuracy /= len(valloader)\n",
    "\n",
    "    # Update learning rate\n",
    "    scheduler.step()\n",
    "\n",
    "    print(\n",
    "        f\"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1bebe768",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original  - val_loss=0.0903, val_acc=0.9721\n",
      "Kept node indices per hidden layer: [[0, 1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 15, 16, 18, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 35, 36, 37, 38, 39, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 60, 61, 62]]\n",
      "Pruned    - val_loss=0.1350, val_acc=0.9581\n"
     ]
    }
   ],
   "source": [
    "@torch.no_grad()\n",
    "def eval_model(model, loader, device):\n",
    "    model.eval()\n",
    "    ce = nn.CrossEntropyLoss()\n",
    "    tot_loss, tot_corr, tot_n = 0.0, 0, 0\n",
    "    for x, y in loader:\n",
    "        x, y = x.view(x.size(0), -1).to(device), y.to(device)\n",
    "        logits = model(x)\n",
    "        loss = ce(logits, y)\n",
    "        tot_loss += loss.item() * x.size(0)\n",
    "        tot_corr += (logits.argmax(1) == y).sum().item()\n",
    "        tot_n += x.size(0)\n",
    "    return tot_loss / tot_n, tot_corr / tot_n\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def kan_node_scores(model: KAN):\n",
    "    scores = []\n",
    "    L = len(model.layers)  # number of KANLinear layers\n",
    "    # per-layer |phi|_1 matrix (out, in)\n",
    "    phi_l1 = []\n",
    "    for layer in model.layers:\n",
    "        # use scaled spline weights (includes spline_scaler if present)\n",
    "        sw = layer.scaled_spline_weight  # (out, in, coeff)\n",
    "        m = sw.abs().mean(dim=-1)        # (out, in)\n",
    "        phi_l1.append(m)\n",
    "\n",
    "    for l in range(1, L):\n",
    "        prev = model.layers[l-1]             # maps n_{l-1} -> n_l\n",
    "        prev_m = phi_l1[l-1]                 # shape (n_l, n_{l-1})\n",
    "        next_m = phi_l1[l] if l < L else None\n",
    "\n",
    "        I = prev_m.max(dim=1).values  # (n_l,)\n",
    "        O = phi_l1[l].max(dim=0).values if l < L else torch.zeros_like(I)\n",
    "        score = torch.maximum(I, O)\n",
    "        scores.append(score.cpu())\n",
    "    return scores  # length = #hidden layers\n",
    "\n",
    "\n",
    "def EMP_kan(score_1d: torch.Tensor):\n",
    "    s = score_1d.clone().float()\n",
    "    s[s < 0] = 0\n",
    "    if s.sum() <= 0:\n",
    "        s = torch.ones_like(s)\n",
    "    w = s / s.sum()\n",
    "    w_sorted, idx_sorted = torch.sort(w, descending=True)\n",
    "    neff = int(torch.floor(1.0 / torch.sum(w**2)).item())\n",
    "    neff = max(1, min(neff, w.numel()))\n",
    "    keep_idx = idx_sorted[:neff]\n",
    "    keep_idx, _ = torch.sort(keep_idx)  # stable order\n",
    "    return keep_idx.tolist(), neff\n",
    "\n",
    "\n",
    "def _new_kan_like(old: KAN, layers_hidden):\n",
    "    return KAN(\n",
    "        layers_hidden=layers_hidden,\n",
    "        grid_size=old.grid_size,\n",
    "        spline_order=old.spline_order,\n",
    "        scale_noise=old.layers[0].scale_noise,\n",
    "        scale_base=old.layers[0].scale_base,\n",
    "        scale_spline=old.layers[0].scale_spline,\n",
    "        base_activation=type(old.layers[0].base_activation),\n",
    "        grid_eps=old.layers[0].grid_eps,\n",
    "        grid_range=[old.layers[0].grid[0,0].item(), old.layers[0].grid[0,-1].item()],\n",
    "    )\n",
    "\n",
    "@torch.no_grad()\n",
    "def _slice_layer(old_layer: KANLinear, keep_in, keep_out):\n",
    "    new_layer = KANLinear(\n",
    "        in_features=len(keep_in),\n",
    "        out_features=len(keep_out),\n",
    "        grid_size=old_layer.grid_size,\n",
    "        spline_order=old_layer.spline_order,\n",
    "        scale_noise=old_layer.scale_noise,\n",
    "        scale_base=old_layer.scale_base,\n",
    "        scale_spline=old_layer.scale_spline,\n",
    "        enable_standalone_scale_spline=old_layer.enable_standalone_scale_spline,\n",
    "        base_activation=type(old_layer.base_activation),\n",
    "        grid_eps=old_layer.grid_eps,\n",
    "        grid_range=[old_layer.grid[0,0].item(), old_layer.grid[0,-1].item()],\n",
    "    )\n",
    "    device = old_layer.base_weight.device\n",
    "    new_layer = new_layer.to(device)\n",
    "\n",
    "    ki = torch.as_tensor(keep_in, dtype=torch.long, device=device)\n",
    "    ko = torch.as_tensor(keep_out, dtype=torch.long, device=device)\n",
    "\n",
    "    # base weights\n",
    "    new_layer.base_weight.data.copy_(old_layer.base_weight.data.index_select(0, ko).index_select(1, ki))\n",
    "    # spline weights\n",
    "    new_layer.spline_weight.data.copy_(old_layer.spline_weight.data.index_select(0, ko).index_select(1, ki))\n",
    "    # spline scaler (if exists)\n",
    "    if old_layer.enable_standalone_scale_spline:\n",
    "        new_layer.spline_scaler.data.copy_(old_layer.spline_scaler.data.index_select(0, ko).index_select(1, ki))\n",
    "    # copy the selected input grids\n",
    "    new_layer.grid.data.copy_(old_layer.grid.index_select(0, ki).contiguous())\n",
    "    return new_layer\n",
    "\n",
    "@torch.no_grad()\n",
    "def prune_kan_by_neff(model: KAN, device):\n",
    "    model.eval()\n",
    "    # 1) node scores per hidden layer\n",
    "    scores = kan_node_scores(model)  # list of tensors\n",
    "    keep_lists = []\n",
    "    for s in scores:\n",
    "        keep, neff = EMP_kan(s)\n",
    "        keep_lists.append(keep)\n",
    "\n",
    "    # 2) new widths\n",
    "    widths_old = [model.layers[0].in_features] + [ly.out_features for ly in model.layers]\n",
    "    hidden_kept = [len(k) for k in keep_lists]\n",
    "    widths_new = [widths_old[0]] + hidden_kept + [widths_old[-1]]\n",
    "\n",
    "    # 3) construct new model and copy slices\n",
    "    pruned = _new_kan_like(model, widths_new).to(device)\n",
    "    # layer 0: keep_out = keep_lists[0], keep_in = all inputs\n",
    "    all_in0 = list(range(widths_old[0]))\n",
    "    pruned.layers[0] = _slice_layer(model.layers[0], keep_in=all_in0, keep_out=keep_lists[0])\n",
    "\n",
    "    # middle layers (if any)\n",
    "    for l in range(1, len(model.layers)-1):\n",
    "        keep_in = keep_lists[l-1]\n",
    "        keep_out = keep_lists[l]\n",
    "        pruned.layers[l] = _slice_layer(model.layers[l], keep_in=keep_in, keep_out=keep_out)\n",
    "\n",
    "    # last layer: keep_in = last hidden keep, keep_out = all outputs\n",
    "    last = len(model.layers)-1\n",
    "    keep_in_last = keep_lists[-1]\n",
    "    all_out_last = list(range(model.layers[last].out_features))\n",
    "    pruned.layers[last] = _slice_layer(model.layers[last], keep_in=keep_in_last, keep_out=all_out_last)\n",
    "\n",
    "    return pruned, keep_lists\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.eval()\n",
    "orig_val_loss, orig_val_acc = eval_model(model, valloader, device)\n",
    "print(f\"Original  - val_loss={orig_val_loss:.4f}, val_acc={orig_val_acc:.4f}\")\n",
    "\n",
    "pruned_model, kept = prune_kan_by_neff(model, device)\n",
    "pruned_val_loss, pruned_val_acc = eval_model(pruned_model, valloader, device)\n",
    "print(\"Kept node indices per hidden layer:\", kept)\n",
    "print(f\"Pruned    - val_loss={pruned_val_loss:.4f}, val_acc={pruned_val_acc:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93e90870",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
