{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "%%writefile train_entropy_pruned_distilbert_advanced.py\n",
        "\n",
        "import time\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import numpy as np\n",
        "from torch.utils.data import DataLoader\n",
        "import random\n",
        "from datasets import load_dataset\n",
        "from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler\n",
        "from torch.optim import AdamW\n",
        "from sklearn.metrics import accuracy_score, roc_auc_score\n",
        "from tqdm import tqdm\n",
        "torch.manual_seed(42)\n",
        "random.seed(42)\n",
        "np.random.seed(42)\n",
        "\n",
        "# -------------------\n",
        "# Trainable Entropy Gate\n",
        "# -------------------\n",
        "class EntropyGate(nn.Module):\n",
        "    def __init__(self, hidden_size=768, keep_ratio=0.75):\n",
        "        super().__init__()\n",
        "        self.keep_ratio = keep_ratio\n",
        "        self.scorer = nn.Linear(hidden_size, 2)  # learnable\n",
        "        self.budget_lambda = 0.1  # regularization strength\n",
        "\n",
        "    def forward(self, hidden_states, attention_mask, labels=None):\n",
        "        logits = self.scorer(hidden_states)          # [B, L, 2]\n",
        "        probs = F.softmax(logits, dim=-1)            # [B, L, 2]\n",
        "        ent = -(probs * probs.log()).sum(dim=-1)     # [B, L]\n",
        "        k = max(1, int(self.keep_ratio * hidden_states.size(1)))\n",
        "        topk_idx = ent.topk(k, dim=-1, largest=False).indices\n",
        "        mask = torch.zeros_like(ent, dtype=torch.bool)\n",
        "        mask.scatter_(1, topk_idx, True)\n",
        "        pruned_states = hidden_states * mask.unsqueeze(-1)\n",
        "\n",
        "        # Budget loss to encourage ~ρ fraction kept\n",
        "        if labels is not None:\n",
        "            avg_keep = mask.float().mean()\n",
        "            budget_loss = self.budget_lambda * (avg_keep - self.keep_ratio) ** 2\n",
        "        else:\n",
        "            budget_loss = 0.0\n",
        "\n",
        "        return pruned_states, attention_mask, budget_loss\n",
        "\n",
        "\n",
        "# -------------------\n",
        "# DistilBERT + Gate\n",
        "# -------------------\n",
        "from transformers.models.distilbert.modeling_distilbert import DistilBertModel\n",
        "\n",
        "class DistilBertWithGate(nn.Module):\n",
        "    def __init__(self, keep_ratio=0.5, num_labels=2):\n",
        "        super().__init__()\n",
        "        self.keep_ratio = keep_ratio\n",
        "        self.bert = DistilBertModel.from_pretrained(\"distilbert-base-uncased\")\n",
        "        self.gate = EntropyGate(hidden_size=self.bert.config.dim, keep_ratio=keep_ratio)\n",
        "        self.classifier = nn.Linear(self.bert.config.dim, num_labels)\n",
        "\n",
        "    def forward(self, input_ids, attention_mask, labels=None):\n",
        "        outputs = self.bert(input_ids=input_ids,\n",
        "                            attention_mask=attention_mask,\n",
        "                            output_hidden_states=True)\n",
        "        hidden_states = outputs.hidden_states[1]  # after first block\n",
        "        gated_states, _, budget_loss = self.gate(hidden_states, attention_mask, labels)\n",
        "        outputs.last_hidden_state[:, :, :] = gated_states\n",
        "        pooled = outputs.last_hidden_state[:, 0]\n",
        "        logits = self.classifier(pooled)\n",
        "\n",
        "        loss = None\n",
        "        if labels is not None:\n",
        "            ce_loss = F.cross_entropy(logits, labels)\n",
        "            loss = ce_loss + budget_loss\n",
        "        return {\"loss\": loss, \"logits\": logits}\n",
        "\n",
        "\n",
        "# -------------------\n",
        "# Metrics + Efficiency\n",
        "# -------------------\n",
        "def compute_metrics(preds, labels):\n",
        "    acc = accuracy_score(labels, preds.argmax(-1))\n",
        "    try:\n",
        "        auc = roc_auc_score(labels, F.softmax(torch.tensor(preds), -1)[:, 1])\n",
        "    except:\n",
        "        auc = 0.0\n",
        "    return {\"acc\": acc, \"auc\": auc}\n",
        "\n",
        "def flops_proxy(seq_len, hidden_dim=768, rho=1.0, layers=12):\n",
        "    \"\"\"\n",
        "    Analytical FLOPs estimate:\n",
        "    FLOPs ~ layers * (L^2 * d) for attention.\n",
        "    If pruning applies after 1 layer, adjust accordingly.\n",
        "    \"\"\"\n",
        "    # here: first layer always full, later layers pruned\n",
        "    return ( (seq_len**2) * hidden_dim ) + ( (layers-1) * ((rho*seq_len)**2) * hidden_dim )\n",
        "\n",
        "# -------------------\n",
        "# Training & Evaluation\n",
        "# -------------------\n",
        "def train(model, loader, device, epochs=2):\n",
        "    optimizer = AdamW(model.parameters(), lr=5e-5)\n",
        "    num_training_steps = len(loader) * epochs\n",
        "    lr_scheduler = get_scheduler(\"linear\", optimizer=optimizer,\n",
        "                                 num_warmup_steps=0,\n",
        "                                 num_training_steps=num_training_steps)\n",
        "    model.train()\n",
        "    for epoch in range(epochs):\n",
        "        loop = tqdm(loader, desc=f\"Epoch {epoch+1}\")\n",
        "        for batch in loop:\n",
        "            batch = {k: v.to(device) for k, v in batch.items()}\n",
        "            if \"label\" in batch:\n",
        "                batch[\"labels\"] = batch.pop(\"label\")\n",
        "            outputs = model(**batch)\n",
        "            loss = outputs[\"loss\"]\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "            lr_scheduler.step()\n",
        "            optimizer.zero_grad()\n",
        "            loop.set_postfix(loss=loss.item())\n",
        "\n",
        "\n",
        "def evaluate(model, loader, device):\n",
        "    model.eval()\n",
        "    preds, labels = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in loader:\n",
        "            batch = {k: v.to(device) for k, v in batch.items()}\n",
        "            if \"label\" in batch:\n",
        "                batch[\"labels\"] = batch.pop(\"label\")\n",
        "            outputs = model(**batch)\n",
        "            logits = outputs[\"logits\"].cpu()\n",
        "            preds.append(logits)\n",
        "            labels.append(batch[\"labels\"].cpu())\n",
        "    preds = torch.cat(preds)\n",
        "    labels = torch.cat(labels)\n",
        "    return compute_metrics(preds, labels)\n",
        "\n",
        "\n",
        "# -------------------\n",
        "# Main\n",
        "# -------------------\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    dataset = load_dataset(\"glue\", \"sst2\")\n",
        "    tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n",
        "\n",
        "    def tokenize_fn(ex):\n",
        "        return tokenizer(ex[\"sentence\"], truncation=True,\n",
        "                         padding=\"max_length\", max_length=128)\n",
        "\n",
        "    dataset = dataset.map(tokenize_fn, batched=True)\n",
        "    dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"])\n",
        "\n",
        "    train_loader = DataLoader(dataset[\"train\"], batch_size=16, shuffle=True)\n",
        "    val_loader = DataLoader(dataset[\"validation\"], batch_size=32)\n",
        "\n",
        "    # baseline\n",
        "    baseline = AutoModelForSequenceClassification.from_pretrained(\n",
        "        \"distilbert-base-uncased\", num_labels=2\n",
        "    ).to(device)\n",
        "    train(baseline, train_loader, device, epochs=1)\n",
        "    base_metrics = evaluate(baseline, val_loader, device)\n",
        "\n",
        "    # proposed\n",
        "    rho = 0.75\n",
        "    pruned = DistilBertWithGate(keep_ratio=0.75, num_labels=2).to(device)\n",
        "    train(pruned, train_loader, device, epochs=1)\n",
        "    pruned_metrics = evaluate(pruned, val_loader, device)\n",
        "\n",
        "    # FLOPs estimation\n",
        "    seq_len = 128\n",
        "    hidden_dim = 768\n",
        "    layers = 12\n",
        "    flops_base = flops_proxy(seq_len, hidden_dim, rho=1.0, layers=layers)\n",
        "    flops_pruned = flops_proxy(seq_len, hidden_dim, rho=rho, layers=layers)\n",
        "    flops_red = (1 - flops_pruned / flops_base) * 100\n",
        "\n",
        "    print(\"\\n==== Validation Results (SST-2) ====\")\n",
        "    print(f\"Baseline: Acc={base_metrics['acc']:.4f}, AUC={base_metrics['auc']:.4f}, FLOPs={flops_base:.2e}\")\n",
        "    print(f\"Proposed: Acc={pruned_metrics['acc']:.4f}, AUC={pruned_metrics['auc']:.4f}, \"\n",
        "          f\"FLOPs={flops_pruned:.2e} ({flops_red:.1f}%↓)\")\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "CV7fJpWkV-8D",
        "outputId": "a7e2ca68-0e1c-4f54-9f8e-160f3fb7ba4c"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Overwriting train_entropy_pruned_distilbert_advanced.py\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!python train_entropy_pruned_distilbert_advanced.py\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "aufRQKQwWIeH",
        "outputId": "233ea973-7424-494b-a102-e7469b1ef1a8"
      },
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "2025-09-08 21:58:20.037590: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
            "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
            "E0000 00:00:1757368700.057212   14558 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
            "E0000 00:00:1757368700.063147   14558 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
            "W0000 00:00:1757368700.078061   14558 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
            "W0000 00:00:1757368700.078084   14558 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
            "W0000 00:00:1757368700.078088   14558 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
            "W0000 00:00:1757368700.078091   14558 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
            "2025-09-08 21:58:20.082544: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
            "To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
            "Map: 100% 872/872 [00:00<00:00, 7526.99 examples/s]\n",
            "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Epoch 1: 100% 4210/4210 [12:07<00:00,  5.79it/s, loss=0.0158]\n",
            "/content/train_entropy_pruned_distilbert_advanced.py:84: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  auc = roc_auc_score(labels, F.softmax(torch.tensor(preds), -1)[:, 1])\n",
            "Epoch 1: 100% 4210/4210 [11:50<00:00,  5.92it/s, loss=0.916]\n",
            "/content/train_entropy_pruned_distilbert_advanced.py:84: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  auc = roc_auc_score(labels, F.softmax(torch.tensor(preds), -1)[:, 1])\n",
            "\n",
            "==== Validation Results (SST-2) ====\n",
            "Baseline: Acc=0.9140, AUC=0.9725, FLOPs=1.51e+08\n",
            "Proposed: Acc=0.8268, AUC=0.8806, FLOPs=9.04e+07 (40.1%↓)\n"
          ]
        }
      ]
    }
  ]
}