{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fQy4J05l_RX_"
      },
      "outputs": [],
      "source": [
        "!pip uninstall -y datasets\n",
        "!pip install datasets==2.18.0\n",
        "!pip install evaluate"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tokenizer, max_length=128):\n",
        "    return tokenizer(examples[\"premise\"], examples[\"hypothesis\"],\n",
        "                     truncation=True, padding=\"max_length\", max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=3\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx] += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune MNLI Acc: {acc:.4f}\")\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] MNLI Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad = False\n",
        "    for p in model.classifier.parameters(): p.requires_grad = True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad = True\n",
        "        l.B.requires_grad = True\n",
        "\n",
        "    opt = torch.optim.Adam(\n",
        "        list(model.classifier.parameters()) +\n",
        "        [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] MNLI Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    dataset = load_dataset(\"glue\", \"mnli\")\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "\n",
        "    train_ds = dataset[\"train\"].shuffle(seed).select(range(10000))  # smaller for speed\n",
        "    dev_ds = dataset[\"validation_matched\"]\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"premise\", \"hypothesis\"])\\\n",
        "                    .rename_column(\"label\", \"labels\")\n",
        "\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True,\n",
        "                     remove_columns=[\"premise\", \"hypothesis\"])\\\n",
        "                .rename_column(\"label\", \"labels\")\n",
        "\n",
        "    collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "8hICTgzJQ4Cj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# KE values per epoch\n",
        "ke_epochs = {\n",
        "    1: {0: 3.7403, 1: 3.8396, 2: 3.8184, 3: 3.7880, 4: 3.7816, 5: 3.7467, 6: 3.7648, 7: 3.7296, 8: 3.7192, 9: 3.6667, 10: 3.5540, 11: 3.5193},\n",
        "    2: {0: 3.7418, 1: 3.8355, 2: 3.8190, 3: 3.7905, 4: 3.7855, 5: 3.7613, 6: 3.7775, 7: 3.7418, 8: 3.7250, 9: 3.6640, 10: 3.5458, 11: 3.5028},\n",
        "    3: {0: 3.7434, 1: 3.8356, 2: 3.8193, 3: 3.7932, 4: 3.7921, 5: 3.7653, 6: 3.7784, 7: 3.7541, 8: 3.7352, 9: 3.6658, 10: 3.5442, 11: 3.5088},\n",
        "    4: {0: 3.7441, 1: 3.8372, 2: 3.8191, 3: 3.7979, 4: 3.7927, 5: 3.7678, 6: 3.7775, 7: 3.7542, 8: 3.7398, 9: 3.6725, 10: 3.5492, 11: 3.5134},\n",
        "    5: {0: 3.7451, 1: 3.8349, 2: 3.8234, 3: 3.8037, 4: 3.7976, 5: 3.7695, 6: 3.7790, 7: 3.7539, 8: 3.7457, 9: 3.6797, 10: 3.5499, 11: 3.5118},\n",
        "    6: {0: 3.7452, 1: 3.8372, 2: 3.8257, 3: 3.8062, 4: 3.8008, 5: 3.7704, 6: 3.7808, 7: 3.7537, 8: 3.7452, 9: 3.6729, 10: 3.5400, 11: 3.5013}\n",
        "}\n",
        "\n",
        "# Plot\n",
        "plt.figure(figsize=(10, 6))\n",
        "for epoch, ke in ke_epochs.items():\n",
        "    layers = [l + 1 for l in ke.keys()]  # shift layer indices to 1–12\n",
        "    values = list(ke.values())\n",
        "    plt.plot(layers, values, marker='o', label=f\"Epoch {epoch}\")\n",
        "\n",
        "#plt.title(\"Knowledge Entropy vs Layer Index\", fontsize=16)\n",
        "plt.xlabel(\"Layer Index\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.legend()\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "ryiCc-_TQ6CE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level KE\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level KE\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune MRPC Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] MRPC Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] MRPC Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"mrpc\", split=\"train\").shuffle(seed).select(range(1000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"mrpc\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "AFGs9h-f_Swa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Knowledge Entropy values per epoch\n",
        "ke_epochs = [\n",
        "    {0: 3.239791645050049, 1: 3.2986808376312258, 2: 3.253356773376465, 3: 3.228304588317871, 4: 3.204816005706787, 5: 3.1601281089782716, 6: 3.167900978088379, 7: 3.14154211807251, 8: 3.136232151031494, 9: 3.108181224822998, 10: 3.070255741119385, 11: 3.150283061981201},\n",
        "    {0: 3.2390173606872557, 1: 3.2963816051483152, 2: 3.2457099170684813, 3: 3.224952474594116, 4: 3.199368993759155, 5: 3.1539372539520265, 6: 3.149820531845093, 7: 3.1081827564239504, 8: 3.0931982421875, 9: 3.0528202323913574, 10: 3.028363660812378, 11: 3.050001382827759},\n",
        "    {0: 3.239009340286255, 1: 3.2983984470367433, 2: 3.2463859825134276, 3: 3.2241379623413087, 4: 3.1976428546905518, 5: 3.1447392463684083, 6: 3.1279000110626223, 7: 3.0747772026062012, 8: 3.0223362197875976, 9: 2.9473570098876953, 10: 2.87524440574646, 11: 2.8674370727539062},\n",
        "    {0: 3.238827512741089, 1: 3.2974642314910887, 2: 3.2431242713928223, 3: 3.2216903343200682, 4: 3.199512315750122, 5: 3.1563696422576903, 6: 3.134328540802002, 7: 3.0738258819580078, 8: 2.99842280960083, 9: 2.912938259124756, 10: 2.8102419834136962, 11: 2.771777828216553},\n",
        "    {0: 3.2383913173675536, 1: 3.296243072509766, 2: 3.242586082458496, 3: 3.221456657409668, 4: 3.2004083824157714, 5: 3.152312059402466, 6: 3.1258499088287355, 7: 3.064895736694336, 8: 2.980252359390259, 9: 2.860691904067993, 10: 2.749246000289917, 11: 2.710493507385254},\n",
        "    {0: 3.238114908218384, 1: 3.2962895565032957, 2: 3.243076961517334, 3: 3.221760944366455, 4: 3.2005157108306883, 5: 3.1507289390563966, 6: 3.123993368148804, 7: 3.0618533477783205, 8: 2.9700901947021485, 9: 2.843407918930054, 10: 2.727255252838135, 11: 2.6856814765930177}\n",
        "]\n",
        "\n",
        "# Plotting\n",
        "plt.figure(figsize=(10, 6))\n",
        "for i, epoch_ke in enumerate(ke_epochs):\n",
        "    layers = [l + 1 for l in epoch_ke.keys()]  # shift layer indices to 1–12\n",
        "    values = list(epoch_ke.values())\n",
        "    plt.plot(layers, values, label=f\"Epoch {i+1}\", marker='o')\n",
        "\n",
        "plt.xlabel(\"Layer\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "#plt.title(\"Knowledge Entropy vs Layers\", fontsize=16)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "xHDCg2DUDvH4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune SST-2 Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] SST-2 Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] SST-2 Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"sst2\", split=\"train\").shuffle(seed).select(range(5000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"sst2\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "e0e5CB8l49O9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# KE values for each epoch\n",
        "ke_by_epoch = {\n",
        "    1: [2.9651833793640137, 3.089046128463745, 3.0924095928192137, 3.0681774433135987, 3.0340894985198976, 2.9792041194915773, 2.9694000019073488, 2.9185178718566895, 2.9200110771179197, 2.82025255355835, 2.667954217147827, 2.6612180709838866],\n",
        "    2: [2.965893549346924, 3.085789897155762, 3.092639580154419, 3.078977996826172, 3.0649560813903807, 3.0321216072082517, 3.018477940368652, 2.9576354961395266, 2.911316421508789, 2.782052672576904, 2.6040575157165526, 2.5998096366882324],\n",
        "    3: [2.9650946979522703, 3.084409861755371, 3.0894264568328857, 3.077181778717041, 3.066806471252441, 3.0330316158294677, 3.026228702163696, 2.9687501598358152, 2.918505425262451, 2.781764380264282, 2.6073847118377684, 2.5956354915618896],\n",
        "    4: [2.9647007961273193, 3.0866590961456297, 3.088959024810791, 3.076308888244629, 3.0617130882263184, 3.0239853103637695, 3.0255811374664305, 2.9724936283111574, 2.91553058013916, 2.779527521133423, 2.608109250640869, 2.599828709793091],\n",
        "    5: [2.9640205013275147, 3.089625425720215, 3.0891447685241697, 3.0789716785430907, 3.066036195373535, 3.0240665218353273, 3.0201424140930175, 2.9597047756195067, 2.888725273895264, 2.7510012699127198, 2.584951690673828, 2.5864554821014405],\n",
        "    6: [2.963668116760254, 3.090019404220581, 3.089617932128906, 3.0806209617614746, 3.0694873622894288, 3.0306819889068604, 3.0209427997589113, 2.956924412536621, 2.8712899379730223, 2.7296667186737062, 2.57259878578186, 2.5810291679382322]\n",
        "}\n",
        "\n",
        "plt.figure(figsize=(10, 6))\n",
        "for epoch, ke_values in ke_by_epoch.items():\n",
        "    plt.plot(range(1,13), ke_values, marker='o', label=f'Epoch {epoch}')\n",
        "\n",
        "#plt.title(\"Knowledge Entropy vs. Layer\", fontsize=16)\n",
        "plt.xlabel(\"Layer Index\", fontsize=16)\n",
        "plt.ylabel(\"Approx. Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "lh8kP42oO0ZO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(\n",
        "        examples['sentence'],\n",
        "        truncation=True,\n",
        "        padding='max_length',\n",
        "        max_length=max_length\n",
        "    )\n",
        "\n",
        "\n",
        "from sklearn.metrics import matthews_corrcoef\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    acc_metric = evaluate.load(\"accuracy\")\n",
        "    mcc_metric = evaluate.load(\"matthews_correlation\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    acc = acc_metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "    mcc = mcc_metric.compute(predictions=preds, references=labs)[\"matthews_correlation\"]\n",
        "    return acc, mcc\n",
        "\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level KE\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level KE\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "        acc, mcc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"-> Full Finetune CoLA Acc: {acc:.4f} | MCC: {mcc:.4f}\")\n",
        "\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "\n",
        "        acc, mcc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"-> Full Finetune CoLA Acc: {acc:.4f} | MCC: {mcc:.4f}\")\n",
        "\n",
        "\n",
        "    return model\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "        acc, mcc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"-> Full Finetune CoLA Acc: {acc:.4f} | MCC: {mcc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"cola\", split=\"train\").shuffle(seed).select(range(1000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"cola\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "VLk75h1ionEm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Knowledge Entropy values over 6 epochs\n",
        "ke_values = [\n",
        "    {0: 2.9448, 1: 3.0885, 2: 3.1181, 3: 3.0932, 4: 3.0441, 5: 2.9380, 6: 2.9606, 7: 2.9447, 8: 2.9687, 9: 2.9775, 10: 2.9016, 11: 2.9350},\n",
        "    {0: 2.9439, 1: 3.0961, 2: 3.1100, 3: 3.0816, 4: 3.0435, 5: 2.9906, 6: 3.0009, 7: 2.9850, 8: 2.9817, 9: 2.9741, 10: 2.8645, 11: 2.8777},\n",
        "    {0: 2.9447, 1: 3.0925, 2: 3.1075, 3: 3.0703, 4: 3.0372, 5: 2.9933, 6: 3.0018, 7: 2.9899, 8: 2.9857, 9: 2.9657, 10: 2.8410, 11: 2.8267},\n",
        "    {0: 2.9449, 1: 3.0911, 2: 3.1028, 3: 3.0704, 4: 3.0389, 5: 2.9948, 6: 2.9996, 7: 2.9699, 8: 2.9628, 9: 2.9423, 10: 2.8048, 11: 2.7881},\n",
        "    {0: 2.9457, 1: 3.0933, 2: 3.1071, 3: 3.0762, 4: 3.0477, 5: 3.0011, 6: 3.0038, 7: 2.9738, 8: 2.9548, 9: 2.9251, 10: 2.7719, 11: 2.7567},\n",
        "    {0: 2.9456, 1: 3.0939, 2: 3.1071, 3: 3.0768, 4: 3.0461, 5: 3.0022, 6: 3.0023, 7: 2.9735, 8: 2.9553, 9: 2.9216, 10: 2.7640, 11: 2.7494},\n",
        "]\n",
        "\n",
        "plt.figure(figsize=(10, 6))\n",
        "for i, epoch_ke in enumerate(ke_values, start=1):\n",
        "    layers = [l + 1 for l in epoch_ke.keys()]  # shift layer indices to 1–12\n",
        "    entropy = list(epoch_ke.values())\n",
        "    plt.plot(layers, entropy, marker='o', label=f\"Epoch {i}\")\n",
        "\n",
        "plt.xlabel(\"Layer Index\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "#plt.title(\"Knowledge Entropy vs. Layer Index (CoLA)\", fontsize=16)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "cVv85s_wpaN-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "DNR1BHeYpaq6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['question'],\n",
        "               examples['sentence'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level KE\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level KE\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune QNLI Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] QNLI Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] QNLI Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"qnli\", split=\"train\").shuffle(seed).select(range(2000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"qnli\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"question\",\"sentence\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"question\",\"sentence\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "M5lgTIkOonfn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Knowledge Entropy (KE) values across layers for each epoch\n",
        "ke_epochs = {\n",
        "    1: {0: 3.2304, 1: 3.2924, 2: 3.2409, 3: 3.1981, 4: 3.1778, 5: 3.1335, 6: 3.1346, 7: 3.0527, 8: 3.0132, 9: 2.9415, 10: 2.8503, 11: 2.8969},\n",
        "    2: {0: 3.2314, 1: 3.2867, 2: 3.2390, 3: 3.2144, 4: 3.1975, 5: 3.1637, 6: 3.1601, 7: 3.0645, 8: 2.9672, 9: 2.8309, 10: 2.6973, 11: 2.6584},\n",
        "    3: {0: 3.2336, 1: 3.2851, 2: 3.2359, 3: 3.2133, 4: 3.2056, 5: 3.1733, 6: 3.1724, 7: 3.0849, 8: 2.9905, 9: 2.8596, 10: 2.7224, 11: 2.6792},\n",
        "    4: {0: 3.2347, 1: 3.2857, 2: 3.2389, 3: 3.2136, 4: 3.2055, 5: 3.1750, 6: 3.1752, 7: 3.0889, 8: 2.9981, 9: 2.8731, 10: 2.7439, 11: 2.6969},\n",
        "    5: {0: 3.2345, 1: 3.2861, 2: 3.2386, 3: 3.2143, 4: 3.2060, 5: 3.1730, 6: 3.1725, 7: 3.0911, 8: 2.9987, 9: 2.8678, 10: 2.7374, 11: 2.6870},\n",
        "    6: {0: 3.2349, 1: 3.2869, 2: 3.2387, 3: 3.2147, 4: 3.2044, 5: 3.1738, 6: 3.1718, 7: 3.0958, 8: 3.0050, 9: 2.8763, 10: 2.7463, 11: 2.6961}\n",
        "}\n",
        "\n",
        "# Plot\n",
        "plt.figure(figsize=(10, 6))\n",
        "for epoch, values in ke_epochs.items():\n",
        "    layers = [l + 1 for l in values.keys()]  # shift layer indices to 1–12\n",
        "    entropy = list(values.values())\n",
        "    plt.plot(layers, entropy, marker='o', label=f\"Epoch {epoch}\")\n",
        "\n",
        "plt.xlabel(\"Layer Index\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "#plt.title(\"Knowledge Entropy vs. Layer Index\", fontsize=16)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "t35oIk4Cdoed"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Q45PT1_AfhaE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "VCeLVFz-fh4i"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(\n",
        "        examples['question1'],\n",
        "        examples['question2'],\n",
        "        truncation=True,\n",
        "        padding='max_length',\n",
        "        max_length=max_length\n",
        "    )\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level KE\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level KE\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune QQP Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] QQP Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] QQP Acc: {acc:.4f}\")\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"qqp\", split=\"train\").shuffle(seed).select(range(3000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"qqp\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"question1\",\"question2\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"question1\",\"question2\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "md4igvXakGTo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# KE data per epoch\n",
        "ke_epochs = {\n",
        "    1: {0: 3.1468, 1: 3.2210, 2: 3.1869, 3: 3.1624, 4: 3.1201, 5: 3.0795, 6: 3.1073, 7: 3.0597, 8: 3.0786, 9: 3.0529, 10: 2.9826, 11: 2.9888},\n",
        "    2: {0: 3.1473, 1: 3.2196, 2: 3.1859, 3: 3.1655, 4: 3.1310, 5: 3.1000, 6: 3.1146, 7: 3.0777, 8: 3.0792, 9: 3.0537, 10: 2.9414, 11: 2.9214},\n",
        "    3: {0: 3.1469, 1: 3.2201, 2: 3.1880, 3: 3.1641, 4: 3.1386, 5: 3.1055, 6: 3.1073, 7: 3.0607, 8: 3.0420, 9: 2.9884, 10: 2.8447, 11: 2.8140},\n",
        "    4: {0: 3.1465, 1: 3.2205, 2: 3.1869, 3: 3.1643, 4: 3.1412, 5: 3.1092, 6: 3.1093, 7: 3.0606, 8: 3.0383, 9: 2.9692, 10: 2.8255, 11: 2.7930},\n",
        "    5: {0: 3.1463, 1: 3.2201, 2: 3.1878, 3: 3.1660, 4: 3.1404, 5: 3.1064, 6: 3.1102, 7: 3.0624, 8: 3.0388, 9: 2.9652, 10: 2.8183, 11: 2.7961},\n",
        "    6: {0: 3.1460, 1: 3.2191, 2: 3.1884, 3: 3.1685, 4: 3.1425, 5: 3.1084, 6: 3.1098, 7: 3.0626, 8: 3.0334, 9: 2.9510, 10: 2.8009, 11: 2.7761},\n",
        "}\n",
        "\n",
        "# Plotting\n",
        "plt.figure(figsize=(10, 6))\n",
        "for epoch, ke in ke_epochs.items():\n",
        "    layers = [l + 1 for l in ke.keys()]  # shift layer indices to 1–12\n",
        "    values = list(ke.values())\n",
        "    plt.plot(layers, values, marker='o', label=f\"Epoch {epoch}\")\n",
        "\n",
        "#plt.title(\"Knowledge Entropy vs Layers\", fontsize=16)\n",
        "plt.xlabel(\"Layer Index\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(fontsize=16)\n",
        "plt.yticks(fontsize=16)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "bq_y4H85ghyG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(\n",
        "        examples['sentence1'],\n",
        "        examples['sentence2'],\n",
        "        truncation=True,\n",
        "        padding='max_length',\n",
        "        max_length=max_length\n",
        "    )\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level KE\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level KE\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune RTE Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] RTE Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] RTE Acc: {acc:.4f}\")\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"rte\", split=\"train\").shuffle(seed)\n",
        "    dev_ds   = load_dataset(\"glue\", \"rte\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "mT_gG5aqu50v"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# KE data for each epoch\n",
        "ke_data = {\n",
        "    1: {0: 3.234637494270618, 1: 3.2954297256775393, 2: 3.2472059474541592, 3: 3.2122703920572233, 4: 3.1996606771762552, 5: 3.166536705616193, 6: 3.192006678917469, 7: 3.1797323601368146, 8: 3.1996014416217804, 9: 3.198551667042268, 10: 3.1776862190319943, 11: 3.233222233179288},\n",
        "    2: {0: 3.2364144898377933, 1: 3.292912809512554, 2: 3.241034468779197, 3: 3.214976036395782, 4: 3.198220124611488, 5: 3.16093972325325, 6: 3.176389996057902, 7: 3.153661161661148, 8: 3.156559161650829, 9: 3.124322636769368, 10: 3.0624067340141687, 11: 3.0741044863676414},\n",
        "    3: {0: 3.2363225794755497, 1: 3.2945617039998374, 2: 3.2429233766519108, 3: 3.214366410023127, 4: 3.2025416669173117, 5: 3.16937870092881, 6: 3.180643222270868, 7: 3.16009650627772, 8: 3.1594576827990704, 9: 3.1243840715824027, 10: 3.053374951466536, 11: 3.0580264222927585},\n",
        "    4: {0: 3.2360405830236583, 1: 3.2945619431825786, 2: 3.2415615954460244, 3: 3.2135055791109037, 4: 3.2035804826479692, 5: 3.1748380936109104, 6: 3.1875738394566073, 7: 3.1665822023000474, 8: 3.1637751505925107, 9: 3.1298167407512665, 10: 3.052433254627081, 11: 3.038477917512258},\n",
        "    5: {0: 3.2361522951187234, 1: 3.294039017114884, 2: 3.241480989333911, 3: 3.2140066035282917, 4: 3.204068192304709, 5: 3.175252984731625, 6: 3.1887962214457684, 7: 3.1690831688734202, 8: 3.1620407494214864, 9: 3.1216949025789895, 10: 3.035130114127428, 11: 3.0117857922346163},\n",
        "    6: {0: 3.2358735440633235, 1: 3.2946181755799513, 2: 3.2427005072434745, 3: 3.213492139791831, 4: 3.204735367726057, 5: 3.174306563077829, 6: 3.1876255694108133, 7: 3.167302029255109, 8: 3.1590061837281938, 9: 3.115201425093871, 10: 3.0225293246599345, 11: 2.994033750051107}\n",
        "}\n",
        "\n",
        "# Plot\n",
        "plt.figure(figsize=(10, 6))\n",
        "for epoch, layer_data in ke_data.items():\n",
        "    layers = [l + 1 for l in layer_data.keys()]  # shift layer indices to 1–12\n",
        "    ke_values = list(layer_data.values())\n",
        "    plt.plot(layers, ke_values, marker='o', label=f\"Epoch {epoch}\")\n",
        "\n",
        "plt.xlabel(\"Layer Index\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "#plt.title(\"Knowledge Entropy vs. Layer\", fontsize=16)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "OwQGPSMJhDZ_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(\n",
        "        examples['sentence1'],\n",
        "        examples['sentence2'],\n",
        "        truncation=True,\n",
        "        padding='max_length',\n",
        "        max_length=max_length\n",
        "    )\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"pearsonr\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            # Regression head: output is [B, 1]\n",
        "            pred = out.logits.view(-1).cpu().numpy()\n",
        "            preds.extend(pred)\n",
        "    return metric.compute(predictions=preds, references=labs)[\"pearsonr\"]\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=1  # num_labels=1 for regression!\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device, dtype=torch.float).unsqueeze(1))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level KE\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level KE\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    pearson = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune STS-B Pearson: {pearson:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device, dtype=torch.float).unsqueeze(1))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        pearson = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] STS-B Pearson: {pearson:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device, dtype=torch.float).unsqueeze(1))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        pearson = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] STS-B Pearson: {pearson:.4f}\")\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"stsb\", split=\"train\").shuffle(seed)\n",
        "    dev_ds   = load_dataset(\"glue\", \"stsb\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "oCixrJGeu6Qn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Provided KE values\n",
        "ke_epochs = {\n",
        "    1: {0: 3.1216782207117624, 1: 3.205926341373831, 2: 3.1831439436725515, 3: 3.1425496576890164, 4: 3.108628351930451,\n",
        "        5: 3.0245226891879744, 6: 3.0113194900693085, 7: 2.96199728318481, 8: 2.9216685882032496, 9: 2.9011312645234386,\n",
        "        10: 2.7949637216056007, 11: 2.8036840522405337},\n",
        "    2: {0: 3.122385612946724, 1: 3.2053615688780255, 2: 3.1827276502431516, 3: 3.14758318588034, 4: 3.128111186047423,\n",
        "        5: 3.0808556003597083, 6: 3.0710045684527953, 7: 3.012045979334018, 8: 2.9473394376014634, 9: 2.890920449032737,\n",
        "        10: 2.706535710744632, 11: 2.694957201271959},\n",
        "    3: {0: 3.1233545473785824, 1: 3.2060421793119303, 2: 3.180608749721246, 3: 3.149725641096749, 4: 3.1337831318792944,\n",
        "        5: 3.087708413518022, 6: 3.0823230342175267, 7: 3.0231315847563978, 8: 2.9449381267908383, 9: 2.8736855148108513,\n",
        "        10: 2.6799341229636413, 11: 2.6526868296929624},\n",
        "    4: {0: 3.122849219697573, 1: 3.208289318190828, 2: 3.1786043458256836, 3: 3.152634641225547, 4: 3.140414782127518,\n",
        "        5: 3.0983164837695294, 6: 3.087455921942402, 7: 3.0256239744486164, 8: 2.950759364434509, 9: 2.883653296548899,\n",
        "        10: 2.686351388817205, 11: 2.6519999305130875},\n",
        "    5: {0: 3.1224169508969832, 1: 3.2059260482416696, 2: 3.17878786678606, 3: 3.156504187365069, 4: 3.1431414029860196,\n",
        "        5: 3.0961787657545403, 6: 3.0855182536951524, 7: 3.0326405927766844, 8: 2.9571537298353725, 9: 2.8834909653298877,\n",
        "        10: 2.6865991667348252, 11: 2.644667492788259},\n",
        "    6: {0: 3.122367330651953, 1: 3.20639664672512, 2: 3.179138040343644, 3: 3.157353766604491, 4: 3.1413243549755454,\n",
        "        5: 3.0952708057965954, 6: 3.0864931734612986, 7: 3.033865492267635, 8: 2.9587372055637324, 9: 2.8831359905725726,\n",
        "        10: 2.6869357043413524, 11: 2.6412745341140473}\n",
        "}\n",
        "\n",
        "# Plot\n",
        "plt.figure(figsize=(10, 6))\n",
        "for epoch, ke in ke_epochs.items():\n",
        "    layers = [l + 1 for l in ke.keys()]  # shift layer indices to 1–12\n",
        "    values = list(ke.values())\n",
        "    plt.plot(layers, values, marker='o', label=f'Epoch {epoch}')\n",
        "\n",
        "plt.xlabel('Layer Index', fontsize=16)\n",
        "plt.ylabel('Knowledge Entropy', fontsize=16)\n",
        "#plt.title('Knowledge Entropy vs Layer (fz=16)', fontsize=16)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "xiNpM5vKvAxC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Etp_DiKHvBTa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Caj-xs3GQVEE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Dq83WlRPQVf8"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}