{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6ENR6EQ2b8R8",
        "outputId": "fe08393f-a8c3-481b-fb12-9722919cf815"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Found existing installation: datasets 4.0.0\n",
            "Uninstalling datasets-4.0.0:\n",
            "  Successfully uninstalled datasets-4.0.0\n",
            "Collecting datasets==2.18.0\n",
            "  Downloading datasets-2.18.0-py3-none-any.whl.metadata (20 kB)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (3.19.1)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (2.0.2)\n",
            "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (18.1.0)\n",
            "Collecting pyarrow-hotfix (from datasets==2.18.0)\n",
            "  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)\n",
            "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (0.3.8)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (2.2.2)\n",
            "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (2.32.4)\n",
            "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (4.67.1)\n",
            "Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (3.5.0)\n",
            "Requirement already satisfied: multiprocess in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (0.70.16)\n",
            "Collecting fsspec<=2024.2.0,>=2023.1.0 (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets==2.18.0)\n",
            "  Downloading fsspec-2024.2.0-py3-none-any.whl.metadata (6.8 kB)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (3.12.15)\n",
            "Requirement already satisfied: huggingface-hub>=0.19.4 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (0.35.0)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (25.0)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (6.0.2)\n",
            "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (2.6.1)\n",
            "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (1.4.0)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (25.3.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (1.7.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (6.6.4)\n",
            "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (0.3.2)\n",
            "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (1.20.1)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.19.4->datasets==2.18.0) (4.15.0)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.19.4->datasets==2.18.0) (1.1.10)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->datasets==2.18.0) (3.4.3)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->datasets==2.18.0) (3.10)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->datasets==2.18.0) (2.5.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->datasets==2.18.0) (2025.8.3)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets==2.18.0) (2.9.0.post0)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets==2.18.0) (2025.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets==2.18.0) (2025.2)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.18.0) (1.17.0)\n",
            "Downloading datasets-2.18.0-py3-none-any.whl (510 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m510.5/510.5 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading fsspec-2024.2.0-py3-none-any.whl (170 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m170.9/170.9 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading pyarrow_hotfix-0.7-py3-none-any.whl (7.9 kB)\n",
            "Installing collected packages: pyarrow-hotfix, fsspec, datasets\n",
            "  Attempting uninstall: fsspec\n",
            "    Found existing installation: fsspec 2025.3.0\n",
            "    Uninstalling fsspec-2025.3.0:\n",
            "      Successfully uninstalled fsspec-2025.3.0\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.2.0 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed datasets-2.18.0 fsspec-2024.2.0 pyarrow-hotfix-0.7\n"
          ]
        }
      ],
      "source": [
        "!pip uninstall -y datasets\n",
        "!pip install datasets==2.18.0\n",
        "!pip install evaluate\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Prune Layers Based on Entropy Rate Using MNLI Dataset\n",
        "\n",
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8)\n",
        "                cos_squares.append(c2.item())\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(model.roberta.encoder.layer)]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "\n",
        "\n",
        "def preprocess_function(examples, tok, max_length=128):\n",
        "    return tok(examples['premise'],\n",
        "               examples['hypothesis'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=3).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    scaler = GradScaler()\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device), attention_mask=b['attention_mask'].to(device), labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx] += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx] for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune MNLI Acc: {acc:.4f}\")\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    for epoch in range(3):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device), attention_mask=b['attention_mask'].to(device), labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] MNLI Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad = False\n",
        "    for p in model.classifier.parameters(): p.requires_grad = True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad = True\n",
        "        l.B.requires_grad = True\n",
        "    opt = torch.optim.Adam(\n",
        "        list(model.classifier.parameters()) + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    scaler = GradScaler()\n",
        "    for epoch in range(3):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device), attention_mask=b['attention_mask'].to(device), labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] MNLI Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"mnli\", split=\"train[:5000]\").shuffle(seed)\n",
        "    dev_ds = load_dataset(\"glue\", \"mnli\", split=\"validation_matched[:1000]\")\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"premise\",\"hypothesis\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"premise\",\"hypothesis\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "\n",
        "\n",
        "\n",
        "    collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "hHX8whpCb9ec"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "ElRviRUwjPgg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "nl6TJHILjQVX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    \"\"\"\n",
        "    we watch each pair of adjacent\n",
        "    layers' output.dense activations.\n",
        "    \"\"\"\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    \"\"\"\n",
        "    For each adjacent layer‐pair idx,\n",
        "    approximate the conditional entropy rate via\n",
        "      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)\n",
        "    We return only the cosine‐sum term; the additive constant is the\n",
        "    same for all layers and can be dropped for pruning.\n",
        "    \"\"\"\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        # not enough history yet\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        # flatten across all non‐batch dims\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(\n",
        "                    dY[i].unsqueeze(0),\n",
        "                    dX[i].unsqueeze(0),\n",
        "                    dim=1, eps=1e-8\n",
        "                )  # [1]\n",
        "                cos_squares.append(c2.item())\n",
        "            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "\n",
        "        # shift history\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # sort descending by ER → highest‐entropy layers first\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level ER\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx]   += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level ER\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx]\n",
        "                    for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune MRPC Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] MRPC Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    # (unchanged LoRA stage)\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    # If you want to continue monitoring ER during LoRA, you can re-hook here.\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] MRPC Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # load & preprocess MRPC subset\n",
        "    train_ds = load_dataset(\"glue\", \"mrpc\", split=\"train\")\\\n",
        "               .shuffle(seed).select(range(1000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"mrpc\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader,\n",
        "                                       dev_loader,\n",
        "                                       device)\n",
        "    model = prune_and_finetuning(model,\n",
        "                                 train_loader,\n",
        "                                 dev_loader,\n",
        "                                 device,\n",
        "                                 er_scores)\n",
        "    lora_only_finetuning(model,\n",
        "                         train_loader,\n",
        "                         dev_loader,\n",
        "                         device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "3wlsElsZjjzU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Q1bHamnpb-S6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    \"\"\"\n",
        "    Exactly the same hooks as MIR: we watch each pair of adjacent\n",
        "    layers' output.dense activations.\n",
        "    \"\"\"\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    \"\"\"\n",
        "    For each adjacent layer‐pair idx,\n",
        "    approximate the conditional entropy rate via\n",
        "      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)\n",
        "    We return only the cosine‐sum term; the additive constant is the\n",
        "    same for all layers and can be dropped for pruning.\n",
        "    \"\"\"\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        # not enough history yet\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        # flatten across all non‐batch dims\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(\n",
        "                    dY[i].unsqueeze(0),\n",
        "                    dX[i].unsqueeze(0),\n",
        "                    dim=1, eps=1e-8\n",
        "                )  # [1]\n",
        "                cos_squares.append(c2.item())\n",
        "            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "\n",
        "        # shift history\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # sort descending by ER → highest‐entropy layers first\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level ER\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx]   += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level ER\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx]\n",
        "                    for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune SST2 Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] SST2 Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    # (unchanged LoRA stage)\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    # If you want to continue monitoring ER during LoRA, you can re-hook here.\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] SST2 Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "\n",
        "# load & preprocess SST-2 subset\n",
        "    train_ds = load_dataset(\"glue\", \"sst2\", split=\"train\").shuffle(seed).select(range(5000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"sst2\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: tokenizer(ex[\"sentence\"], truncation=True, padding='max_length', max_length=64),\n",
        "                         batched=True)\\\n",
        "                    .rename_column(\"label\", \"labels\")\\\n",
        "                    .remove_columns([\"sentence\", \"idx\"])\n",
        "    dev = dev_ds.map(lambda ex: tokenizer(ex[\"sentence\"], truncation=True, padding='max_length', max_length=64),\n",
        "                         batched=True)\\\n",
        "                    .rename_column(\"label\", \"labels\")\\\n",
        "                    .remove_columns([\"sentence\", \"idx\"])\n",
        "\n",
        "\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader,\n",
        "                                       dev_loader,\n",
        "                                       device)\n",
        "    model = prune_and_finetuning(model,\n",
        "                                 train_loader,\n",
        "                                 dev_loader,\n",
        "                                 device,\n",
        "                                 er_scores)\n",
        "    lora_only_finetuning(model,\n",
        "                         train_loader,\n",
        "                         dev_loader,\n",
        "                         device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "SOnZT6CA8O1h"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import numpy as np\n",
        "import random\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ----------- Entropy Rate / Hook Utilities (Theorem 2) -----------\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(\n",
        "                    dY[i].unsqueeze(0),\n",
        "                    dX[i].unsqueeze(0),\n",
        "                    dim=1, eps=1e-8\n",
        "                )\n",
        "                cos_squares.append(c2.item())\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "    return er_scores\n",
        "\n",
        "def estimate_entropy_rate(model, loader, device, max_batches=10):\n",
        "    model.eval()\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    er_sum, er_count = {}, {}\n",
        "    loader_iter = iter(loader)   # <-- FIXED LINE\n",
        "    for _ in range(max_batches):\n",
        "        try:\n",
        "            batch = next(loader_iter)\n",
        "        except StopIteration:\n",
        "            break\n",
        "        _ = model(input_ids=batch['input_ids'].to(device),\n",
        "                  attention_mask=batch['attention_mask'].to(device))\n",
        "        batch_er = compute_batch_entropy(activations)\n",
        "        for idx, val in batch_er.items():\n",
        "            er_sum[idx] = er_sum.get(idx, 0.0) + val\n",
        "            er_count[idx] = er_count.get(idx, 0) + 1\n",
        "    remove_hooks(hooks)\n",
        "    er_avg = {k: (er_sum[k]/er_count[k] if er_count[k] > 0 else 0.0) for k in er_sum}\n",
        "    return er_avg\n",
        "\n",
        "\n",
        "# ----------- Pruning Utilities -----------\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ----------- Standard Preprocessing & Training -----------\n",
        "\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence'], truncation=True, padding='max_length', max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "def finetune_model(model, train_loader, dev_loader, device, epochs):\n",
        "    model.train()\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(\n",
        "        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*epochs)\n",
        "    scaler = GradScaler()\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] SST-2 Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"sst2\", split=\"train\").shuffle(seed).select(range(5000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"sst2\", split=\"validation\")\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=[\"sentence\"]).rename_column(\"label\", \"labels\")\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=[\"sentence\"]).rename_column(\"label\", \"labels\")\n",
        "\n",
        "    collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    print(\"\\n=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=2).to(device)\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=6)\n",
        "    acc_full = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}\")\n",
        "\n",
        "    print(\"\\n=== Stage 2: Entropy Rate Pruning (Remove 4 highest-ER layers) ===\")\n",
        "    dev_iter = iter(dev_loader)\n",
        "    er_scores = estimate_entropy_rate(model, dev_loader, device, max_batches=10)\n",
        "    print(\"Layer-wise Entropy Rate:\", er_scores)\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(f\"Pruned layer indices: {prune_idxs}\")\n",
        "\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=5)\n",
        "    acc_pruned = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "joUSON5PFo6W"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    \"\"\"\n",
        "    Exactly the same hooks as MIR: we watch each pair of adjacent\n",
        "    layers' output.dense activations.\n",
        "    \"\"\"\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    \"\"\"\n",
        "    For each adjacent layer‐pair idx,\n",
        "    approximate the conditional entropy rate via\n",
        "      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)\n",
        "    We return only the cosine‐sum term; the additive constant is the\n",
        "    same for all layers and can be dropped for pruning.\n",
        "    \"\"\"\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        # not enough history yet\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        # flatten across all non‐batch dims\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(\n",
        "                    dY[i].unsqueeze(0),\n",
        "                    dX[i].unsqueeze(0),\n",
        "                    dim=1, eps=1e-8\n",
        "                )  # [1]\n",
        "                cos_squares.append(c2.item())\n",
        "            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "\n",
        "        # shift history\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # sort descending by ER → highest‐entropy layers first\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level ER\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx]   += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level ER\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx]\n",
        "                    for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune CoLA Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] CoLA Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    # (unchanged LoRA stage)\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    # If you want to continue monitoring ER during LoRA, you can re-hook here.\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] CoLA Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "\n",
        "      # load & preprocess CoLA subset\n",
        "    train_ds = load_dataset(\"glue\", \"cola\", split=\"train\").shuffle(seed).select(range(5000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"cola\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: tokenizer(ex[\"sentence\"], truncation=True, padding='max_length', max_length=64),\n",
        "                         batched=True)\\\n",
        "                    .rename_column(\"label\", \"labels\")\\\n",
        "                    .remove_columns([\"sentence\", \"idx\"])\n",
        "    dev = dev_ds.map(lambda ex: tokenizer(ex[\"sentence\"], truncation=True, padding='max_length', max_length=64),\n",
        "                     batched=True)\\\n",
        "                .rename_column(\"label\", \"labels\")\\\n",
        "                .remove_columns([\"sentence\", \"idx\"])\n",
        "\n",
        "\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader,\n",
        "                                       dev_loader,\n",
        "                                       device)\n",
        "    model = prune_and_finetuning(model,\n",
        "                                 train_loader,\n",
        "                                 dev_loader,\n",
        "                                 device,\n",
        "                                 er_scores)\n",
        "    lora_only_finetuning(model,\n",
        "                         train_loader,\n",
        "                         dev_loader,\n",
        "                         device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "Lh5rJFMZkoe0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "GUXz-L1DiH-j"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    \"\"\"\n",
        "    Exactly the same hooks as MIR: we watch each pair of adjacent\n",
        "    layers' output.dense activations.\n",
        "    \"\"\"\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    \"\"\"\n",
        "    For each adjacent layer‐pair idx,\n",
        "    approximate the conditional entropy rate via\n",
        "      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)\n",
        "    We return only the cosine‐sum term; the additive constant is the\n",
        "    same for all layers and can be dropped for pruning.\n",
        "    \"\"\"\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        # not enough history yet\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        # flatten across all non‐batch dims\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(\n",
        "                    dY[i].unsqueeze(0),\n",
        "                    dX[i].unsqueeze(0),\n",
        "                    dim=1, eps=1e-8\n",
        "                ) # [1]\n",
        "                cos_squares.append(c2.item())\n",
        "            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "\n",
        "        # shift history\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # sort descending by ER → highest‐entropy layers first\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level ER\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx]   += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level ER\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx]\n",
        "                    for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune QNLI Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] QNLI Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    # (unchanged LoRA stage)\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    # If you want to continue monitoring ER during LoRA, you can re-hook here.\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] QNLI Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # Load & preprocess QNLI\n",
        "    from datasets import load_dataset\n",
        "    train_ds = load_dataset(\"glue\", \"qnli\", split=\"train\").shuffle(seed).select(range(5000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"qnli\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    def preprocess(examples):\n",
        "        return tokenizer(examples[\"question\"],\n",
        "                         examples[\"sentence\"],\n",
        "                         truncation=True,\n",
        "                         padding='max_length',\n",
        "                         max_length=128)\n",
        "\n",
        "    train = train_ds.map(preprocess, batched=True)\\\n",
        "                    .rename_column(\"label\", \"labels\")\\\n",
        "                    .remove_columns([\"question\", \"sentence\", \"idx\"])\n",
        "    dev = dev_ds.map(preprocess, batched=True)\\\n",
        "                .rename_column(\"label\", \"labels\")\\\n",
        "                .remove_columns([\"question\", \"sentence\", \"idx\"])\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "Y0s4HbYJTlfC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    \"\"\"\n",
        "    Exactly the same hooks as MIR: we watch each pair of adjacent\n",
        "    layers' output.dense activations.\n",
        "    \"\"\"\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    \"\"\"\n",
        "    For each adjacent layer‐pair idx,\n",
        "    approximate the conditional entropy rate via\n",
        "      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)\n",
        "    We return only the cosine‐sum term; the additive constant is the\n",
        "    same for all layers and can be dropped for pruning.\n",
        "    \"\"\"\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        # not enough history yet\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        # flatten across all non‐batch dims\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(\n",
        "                    dY[i].unsqueeze(0),\n",
        "                    dX[i].unsqueeze(0),\n",
        "                    dim=1, eps=1e-8\n",
        "                )  # [1]\n",
        "                cos_squares.append(c2.item())\n",
        "            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "\n",
        "        # shift history\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # sort descending by ER → highest‐entropy layers first\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level ER\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx]   += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level ER\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx]\n",
        "                    for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune QQP Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] QQP Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    # (unchanged LoRA stage)\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    # If you want to continue monitoring ER during LoRA, you can re-hook here.\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] QQP Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # Load & preprocess QQP\n",
        "    from datasets import load_dataset\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"qqp\", split=\"train\").shuffle(seed).select(range(5000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"qqp\", split=\"validation\")\n",
        "\n",
        "    def preprocess(examples):\n",
        "        return tokenizer(examples[\"question1\"],\n",
        "                         examples[\"question2\"],\n",
        "                         truncation=True,\n",
        "                         padding=\"max_length\",\n",
        "                         max_length=128)\n",
        "\n",
        "    train = train_ds.map(preprocess, batched=True)\\\n",
        "                    .rename_column(\"label\", \"labels\")\\\n",
        "                    .remove_columns([\"question1\", \"question2\", \"idx\"])\n",
        "    dev = dev_ds.map(preprocess, batched=True)\\\n",
        "                .rename_column(\"label\", \"labels\")\\\n",
        "                .remove_columns([\"question1\", \"question2\", \"idx\"])\n",
        "\n",
        "    collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "SkbNUsq3Uap6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    \"\"\"\n",
        "    Exactly the same hooks as MIR: we watch each pair of adjacent\n",
        "    layers' output.dense activations.\n",
        "    \"\"\"\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        if X_prev.size(0) != X_curr.size(0) or Y_prev.size(0) != Y_curr.size(0):\n",
        "            # Skip batch if sizes mismatch\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = [\n",
        "                F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8).item()\n",
        "                for i in range(1, B)\n",
        "            ]\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # sort descending by ER → highest‐entropy layers first\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level ER\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx]   += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level ER\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx]\n",
        "                    for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune RTE Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] RTE Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    # (unchanged LoRA stage)\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    # If you want to continue monitoring ER during LoRA, you can re-hook here.\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] RTE Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # Load & preprocess RTE\n",
        "    from datasets import load_dataset\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"rte\", split=\"train\").shuffle(seed)\n",
        "    dev_ds   = load_dataset(\"glue\", \"rte\", split=\"validation\")\n",
        "\n",
        "    def preprocess(examples):\n",
        "        return tokenizer(examples[\"sentence1\"],\n",
        "                         examples[\"sentence2\"],\n",
        "                         truncation=True,\n",
        "                         padding=\"max_length\",\n",
        "                         max_length=128)\n",
        "\n",
        "    train = train_ds.map(preprocess, batched=True)\\\n",
        "                    .rename_column(\"label\", \"labels\")\\\n",
        "                    .remove_columns([\"sentence1\", \"sentence2\", \"idx\"])\n",
        "    dev = dev_ds.map(preprocess, batched=True)\\\n",
        "                .rename_column(\"label\", \"labels\")\\\n",
        "                .remove_columns([\"sentence1\", \"sentence2\", \"idx\"])\n",
        "\n",
        "    collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "d7uteGgVU6ui"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "gAJz_p3jPzEo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ========================================================\n",
        "# 1) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "        if X_prev.size(0) != X_curr.size(0) or Y_prev.size(0) != Y_curr.size(0):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = [\n",
        "                F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8).item()\n",
        "                for i in range(1, B)\n",
        "            ]\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "    return er_scores\n",
        "\n",
        "def estimate_entropy_rate(model, loader, device, max_batches=10):\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    model.eval()\n",
        "    er_sums = {idx: 0.0 for idx in activations}\n",
        "    er_counts = {idx: 0 for idx in activations}\n",
        "    num_batches = 0\n",
        "    for batch in loader:\n",
        "        with torch.no_grad():\n",
        "            ids = batch['input_ids'].to(device)\n",
        "            mask = batch['attention_mask'].to(device)\n",
        "            _ = model(input_ids=ids, attention_mask=mask)\n",
        "        batch_er = compute_batch_entropy(activations)\n",
        "        for idx, v in batch_er.items():\n",
        "            if v != 0.0:\n",
        "                er_sums[idx] += v\n",
        "                er_counts[idx] += 1\n",
        "        num_batches += 1\n",
        "        if num_batches >= max_batches:\n",
        "            break\n",
        "    remove_hooks(hooks)\n",
        "    er_scores = {idx: er_sums[idx] / er_counts[idx] for idx in er_sums if er_counts[idx] > 0}\n",
        "    return er_scores\n",
        "\n",
        "# ========================================================\n",
        "# 2) Pruning Utilities with SkipFF (prune high-ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ========================================================\n",
        "# 3) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "# ========================================================\n",
        "# 4) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "# ========================================================\n",
        "# 5) Standard fine-tuning (your version, used everywhere)\n",
        "# ========================================================\n",
        "def finetune_model(model, train_loader, dev_loader, device, epochs):\n",
        "    model.train()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*epochs)\n",
        "    scaler = GradScaler()\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] RTE Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# ========================================================\n",
        "# 6) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train_ds = load_dataset(\"glue\", \"rte\", split=\"train\").shuffle(seed)\n",
        "    dev_ds   = load_dataset(\"glue\", \"rte\", split=\"validation\")\n",
        "\n",
        "    def preprocess(examples):\n",
        "        return tokenizer(examples[\"sentence1\"],\n",
        "                         examples[\"sentence2\"],\n",
        "                         truncation=True,\n",
        "                         padding=\"max_length\",\n",
        "                         max_length=128)\n",
        "\n",
        "    train = train_ds.map(preprocess, batched=True)\\\n",
        "                    .rename_column(\"label\", \"labels\")\\\n",
        "                    .remove_columns([\"sentence1\", \"sentence2\", \"idx\"])\n",
        "    dev = dev_ds.map(preprocess, batched=True)\\\n",
        "                .rename_column(\"label\", \"labels\")\\\n",
        "                .remove_columns([\"sentence1\", \"sentence2\", \"idx\"])\n",
        "\n",
        "    collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    # --- Stage 1: Full Fine-Tuning ---\n",
        "    print(\"\\n=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=2).to(device)\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=6)\n",
        "    acc_full = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}\")\n",
        "\n",
        "    # --- Stage 2: ER-Based Pruning ---\n",
        "    print(\"\\n=== Stage 2: ER-Based Pruning (Remove 4 highest-ER layers) ===\")\n",
        "    er_scores = estimate_entropy_rate(model, dev_loader, device, max_batches=10)\n",
        "    print(\"Estimated ER per layer pair:\", er_scores)\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(f\"Pruning 4 layers with highest ER: {prune_idxs}\")\n",
        "\n",
        "    # --- Stage 3: Fine-tune pruned model ---\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=5)\n",
        "    acc_pruned = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "s3BN8CskRebs"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "uWn8MV69Rh53"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import random\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "# Monkey‑patch numpy.array to ignore the copy argument (workaround for NumPy 2.0)\n",
        "_np_array = np.array\n",
        "def _patched_array(obj, *args, copy=False, **kwargs):\n",
        "    return _np_array(obj, *args, **kwargs)\n",
        "np.array = _patched_array\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from datasets import load_dataset\n",
        "\n",
        "from collections import defaultdict\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ─── 1) Entropy‐Rate Hooks (Theorem 2) ─────────────────────────────────────\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "        # need valid previous + current and same batch size\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr) or \\\n",
        "           X_prev.size(0) != X_curr.size(0):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = [\n",
        "                F.cosine_similarity(\n",
        "                    dY[i].unsqueeze(0),\n",
        "                    dX[i].unsqueeze(0),\n",
        "                    dim=1, eps=1e-8\n",
        "                ).item()\n",
        "                for i in range(1, B)\n",
        "            ]\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "    return er_scores\n",
        "\n",
        "# ─── 2) Pruning Utilities ─────────────────────────────────────────────────\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune]\n",
        "                  if idx+1 < len(model.roberta.encoder.layer)]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ─── 3) LoRA Modules ──────────────────────────────────────────────────────\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "# ─── 4) STS-B Evaluation with Flattened References ─────────────────────────\n",
        "def evaluate_stsb(model, dataloader, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"glue\", \"stsb\")\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dataloader:\n",
        "            out = model(\n",
        "                input_ids=batch[\"input_ids\"].to(device),\n",
        "                attention_mask=batch[\"attention_mask\"].to(device),\n",
        "            )\n",
        "            # flatten predictions\n",
        "            p = out.logits.squeeze(-1).cpu().tolist()\n",
        "            preds.extend(p if isinstance(p, list) else [p])\n",
        "            # flatten references (handle [[5.0], [4.75], ...])\n",
        "            r = batch[\"labels\"].cpu().tolist()\n",
        "            # r might be list-of-lists or list-of-floats\n",
        "            for x in r:\n",
        "                if isinstance(x, (list, tuple, np.ndarray)):\n",
        "                    refs.append(float(x[0]))\n",
        "                else:\n",
        "                    refs.append(float(x))\n",
        "    return metric.compute(predictions=preds, references=refs)\n",
        "\n",
        "# ─── 5) Training Stages ────────────────────────────────────────────────────\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=1\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(\n",
        "        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6\n",
        "    )\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(\n",
        "                    input_ids=batch[\"input_ids\"].to(device),\n",
        "                    attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                    labels=batch[\"labels\"].to(device),\n",
        "                )\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx] += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        epoch_er = {\n",
        "            idx: er_sums[idx] / er_counts[idx]\n",
        "            for idx in er_sums if er_counts[idx] > 0\n",
        "        }\n",
        "        print(f\"[Epoch {epoch+1}] ER:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    metrics = evaluate_stsb(model, dev_loader, device)\n",
        "    print(f\"STS‑B Pearson: {metrics['pearson']:.4f}, Spearman: {metrics['spearmanr']:.4f}\")\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers:\", prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(\n",
        "        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*3\n",
        "    )\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(\n",
        "                input_ids=batch[\"input_ids\"].to(device),\n",
        "                attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                labels=batch[\"labels\"].to(device),\n",
        "            )\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "\n",
        "        metrics = evaluate_stsb(model, dev_loader, device)\n",
        "        print(f\"[Prune Epoch {epoch+1}] Pearson: {metrics['pearson']:.4f}\")\n",
        "    return model\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device):\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model)\n",
        "    for p in model.roberta.parameters(): p.requires_grad = False\n",
        "    for p in model.classifier.parameters(): p.requires_grad = True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad = True\n",
        "        l.B.requires_grad = True\n",
        "\n",
        "    opt = torch.optim.AdamW(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(\n",
        "        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6\n",
        "    )\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(\n",
        "                    input_ids=batch[\"input_ids\"].to(device),\n",
        "                    attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                    labels=batch[\"labels\"].to(device),\n",
        "                )\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "        metrics = evaluate_stsb(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] Pearson: {metrics['pearson']:.4f}\")\n",
        "\n",
        "# ─── 6) Main Entrypoint ────────────────────────────────────────────────────\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train_ds = load_dataset(\"glue\", \"stsb\", split=\"train\").shuffle(seed)\n",
        "    dev_ds   = load_dataset(\"glue\", \"stsb\", split=\"validation\")\n",
        "\n",
        "    def preprocess(ex):\n",
        "        return tokenizer(\n",
        "            ex[\"sentence1\"], ex[\"sentence2\"],\n",
        "            truncation=True, padding=\"max_length\", max_length=128\n",
        "        )\n",
        "\n",
        "    train_ds = train_ds.map(preprocess, batched=True)\n",
        "    dev_ds   = dev_ds.map(preprocess, batched=True)\n",
        "\n",
        "    # Cast labels to flat float32\n",
        "    train_ds = train_ds.map(lambda x: {\"labels\": float(x[\"label\"])}, batched=False)\n",
        "    dev_ds   = dev_ds.map(lambda x: {\"labels\": float(x[\"label\"])}, batched=False)\n",
        "\n",
        "    train_ds = train_ds.remove_columns([\"sentence1\", \"sentence2\", \"label\", \"idx\"])\n",
        "    dev_ds   = dev_ds.remove_columns([\"sentence1\", \"sentence2\", \"label\", \"idx\"])\n",
        "\n",
        "    train_ds.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])\n",
        "    dev_ds.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,  collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev_ds,   batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "-ae7zj04U5u5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "U3wDNkD86y-D"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}