{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fQy4J05l_RX_"
      },
      "outputs": [],
      "source": [
        "!pip uninstall -y datasets\n",
        "!pip install datasets==2.18.0\n",
        "!pip install evaluate"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n"
      ],
      "metadata": {
        "id": "3UI1t0_yd9p7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import numpy as np\n",
        "import random\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ---- 1. Knowledge Entropy Hook Utilities ----\n",
        "from functools import partial\n",
        "\n",
        "def register_ke_hooks_t5(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    dec_layers = model.decoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    enc_hooks, dec_hooks = [], []\n",
        "\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        enc_hooks.append(\n",
        "            layer.layer[1].DenseReluDense.register_forward_hook(\n",
        "                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), enc_acts, idx=i)\n",
        "            )\n",
        "        )\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        dec_hooks.append(\n",
        "            layer.layer[2].DenseReluDense.register_forward_hook(\n",
        "                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), dec_acts, idx=i)\n",
        "            )\n",
        "        )\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts)\n",
        "\n",
        "\n",
        "\n",
        "def compute_ke_batch(acts, act_fn=F.relu, eps=1e-8):\n",
        "    ke = {}\n",
        "    for idx, a in acts.items():\n",
        "        if a is None:\n",
        "            continue\n",
        "        act = act_fn(a)  # shape: (batch, seq, hidden)\n",
        "        # Sum over the last dimension (hidden), keep dims for broadcast\n",
        "        denom = act.sum(dim=-1, keepdim=True)\n",
        "        # To avoid division by zero, set any zero denominators to 1\n",
        "        denom = torch.where(denom == 0, torch.ones_like(denom), denom)\n",
        "        probs = act / (denom + eps)\n",
        "        # Clamp probabilities to avoid log(0)\n",
        "        probs = torch.clamp(probs, min=1e-8)\n",
        "        entropy = -torch.sum(probs * torch.log(probs), dim=-1).mean()\n",
        "        ke[idx] = entropy.item()\n",
        "        acts[idx] = None  # reset\n",
        "    return ke\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks: h.remove()\n",
        "\n",
        "# ---- 2. Pruning Utilities ----\n",
        "class SkipFFN(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states):\n",
        "        return hidden_states\n",
        "\n",
        "def prune_high_ke_ffn(blocks, ke_scores, num_prune=2, hidden_size=768):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        # Replace feed-forward block with Identity/Skip\n",
        "        blocks[idx].layer[1].DenseReluDense = SkipFFN(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# ---- 3. Data/Helper functions ----\n",
        "def make_t5_nli_prompt(premise, hypothesis):\n",
        "    return f\"nli premise: {premise} hypothesis: {hypothesis}\"\n",
        "\n",
        "def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    label_list = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "    labels = [label_list[x] if isinstance(x, int) and x < len(label_list) else x for x in batch['label']]\n",
        "    target = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device, label_texts):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "# ---- 4. Training/Fine-tuning Loops ----\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Knowledge Entropy Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts) = register_ke_hooks_t5(model)\n",
        "    last_enc_ke, last_dec_ke = None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_ke_sum, enc_ke_count = defaultdict(float), defaultdict(int)\n",
        "        dec_ke_sum, dec_ke_count = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            # KE estimation\n",
        "            batch_enc_ke = compute_ke_batch(enc_acts)\n",
        "            for idx, v in batch_enc_ke.items():\n",
        "                enc_ke_sum[idx] += v\n",
        "                enc_ke_count[idx] += 1\n",
        "            batch_dec_ke = compute_ke_batch(dec_acts)\n",
        "            for idx, v in batch_dec_ke.items():\n",
        "                dec_ke_sum[idx] += v\n",
        "                dec_ke_count[idx] += 1\n",
        "\n",
        "        epoch_enc_ke = {idx: enc_ke_sum[idx]/enc_ke_count[idx] for idx in enc_ke_sum if enc_ke_count[idx] > 0}\n",
        "        epoch_dec_ke = {idx: dec_ke_sum[idx]/dec_ke_count[idx] for idx in dec_ke_sum if dec_ke_count[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder KE: {epoch_enc_ke}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder KE: {epoch_dec_ke}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_ke, last_dec_ke = epoch_enc_ke, epoch_dec_ke\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts)])\n",
        "    return model, last_enc_ke, last_dec_ke\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_ke, dec_ke, tokenizer, label_texts):\n",
        "    print(\"=== Stage 2: Prune (High-KE) & Fine-tuning ===\")\n",
        "    # You can set num_prune as you wish\n",
        "    num_prune = 4\n",
        "   # enc_prune_idxs = prune_high_ke_ffn(model.encoder.block, enc_ke, num_prune=num_prune, hidden_size=model.config.d_model)\n",
        "    dec_prune_idxs = prune_high_ke_ffn(model.decoder.block, dec_ke, num_prune=num_prune, hidden_size=model.config.d_model)\n",
        "   # print(\"Pruned encoder layers (highest KE):\", enc_prune_idxs)\n",
        "    print(\"Pruned decoder layers (highest KE):\", dec_prune_idxs)\n",
        "\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] e-SNLI Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# ---- 5. Entrypoint ----\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "\n",
        "    data_files = {\n",
        "        \"train\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_train.json\",\n",
        "        \"validation\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_valid.json\",\n",
        "        \"test\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_test.json\"\n",
        "    }\n",
        "    raw_datasets = load_dataset(\"json\", data_files=data_files)\n",
        "    tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "    label_texts = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "\n",
        "    train_ds = raw_datasets[\"train\"].shuffle(seed=seed).select(range(10000))\n",
        "    dev_ds = raw_datasets[\"validation\"].shuffle(seed=seed).select(range(2000))\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=train_ds.column_names)\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=dev_ds.column_names)\n",
        "\n",
        "    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_ke, dec_ke = full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, enc_ke, dec_ke, tokenizer, label_texts)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "sWY7v5xjLqr5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Your data\n",
        "enc_ke_epochs = [\n",
        "    {0: 5.120778652954102, 1: 5.403757721710205, 2: 5.443886557769775, 3: 5.44828851776123, 4: 5.453934421539307, 5: 5.452359628295898, 6: 5.457024213409424, 7: 5.43810365524292, 8: 5.402327305603027, 9: 5.370745602416992, 10: 5.31574527053833, 11: 5.390125755310058},\n",
        "    {0: 5.1113538009643555, 1: 5.392516478729248, 2: 5.438380133056641, 3: 5.431306089782715, 4: 5.4405796981811525, 5: 5.4482001785278324, 6: 5.4542674026489255, 7: 5.456199799346924, 8: 5.428633112335205, 9: 5.4060304084777835, 10: 5.376359208679199, 11: 5.4560540000915525},\n",
        "    {0: 5.1189261985778804, 1: 5.415700047302246, 2: 5.461237496948242, 3: 5.445635594177246, 4: 5.448134069061279, 5: 5.454518992614746, 6: 5.457215517425537, 7: 5.458963891601562, 8: 5.441248722839355, 9: 5.416132154083252, 10: 5.3972449882507325, 11: 5.488359162902832},\n",
        "    {0: 5.116845191955567, 1: 5.412102639007569, 2: 5.457377126312256, 3: 5.440801442718506, 4: 5.446412419891358, 5: 5.4529715980529785, 6: 5.4542711631774905, 7: 5.458785958862305, 8: 5.443267332458496, 9: 5.41692006149292, 10: 5.398646075439453, 11: 5.48984780960083},\n",
        "    {0: 5.116689331817627, 1: 5.412071422576904, 2: 5.457381925201416, 3: 5.440740663146973, 4: 5.446404431915283, 5: 5.453024647521973, 6: 5.454191885375977, 7: 5.45875881652832, 8: 5.443202964019775, 9: 5.416829935455322, 10: 5.398244989776611, 11: 5.489683660888672},\n",
        "    {0: 5.116260417938232, 1: 5.411794747924804, 2: 5.457123635864257, 3: 5.440602420806885, 4: 5.44625763092041, 5: 5.452800505065918, 6: 5.454110127258301, 7: 5.458805889892578, 8: 5.443212223052979, 9: 5.4169437850952145, 10: 5.3985387763977055, 11: 5.489740788269043}\n",
        "]\n",
        "dec_ke_epochs = [\n",
        "    {0: 5.473455963897705, 1: 5.18053843383789, 2: 5.033648477172852, 3: 4.812303190612793, 4: 5.261005407714844, 5: 5.501210649108887, 6: 5.472148556518555, 7: 5.517318391418457, 8: 5.553192390441895, 9: 5.512592239379883, 10: 5.584075408172607, 11: 5.50434167175293},\n",
        "    {0: 5.472483442687988, 1: 5.169460816192627, 2: 5.043915207672119, 3: 4.82118962020874, 4: 5.270837689971924, 5: 5.517090723419189, 6: 5.498110526275635, 7: 5.5361996917724605, 8: 5.564452503204346, 9: 5.528676497650147, 10: 5.596128713989258, 11: 5.526001943969726},\n",
        "    {0: 5.469658322906494, 1: 5.168160199737549, 2: 5.06294663772583, 3: 4.844052698516846, 4: 5.287757330322266, 5: 5.529096306610107, 6: 5.515416881561279, 7: 5.5526224166870115, 8: 5.578667069244385, 9: 5.550772470855713, 10: 5.610965402221679, 11: 5.548971197509766},\n",
        "    {0: 5.468696701049804, 1: 5.168712305450439, 2: 5.070296308898926, 3: 4.852590578460694, 4: 5.291568259429932, 5: 5.531682807922364, 6: 5.5185763488769535, 7: 5.554903052520752, 8: 5.579686510467529, 9: 5.552240438842773, 10: 5.613069982147217, 11: 5.553303944396973},\n",
        "    {0: 5.468258437347412, 1: 5.169210543060303, 2: 5.070472750091553, 3: 4.852741773223877, 4: 5.2915089881896975, 5: 5.531434827423095, 6: 5.518463079833984, 7: 5.554849634552002, 8: 5.5797720138549805, 9: 5.552255963897705, 10: 5.61289744644165, 11: 5.553331785583496},\n",
        "    {0: 5.468696504974365, 1: 5.168731469726563, 2: 5.070022554016114, 3: 4.852040791320801, 4: 5.291398749542236, 5: 5.531332643890381, 6: 5.51825998916626, 7: 5.554569575500488, 8: 5.579519750976562, 9: 5.552110796356201, 10: 5.613056344604492, 11: 5.553367820739746}\n",
        "]\n",
        "\n",
        "layers = list(range(1, 13))  # 1-based layers for x-axis\n",
        "epochs = [f\"Epoch {i+1}\" for i in range(6)]\n",
        "\n",
        "# Convert dicts to lists (layer 1-12)\n",
        "enc_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in enc_ke_epochs]\n",
        "dec_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in dec_ke_epochs]\n",
        "\n",
        "# --- Encoder KE Plot ---\n",
        "plt.figure(figsize=(10, 6))\n",
        "for i, epoch in enumerate(epochs):\n",
        "    plt.plot(layers, enc_ke_list[i], marker='o', label=epoch)\n",
        "plt.xlabel(\"Layer\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(layers, fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "#plt.title(\"Encoder KE vs Layer\", fontsize=18)\n",
        "plt.grid(True)\n",
        "plt.legend(fontsize=12)\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "# --- Decoder KE Plot ---\n",
        "plt.figure(figsize=(10, 6))\n",
        "for i, epoch in enumerate(epochs):\n",
        "    plt.plot(layers, dec_ke_list[i], marker='o', label=epoch)\n",
        "plt.xlabel(\"Layer\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(layers, fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "#plt.title(\"Decoder KE vs Layer\", fontsize=18)\n",
        "plt.grid(True)\n",
        "plt.legend(fontsize=12)\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "6WSViXfaJgqP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "5HI_2RtFJhu_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Mount Google Drive if on Colab\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import numpy as np\n",
        "import random\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Load CQA Data ---\n",
        "data_files = {\n",
        "    \"train\": \"/content/drive/MyDrive/NLP_datasets/cqa/cqa_train.json\",\n",
        "    \"test\":  \"/content/drive/MyDrive/NLP_datasets/cqa/cqa_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Preprocessing Functions ---\n",
        "def preprocess_cqa(batch, tokenizer, max_input_length=128, max_target_length=8, use_cot=False):\n",
        "    if use_cot and 'abstractive_explanation' in batch:\n",
        "        inputs = [\n",
        "            f\"question: {q} choices: {', '.join(choices)} rationale: {exp}\"\n",
        "            for q, choices, exp in zip(batch['question'], batch['choices'], batch['abstractive_explanation'])\n",
        "        ]\n",
        "    else:\n",
        "        inputs = [\n",
        "            f\"question: {q} choices: {', '.join(choices)}\"\n",
        "            for q, choices in zip(batch['question'], batch['choices'])\n",
        "        ]\n",
        "    targets = [str(ans).strip() for ans in batch['answer']]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    target = tokenizer(targets, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "USE_COT = False\n",
        "\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=USE_COT),\n",
        "                            batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev   = dataset[\"test\"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=False),\n",
        "                            batched=True, remove_columns=dataset[\"test\"].column_names)\n",
        "\n",
        "collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 3. Knowledge Entropy Hook Utilities ---\n",
        "from functools import partial\n",
        "\n",
        "def register_ke_hooks_t5(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    dec_layers = model.decoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    enc_hooks, dec_hooks = [], []\n",
        "\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        enc_hooks.append(\n",
        "            layer.layer[1].DenseReluDense.register_forward_hook(\n",
        "                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), enc_acts, idx=i)\n",
        "            )\n",
        "        )\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        dec_hooks.append(\n",
        "            layer.layer[2].DenseReluDense.register_forward_hook(\n",
        "                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), dec_acts, idx=i)\n",
        "            )\n",
        "        )\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts)\n",
        "\n",
        "\n",
        "def compute_ke_batch(acts, act_fn=F.relu, eps=1e-8):\n",
        "    ke = {}\n",
        "    for idx, a in acts.items():\n",
        "        acts[idx] = None  # Always reset, even if skipping\n",
        "        if a is None:\n",
        "            continue\n",
        "        if not torch.isfinite(a).all():\n",
        "            continue\n",
        "        if a.numel() == 0 or a.abs().sum() == 0:\n",
        "            continue\n",
        "        act = act_fn(a)\n",
        "        denom = act.sum(dim=-1, keepdim=True)\n",
        "        denom = torch.where(denom == 0, torch.ones_like(denom), denom)\n",
        "        probs = act / (denom + eps)\n",
        "        probs = torch.clamp(probs, min=1e-8)\n",
        "        if not torch.isfinite(probs).all():\n",
        "            continue\n",
        "        entropy = -torch.sum(probs * torch.log(probs), dim=-1).mean()\n",
        "        if not torch.isfinite(entropy):\n",
        "            continue\n",
        "        ke[idx] = entropy.item()\n",
        "    return ke\n",
        "\n",
        "\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks: h.remove()\n",
        "\n",
        "# --- 4. Pruning Utilities ---\n",
        "class SkipFFN(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states):\n",
        "        return hidden_states\n",
        "\n",
        "def prune_high_ke_ffn(blocks, ke_scores, num_prune=4, hidden_size=768):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx].layer[1].DenseReluDense = SkipFFN(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 5. Training/Eval/KE Pipeline ---\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=4)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Knowledge Entropy Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts) = register_ke_hooks_t5(model)\n",
        "    last_enc_ke, last_dec_ke = None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_ke_sum, enc_ke_count = defaultdict(float), defaultdict(int)\n",
        "        dec_ke_sum, dec_ke_count = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            batch_enc_ke = compute_ke_batch(enc_acts)\n",
        "            for idx, v in batch_enc_ke.items():\n",
        "                enc_ke_sum[idx] += v\n",
        "                enc_ke_count[idx] += 1\n",
        "            batch_dec_ke = compute_ke_batch(dec_acts)\n",
        "            for idx, v in batch_dec_ke.items():\n",
        "                dec_ke_sum[idx] += v\n",
        "                dec_ke_count[idx] += 1\n",
        "\n",
        "        epoch_enc_ke = {idx: enc_ke_sum[idx]/enc_ke_count[idx] for idx in enc_ke_sum if enc_ke_count[idx] > 0}\n",
        "        epoch_dec_ke = {idx: dec_ke_sum[idx]/dec_ke_count[idx] for idx in dec_ke_sum if dec_ke_count[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder KE: {epoch_enc_ke}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder KE: {epoch_dec_ke}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_ke, last_dec_ke = epoch_enc_ke, epoch_dec_ke\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts)])\n",
        "    return model, last_enc_ke, last_dec_ke\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_ke, dec_ke, tokenizer):\n",
        "    print(\"=== Stage 2: Prune (High-KE) & Fine-tuning ===\")\n",
        "#    enc_prune_idxs = prune_high_ke_ffn(model.encoder.block, enc_ke, num_prune=4, hidden_size=model.config.d_model)\n",
        "    dec_prune_idxs = prune_high_ke_ffn(model.decoder.block, dec_ke, num_prune=4, hidden_size=model.config.d_model)\n",
        "#    print(\"Pruned encoder layers (highest KE):\", enc_prune_idxs)\n",
        "    print(\"Pruned decoder layers (highest KE):\", dec_prune_idxs)\n",
        "\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] CQA Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 6. Entrypoint ---\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_ke, dec_ke = full_finetuning(train_loader, dev_loader, device, tokenizer)\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_ke, dec_ke, tokenizer\n",
        "    )\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "xcQkLEUPLy59"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "# --- Data ---\n",
        "enc_ke_epochs = [\n",
        "    {0: 5.139904764485476, 1: 5.401092627952838, 2: 5.439678460702129, 3: 5.432176997117417, 4: 5.442642464817842, 5: 5.4454231058631235, 6: 5.463585686018119, 7: 5.456152987989103, 8: 5.42602445652528, 9: 5.3984725964676175, 10: 5.365565290591987, 11: 5.397532289749698},\n",
        "    {0: 5.139600602668299, 1: 5.4082865292215585, 2: 5.435035413513434, 3: 5.431703077357, 4: 5.4357455140851405, 5: 5.447143929736759, 6: 5.4573999727496565, 7: 5.448999722798665, 8: 5.4132760177887915, 9: 5.375146550302239, 10: 5.331545272288456, 11: 5.359056867131291},\n",
        "    {0: 5.132425843788485, 1: 5.400260408523635, 2: 5.425296262763013, 3: 5.425646370071887, 4: 5.4349209725954655, 5: 5.449994925794931, 6: 5.459187375305126, 7: 5.451061400678162, 8: 5.413698513519588, 9: 5.36510266693942, 10: 5.306283406631896, 11: 5.328084006954734},\n",
        "    {0: 5.132674681533538, 1: 5.40307938993858, 2: 5.426535778640722, 3: 5.425935121984121, 4: 5.435641975434152, 5: 5.449571485785624, 6: 5.459163357275852, 7: 5.450666630405119, 8: 5.419207827406759, 9: 5.381928584062798, 10: 5.340362237983541, 11: 5.369139033193854},\n",
        "    {0: 5.13272574069269, 1: 5.403113040235047, 2: 5.426666823513989, 3: 5.425940881613244, 4: 5.435654991757498, 5: 5.449515212737085, 6: 5.459123445457621, 7: 5.450610509255445, 8: 5.419158941027762, 9: 5.3816367816455255, 10: 5.33981858255045, 11: 5.36847594375485},\n",
        "    {0: 5.132623068021828, 1: 5.402830285196038, 2: 5.426464314922715, 3: 5.425696338534551, 4: 5.435371085732245, 5: 5.4492548917510435, 6: 5.4588225204956355, 7: 5.450303058123158, 8: 5.419105227553394, 9: 5.3820987663832796, 10: 5.341227367789483, 11: 5.37036663649098}\n",
        "]\n",
        "dec_ke_epochs = [\n",
        "    {0: 5.4599811254363315, 1: 5.3001703635642405, 2: 5.1401081963589315, 3: 5.0207898828544115, 4: 5.296078369021416, 5: 5.529929007354536, 6: 5.507063817821051, 7: 5.562702250323798, 8: 5.5851923988053676, 9: 5.551125566426077, 10: 5.5910814561341935, 11: 5.5818385175968475},\n",
        "    {0: 5.456549143280582, 1: 5.250347980361008, 2: 5.085265902750182, 3: 4.978233239992447, 4: 5.264991768895107, 5: 5.5131214478066966, 6: 5.495728558724166, 7: 5.558811952097608, 8: 5.581217587485148, 9: 5.551561930623439, 10: 5.592714675765061, 11: 5.592012528729007},\n",
        "    {0: 5.453973231929363, 1: 5.239954027012236, 2: 5.083703805117717, 3: 4.982208940455623, 4: 5.268532490966344, 5: 5.5198255513760905, 6: 5.506296765292832, 7: 5.569196505121665, 8: 5.591176272225459, 9: 5.5623457416056015, 10: 5.598043118373002, 11: 5.60101840440983},\n",
        "    {0: 5.452516633301533, 1: 5.234018758795727, 2: 5.084256347568556, 3: 4.9811194565495835, 4: 5.2666547043961645, 5: 5.518136622674751, 6: 5.504595130143691, 7: 5.568047467906682, 8: 5.590607944576219, 9: 5.562076567037548, 10: 5.598141102563767, 11: 5.602112366060905},\n",
        "    {0: 5.452368998566676, 1: 5.233688806273863, 2: 5.083870511141121, 3: 4.980399314993121, 4: 5.265982374181888, 5: 5.517764004580493, 6: 5.503982042052671, 7: 5.567702337243091, 8: 5.589939683528956, 9: 5.561352815925586, 10: 5.59754676850167, 11: 5.601407691567206},\n",
        "    {0: 5.4520845942149885, 1: 5.233260017357125, 2: 5.082963084542988, 3: 4.979661354165993, 4: 5.265588265381112, 5: 5.517923472732898, 6: 5.504299057240518, 7: 5.567891064858594, 8: 5.590223931318877, 9: 5.561733258481057, 10: 5.597703276091064, 11: 5.601514090765391}\n",
        "]\n",
        "\n",
        "layers = list(range(1, 13))  # 1-based layers for x-axis\n",
        "epochs = [f\"Epoch {i+1}\" for i in range(6)]\n",
        "\n",
        "# Convert dicts to lists (layer 1-12)\n",
        "enc_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in enc_ke_epochs]\n",
        "dec_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in dec_ke_epochs]\n",
        "\n",
        "# --- Encoder KE Plot ---\n",
        "plt.figure(figsize=(10, 6))\n",
        "for i, epoch in enumerate(epochs):\n",
        "    plt.plot(layers, enc_ke_list[i], marker='o', label=epoch)\n",
        "plt.xlabel(\"Layer\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(layers, fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.grid(True)\n",
        "plt.legend(fontsize=12)\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "# --- Decoder KE Plot ---\n",
        "plt.figure(figsize=(10, 6))\n",
        "for i, epoch in enumerate(epochs):\n",
        "    plt.plot(layers, dec_ke_list[i], marker='o', label=epoch)\n",
        "plt.xlabel(\"Layer\", fontsize=16)\n",
        "plt.ylabel(\"Decoder Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(layers, fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.grid(True)\n",
        "plt.legend(fontsize=12)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "v0ifEgX7M_Xu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "OvMDBY8mM_yO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- Mount Google Drive if using Colab ---\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# --- Standard Imports ---\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Load ANLI1 Dataset ---\n",
        "data_files = {\n",
        "    \"train\":      \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_train.json\",\n",
        "    \"validation\": \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_valid.json\",\n",
        "    \"test\":       \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Preprocessing Function ---\n",
        "def make_t5_nli_prompt(premise, hypothesis):\n",
        "    return f\"nli premise: {premise} hypothesis: {hypothesis}\"\n",
        "\n",
        "def preprocess_anli(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]\n",
        "    label_list = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "    labels = [label_list[int(x)] if isinstance(x, (int, float, str)) and str(x).isdigit() and int(x)<3 else str(x) for x in batch['label']]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    target = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_anli(ex, tokenizer), batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev   = dataset[\"validation\"].map(lambda ex: preprocess_anli(ex, tokenizer), batched=True, remove_columns=dataset[\"validation\"].column_names)\n",
        "collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 3. Knowledge Entropy Hook Utilities ---\n",
        "from functools import partial\n",
        "\n",
        "def register_ke_hooks_t5(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    dec_layers = model.decoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    enc_hooks, dec_hooks = [], []\n",
        "\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        enc_hooks.append(\n",
        "            layer.layer[1].DenseReluDense.register_forward_hook(\n",
        "                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), enc_acts, idx=i)\n",
        "            )\n",
        "        )\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        dec_hooks.append(\n",
        "            layer.layer[2].DenseReluDense.register_forward_hook(\n",
        "                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), dec_acts, idx=i)\n",
        "            )\n",
        "        )\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "def compute_ke_batch(acts, act_fn=F.relu, eps=1e-8):\n",
        "    ke = {}\n",
        "    for idx, a in acts.items():\n",
        "        acts[idx] = None  # Always reset, even if skipping\n",
        "        if a is None:\n",
        "            continue\n",
        "        if not torch.isfinite(a).all():\n",
        "            continue\n",
        "        if a.numel() == 0 or a.abs().sum() == 0:\n",
        "            continue\n",
        "        act = act_fn(a)\n",
        "        denom = act.sum(dim=-1, keepdim=True)\n",
        "        denom = torch.where(denom == 0, torch.ones_like(denom), denom)\n",
        "        probs = act / (denom + eps)\n",
        "        probs = torch.clamp(probs, min=1e-8)\n",
        "        if not torch.isfinite(probs).all():\n",
        "            continue\n",
        "        entropy = -torch.sum(probs * torch.log(probs), dim=-1).mean()\n",
        "        if not torch.isfinite(entropy):\n",
        "            continue\n",
        "        ke[idx] = entropy.item()\n",
        "    return ke\n",
        "\n",
        "\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks: h.remove()\n",
        "\n",
        "# --- 4. Pruning Utilities ---\n",
        "class SkipFFN(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states):\n",
        "        return hidden_states\n",
        "\n",
        "def prune_high_ke_ffn(blocks, ke_scores, num_prune=4, hidden_size=768):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx].layer[1].DenseReluDense = SkipFFN(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 5. Training/Eval/KE Pipeline ---\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Knowledge Entropy Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts) = register_ke_hooks_t5(model)\n",
        "    last_enc_ke, last_dec_ke = None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_ke_sum, enc_ke_count = defaultdict(float), defaultdict(int)\n",
        "        dec_ke_sum, dec_ke_count = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            batch_enc_ke = compute_ke_batch(enc_acts)\n",
        "            for idx, v in batch_enc_ke.items():\n",
        "                enc_ke_sum[idx] += v\n",
        "                enc_ke_count[idx] += 1\n",
        "            batch_dec_ke = compute_ke_batch(dec_acts)\n",
        "            for idx, v in batch_dec_ke.items():\n",
        "                dec_ke_sum[idx] += v\n",
        "                dec_ke_count[idx] += 1\n",
        "\n",
        "        epoch_enc_ke = {idx: enc_ke_sum[idx]/enc_ke_count[idx] for idx in enc_ke_sum if enc_ke_count[idx] > 0}\n",
        "        epoch_dec_ke = {idx: dec_ke_sum[idx]/dec_ke_count[idx] for idx in dec_ke_sum if dec_ke_count[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder KE: {epoch_enc_ke}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder KE: {epoch_dec_ke}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_ke, last_dec_ke = epoch_enc_ke, epoch_dec_ke\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts)])\n",
        "    return model, last_enc_ke, last_dec_ke\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_ke, dec_ke, tokenizer):\n",
        "    print(\"=== Stage 2: Prune (High-KE) & Fine-tuning ===\")\n",
        "#    enc_prune_idxs = prune_high_ke_ffn(model.encoder.block, enc_ke, num_prune=4, hidden_size=model.config.d_model)\n",
        "    dec_prune_idxs = prune_high_ke_ffn(model.decoder.block, dec_ke, num_prune=4, hidden_size=model.config.d_model)\n",
        "#    print(\"Pruned encoder layers (highest KE):\", enc_prune_idxs)\n",
        "    print(\"Pruned decoder layers (highest KE):\", dec_prune_idxs)\n",
        "\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] ANLI1 Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 6. Entrypoint ---\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_ke, dec_ke = full_finetuning(train_loader, dev_loader, device, tokenizer)\n",
        "    # --- PRUNING AND CONTINUED FINETUNING ---\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_ke, dec_ke, tokenizer\n",
        "    )\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "lbaPEUrcLvY8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Encoder and decoder KE per epoch (dicts)\n",
        "enc_ke_epochs = [\n",
        "    {0: 5.480030705793848, 1: 5.5496080641476615, 2: 5.53872952911089, 3: 5.511840214819278, 4: 5.516190859956561, 5: 5.521920965302665, 6: 5.521400574918063, 7: 5.5144495568185485, 8: 5.49060498093659, 9: 5.450951964816921, 10: 5.423030107613378, 11: 5.466182296822421},\n",
        "    {0: 5.4871524711824815, 1: 5.544603646476314, 2: 5.538442740800246, 3: 5.51321722111612, 4: 5.517586176350432, 5: 5.520688862170813, 6: 5.521329364236796, 7: 5.51392604315056, 8: 5.4896568541256885, 9: 5.451860777928001, 10: 5.422290218796926, 11: 5.473095425204849},\n",
        "    {0: 5.4869607552042545, 1: 5.544635340402711, 2: 5.538483706060446, 3: 5.5132214298788105, 4: 5.517588706286448, 5: 5.520692724551795, 6: 5.521303620878256, 7: 5.51382915883694, 8: 5.489492775359244, 9: 5.451637857365158, 10: 5.422221080787288, 11: 5.473087342549913},\n",
        "    {0: 5.486878904306663, 1: 5.5445804946827435, 2: 5.538352051770912, 3: 5.513102901656673, 4: 5.517401096955785, 5: 5.5205126964821005, 6: 5.521155353312222, 7: 5.513736206630491, 8: 5.489534128387019, 9: 5.451726898142929, 10: 5.4222523313800535, 11: 5.473024711809085},\n",
        "    {0: 5.486847182489791, 1: 5.544575284112175, 2: 5.5384362558148945, 3: 5.513173838381498, 4: 5.517494451324895, 5: 5.520579959311576, 6: 5.5212360184147675, 7: 5.513783328938034, 8: 5.489483414056166, 9: 5.451590445806396, 10: 5.422033144178845, 11: 5.472951911290487},\n",
        "    {0: 5.486859069680268, 1: 5.5446095003272005, 2: 5.538401910943805, 3: 5.513121429479347, 4: 5.517466274297463, 5: 5.520576953438093, 6: 5.521189136325188, 7: 5.513781391449695, 8: 5.489508598705508, 9: 5.451672037592474, 10: 5.422226219177246, 11: 5.473129866463798}\n",
        "]\n",
        "dec_ke_epochs = [\n",
        "    {0: 5.465231130587175, 1: 5.264995521004111, 2: 5.129532522346007, 3: 4.904289741644123, 4: 5.284679993001292, 5: 5.50792251817332, 6: 5.485444439223263, 7: 5.51603090317343, 8: 5.534246008096697, 9: 5.510181266989484, 10: 5.566446723553958, 11: 5.486786779712739},\n",
        "    {0: 5.469713579256041, 1: 5.189854344148881, 2: 4.998035216808773, 3: 4.745502117364263, 4: 5.174289006977791, 5: 5.4802458447655455, 6: 5.4423357679005235, 7: 5.488046237465538, 8: 5.5072504567236304, 9: 5.504863533096386, 10: 5.56821925942391, 11: 5.49581682125197},\n",
        "    {0: 5.469381145927364, 1: 5.188713194414278, 2: 4.996736016985805, 3: 4.743715989033911, 4: 5.17287224979201, 5: 5.479581630989895, 6: 5.441277936341987, 7: 5.487883858630818, 8: 5.507143537392512, 9: 5.505116776213206, 10: 5.5682468636618925, 11: 5.495751358235483},\n",
        "    {0: 5.469358630307758, 1: 5.188910964336104, 2: 4.996535207024057, 3: 4.743591746301141, 4: 5.172841481580079, 5: 5.4794120374526685, 6: 5.441188570197302, 7: 5.48781734431973, 8: 5.5070663813416285, 9: 5.504971349967345, 10: 5.568174341252742, 11: 5.495965877107081},\n",
        "    {0: 5.469419138318017, 1: 5.188836002803984, 2: 4.996781077612014, 3: 4.7440564200991675, 4: 5.173168115615844, 5: 5.4795689723605205, 6: 5.441124076389131, 7: 5.487770483834403, 8: 5.506965150378999, 9: 5.504970055534726, 10: 5.568170601981027, 11: 5.495966437657674},\n",
        "    {0: 5.469227241334461, 1: 5.1887324437640965, 2: 4.996399641491118, 3: 4.743360573450724, 4: 5.172784906568982, 5: 5.479544408434913, 6: 5.441251537232172, 7: 5.487974353971936, 8: 5.507155140922183, 9: 5.505052519298735, 10: 5.56826036453247, 11: 5.495957667032878}\n",
        "]\n",
        "layers = list(range(1, 13))  # 1-based layers for x-axis\n",
        "epochs = [f\"Epoch {i+1}\" for i in range(6)]\n",
        "\n",
        "# Convert dicts to lists (layer 1-12)\n",
        "enc_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in enc_ke_epochs]\n",
        "dec_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in dec_ke_epochs]\n",
        "\n",
        "# --- Encoder KE Plot ---\n",
        "plt.figure(figsize=(10, 6))\n",
        "for i, epoch in enumerate(epochs):\n",
        "    plt.plot(layers, enc_ke_list[i], marker='o', label=epoch)\n",
        "plt.xlabel(\"Layer\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(layers, fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.grid(True)\n",
        "plt.legend(fontsize=12)\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "# --- Decoder KE Plot ---\n",
        "plt.figure(figsize=(10, 6))\n",
        "for i, epoch in enumerate(epochs):\n",
        "    plt.plot(layers, dec_ke_list[i], marker='o', label=epoch)\n",
        "plt.xlabel(\"Layer\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(layers, fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.grid(True)\n",
        "plt.legend(fontsize=12)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "p4moOELwNmdu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "a5ME85ayNm56"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ===========================\n",
        "# 0. Google Drive Mount\n",
        "# ===========================\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# ===========================\n",
        "# 1. Imports and Setup\n",
        "# ===========================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ===========================\n",
        "# 2. Load SVAMP Dataset\n",
        "# ===========================\n",
        "data_files = {\n",
        "    \"train\": \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_train.json\",\n",
        "    \"test\": \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# ===========================\n",
        "# 3. Preprocessing\n",
        "# ===========================\n",
        "def preprocess_svamp(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    model_inputs = tokenizer(\n",
        "        batch[\"input\"], padding=\"max_length\", truncation=True, max_length=max_input_length\n",
        "    )\n",
        "    targets = [str(x) for x in batch[\"label\"]]\n",
        "    target_encodings = tokenizer(\n",
        "        targets, padding=\"max_length\", truncation=True, max_length=max_target_length\n",
        "    )\n",
        "    model_inputs[\"labels\"] = target_encodings[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev = dataset[\"test\"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset[\"test\"].column_names)\n",
        "collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# ===========================\n",
        "# 4. Knowledge Entropy Utilities\n",
        "# ===========================\n",
        "from functools import partial\n",
        "\n",
        "def register_ke_hooks_t5(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    dec_layers = model.decoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    enc_hooks, dec_hooks = [], []\n",
        "\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        enc_hooks.append(\n",
        "            layer.layer[1].DenseReluDense.register_forward_hook(\n",
        "                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), enc_acts, idx=i)\n",
        "            )\n",
        "        )\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        dec_hooks.append(\n",
        "            layer.layer[2].DenseReluDense.register_forward_hook(\n",
        "                partial(lambda acts, module, inp, out, idx: acts.__setitem__(idx, inp[0].detach()), dec_acts, idx=i)\n",
        "            )\n",
        "        )\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "def compute_ke_batch(acts, act_fn=F.relu, eps=1e-8):\n",
        "    ke = {}\n",
        "    for idx, a in acts.items():\n",
        "        acts[idx] = None  # Always reset, even if skipping\n",
        "        if a is None:\n",
        "            continue\n",
        "        if not torch.isfinite(a).all():\n",
        "            continue\n",
        "        if a.numel() == 0 or a.abs().sum() == 0:\n",
        "            continue\n",
        "        act = act_fn(a)\n",
        "        denom = act.sum(dim=-1, keepdim=True)\n",
        "        denom = torch.where(denom == 0, torch.ones_like(denom), denom)\n",
        "        probs = act / (denom + eps)\n",
        "        probs = torch.clamp(probs, min=1e-8)\n",
        "        if not torch.isfinite(probs).all():\n",
        "            continue\n",
        "        entropy = -torch.sum(probs * torch.log(probs), dim=-1).mean()\n",
        "        if not torch.isfinite(entropy):\n",
        "            continue\n",
        "        ke[idx] = entropy.item()\n",
        "    return ke\n",
        "\n",
        "\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks: h.remove()\n",
        "\n",
        "# ===========================\n",
        "# 5. Pruning Utilities\n",
        "# ===========================\n",
        "class SkipFFN(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states):\n",
        "        return hidden_states\n",
        "\n",
        "def prune_high_ke_ffn(blocks, ke_scores, num_prune=4, hidden_size=768):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx].layer[1].DenseReluDense = SkipFFN(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# ===========================\n",
        "# 6. Eval Helper\n",
        "# ===========================\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=8)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "# ===========================\n",
        "# 7. Training + KE Tracking + Pruning\n",
        "# ===========================\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Knowledge Entropy Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts) = register_ke_hooks_t5(model)\n",
        "    last_enc_ke, last_dec_ke = None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_ke_sum, enc_ke_count = defaultdict(float), defaultdict(int)\n",
        "        dec_ke_sum, dec_ke_count = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            batch_enc_ke = compute_ke_batch(enc_acts)\n",
        "            for idx, v in batch_enc_ke.items():\n",
        "                enc_ke_sum[idx] += v\n",
        "                enc_ke_count[idx] += 1\n",
        "            batch_dec_ke = compute_ke_batch(dec_acts)\n",
        "            for idx, v in batch_dec_ke.items():\n",
        "                dec_ke_sum[idx] += v\n",
        "                dec_ke_count[idx] += 1\n",
        "\n",
        "        epoch_enc_ke = {idx: enc_ke_sum[idx]/enc_ke_count[idx] for idx in enc_ke_sum if enc_ke_count[idx] > 0}\n",
        "        epoch_dec_ke = {idx: dec_ke_sum[idx]/dec_ke_count[idx] for idx in dec_ke_sum if dec_ke_count[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder KE: {epoch_enc_ke}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder KE: {epoch_dec_ke}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_ke, last_dec_ke = epoch_enc_ke, epoch_dec_ke\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts)])\n",
        "    return model, last_enc_ke, last_dec_ke\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_ke, dec_ke, tokenizer):\n",
        "    print(\"=== Stage 2: Prune (High-KE) & Fine-tuning ===\")\n",
        "#    enc_prune_idxs = prune_high_ke_ffn(model.encoder.block, enc_ke, num_prune=4, hidden_size=model.config.d_model)\n",
        "    dec_prune_idxs = prune_high_ke_ffn(model.decoder.block, dec_ke, num_prune=4, hidden_size=model.config.d_model)\n",
        "#    print(\"Pruned encoder layers (highest KE):\", enc_prune_idxs)\n",
        "    print(\"Pruned decoder layers (highest KE):\", dec_prune_idxs)\n",
        "\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] SVAMP Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# ===========================\n",
        "# 8. Main Entrypoint\n",
        "# ===========================\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_ke, dec_ke = full_finetuning(train_loader, dev_loader, device, tokenizer)\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_ke, dec_ke, tokenizer\n",
        "    )\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "ZuzYVSG9gmFz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Provided KE dicts per epoch\n",
        "enc_ke_epochs = [\n",
        "    {0: 5.154252872467041, 1: 5.415111303329468, 2: 5.469181814193726, 3: 5.475447826385498, 4: 5.4965128898620605, 5: 5.519717464447021, 6: 5.528017282485962, 7: 5.509926748275757, 8: 5.499344720840454, 9: 5.4804475879669186, 10: 5.460796346664429, 11: 5.484862985610962},\n",
        "    {0: 5.151100530624389, 1: 5.413294105529785, 2: 5.476608180999756, 3: 5.480278253555298, 4: 5.508586235046387, 5: 5.5349556255340575, 6: 5.54053879737854, 7: 5.51798532485962, 8: 5.511692543029785, 9: 5.501022787094116, 10: 5.478860750198364, 11: 5.49954044342041},\n",
        "    {0: 5.148409032821656, 1: 5.411761407852173, 2: 5.475491733551025, 3: 5.479352645874023, 4: 5.506951446533203, 5: 5.533953561782837, 6: 5.5362600994110105, 7: 5.5146334266662596, 8: 5.508514547348023, 9: 5.494693784713745, 10: 5.481513671875, 11: 5.50757402420044},\n",
        "    {0: 5.149230527877807, 1: 5.4128351974487305, 2: 5.476775159835816, 3: 5.480456304550171, 4: 5.5078802585601805, 5: 5.5339813232421875, 6: 5.535842866897583, 7: 5.514435157775879, 8: 5.508073215484619, 9: 5.49350022315979, 10: 5.481070413589477, 11: 5.507593402862549},\n",
        "    {0: 5.1491956615448, 1: 5.4127289581298825, 2: 5.4760276985168455, 3: 5.479673271179199, 4: 5.507365741729736, 5: 5.533624258041382, 6: 5.535600652694702, 7: 5.5140838718414305, 8: 5.508068170547485, 9: 5.494150772094726, 10: 5.482475533777354, 11: 5.509099026115573},\n",
        "    {0: 5.148624906539917, 1: 5.412059764862061, 2: 5.476006727218628, 3: 5.479845180511474, 4: 5.507486362457275, 5: 5.534078559875488, 6: 5.535918989181519, 7: 5.514234838485717, 8: 5.507962064743042, 9: 5.493731985092163, 10: 5.482590412606998, 11: 5.5096063711205305}\n",
        "]\n",
        "dec_ke_epochs = [\n",
        "    {0: 5.348595600128174, 1: 4.263790817260742, 2: 4.106602759361267, 3: 3.998788385391235, 4: 4.680020771026611, 5: 5.226834297180176, 6: 5.21035964012146, 7: 5.420144653320312, 8: 5.525527219772339, 9: 5.513277044296265, 10: 5.62511775970459, 11: 5.568444547653198},\n",
        "    {0: 5.303920812606812, 1: 4.223071842193604, 2: 4.090608291625976, 3: 3.994982166290283, 4: 4.663872623443604, 5: 5.226346616744995, 6: 5.2035440254211425, 7: 5.403556118011474, 8: 5.528957233428955, 9: 5.523081560134887, 10: 5.630525960922241, 11: 5.582965478897095},\n",
        "    {0: 5.310703840255737, 1: 4.263773174285888, 2: 4.116586332321167, 3: 4.017633848190307, 4: 4.676742868423462, 5: 5.224496612548828, 6: 5.196131429672241, 7: 5.402983207702636, 8: 5.525129270553589, 9: 5.520591449737549, 10: 5.6297976016998295, 11: 5.5853519535064695},\n",
        "    {0: 5.309150876998902, 1: 4.262697019577026, 2: 4.116720676422119, 3: 4.019005084037781, 4: 4.678860769271851, 5: 5.226104078292846, 6: 5.197585697174072, 7: 5.40531442642212, 8: 5.5262416648864745, 9: 5.5210045337677, 10: 5.630594682693482, 11: 5.586866397857666},\n",
        "    {0: 5.308763708387103, 1: 4.265620251091159, 2: 4.118220309821927, 3: 4.019376754760742, 4: 4.678985566509013, 5: 5.22458165032523, 6: 5.196559224809919, 7: 5.404032882379026, 8: 5.525934618346545, 9: 5.520692465256672, 10: 5.630939473911208, 11: 5.587003124003508},\n",
        "    {0: 5.309705539625519, 1: 4.265844345092773, 2: 4.117541984635956, 3: 4.019918616937131, 4: 4.680267382641228, 5: 5.225742593103526, 6: 5.197149743839186, 7: 5.405138190911741, 8: 5.52513566309092, 9: 5.520209808738864, 10: 5.630923173865494, 11: 5.586297920772007}\n",
        "]\n",
        "\n",
        "layers = list(range(1, 13))  # 1-based layers for x-axis\n",
        "epochs = [f\"Epoch {i+1}\" for i in range(6)]\n",
        "\n",
        "# Convert dicts to lists (layer 1-12)\n",
        "enc_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in enc_ke_epochs]\n",
        "dec_ke_list = [[epoch_ke[i-1] for i in layers] for epoch_ke in dec_ke_epochs]\n",
        "\n",
        "# --- Encoder KE Plot ---\n",
        "plt.figure(figsize=(10, 6))\n",
        "for i, epoch in enumerate(epochs):\n",
        "    plt.plot(layers, enc_ke_list[i], marker='o', label=epoch)\n",
        "plt.xlabel(\"Layer\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(layers, fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "#plt.title(\"Encoder KE vs Layer\", fontsize=18)\n",
        "plt.grid(True)\n",
        "plt.legend(fontsize=12)\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "# --- Decoder KE Plot ---\n",
        "plt.figure(figsize=(10, 6))\n",
        "for i, epoch in enumerate(epochs):\n",
        "    plt.plot(layers, dec_ke_list[i], marker='o', label=epoch)\n",
        "plt.xlabel(\"Layer\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(layers, fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "#plt.title(\"Decoder KE vs Layer\", fontsize=18)\n",
        "plt.grid(True)\n",
        "plt.legend(fontsize=12)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "ulh52jrRgmh9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "1ZZq3jLvLwBT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "n0QmqCHqLrIR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "I4AA1mnXLro1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "tjsRqlY6LsFv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "l1mrTDVYLslQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "-IuQUK0-LtFS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tokenizer, max_length=128):\n",
        "    return tokenizer(examples[\"premise\"], examples[\"hypothesis\"],\n",
        "                     truncation=True, padding=\"max_length\", max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=3\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx] += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune MNLI Acc: {acc:.4f}\")\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] MNLI Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad = False\n",
        "    for p in model.classifier.parameters(): p.requires_grad = True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad = True\n",
        "        l.B.requires_grad = True\n",
        "\n",
        "    opt = torch.optim.Adam(\n",
        "        list(model.classifier.parameters()) +\n",
        "        [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] MNLI Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    dataset = load_dataset(\"glue\", \"mnli\")\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "\n",
        "    train_ds = dataset[\"train\"].shuffle(seed).select(range(10000))  # smaller for speed\n",
        "    dev_ds = dataset[\"validation_matched\"]\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"premise\", \"hypothesis\"])\\\n",
        "                    .rename_column(\"label\", \"labels\")\n",
        "\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True,\n",
        "                     remove_columns=[\"premise\", \"hypothesis\"])\\\n",
        "                .rename_column(\"label\", \"labels\")\n",
        "\n",
        "    collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "8hICTgzJQ4Cj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# KE values per epoch\n",
        "ke_epochs = {\n",
        "    1: {0: 3.7403, 1: 3.8396, 2: 3.8184, 3: 3.7880, 4: 3.7816, 5: 3.7467, 6: 3.7648, 7: 3.7296, 8: 3.7192, 9: 3.6667, 10: 3.5540, 11: 3.5193},\n",
        "    2: {0: 3.7418, 1: 3.8355, 2: 3.8190, 3: 3.7905, 4: 3.7855, 5: 3.7613, 6: 3.7775, 7: 3.7418, 8: 3.7250, 9: 3.6640, 10: 3.5458, 11: 3.5028},\n",
        "    3: {0: 3.7434, 1: 3.8356, 2: 3.8193, 3: 3.7932, 4: 3.7921, 5: 3.7653, 6: 3.7784, 7: 3.7541, 8: 3.7352, 9: 3.6658, 10: 3.5442, 11: 3.5088},\n",
        "    4: {0: 3.7441, 1: 3.8372, 2: 3.8191, 3: 3.7979, 4: 3.7927, 5: 3.7678, 6: 3.7775, 7: 3.7542, 8: 3.7398, 9: 3.6725, 10: 3.5492, 11: 3.5134},\n",
        "    5: {0: 3.7451, 1: 3.8349, 2: 3.8234, 3: 3.8037, 4: 3.7976, 5: 3.7695, 6: 3.7790, 7: 3.7539, 8: 3.7457, 9: 3.6797, 10: 3.5499, 11: 3.5118},\n",
        "    6: {0: 3.7452, 1: 3.8372, 2: 3.8257, 3: 3.8062, 4: 3.8008, 5: 3.7704, 6: 3.7808, 7: 3.7537, 8: 3.7452, 9: 3.6729, 10: 3.5400, 11: 3.5013}\n",
        "}\n",
        "\n",
        "# Plot\n",
        "plt.figure(figsize=(10, 6))\n",
        "for epoch, ke in ke_epochs.items():\n",
        "    layers = [l + 1 for l in ke.keys()]  # shift layer indices to 1–12\n",
        "    values = list(ke.values())\n",
        "    plt.plot(layers, values, marker='o', label=f\"Epoch {epoch}\")\n",
        "\n",
        "#plt.title(\"Knowledge Entropy vs Layer Index\", fontsize=16)\n",
        "plt.xlabel(\"Layer Index\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.legend()\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "ryiCc-_TQ6CE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level KE\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level KE\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune MRPC Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] MRPC Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] MRPC Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"mrpc\", split=\"train\").shuffle(seed).select(range(1000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"mrpc\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "AFGs9h-f_Swa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Knowledge Entropy values per epoch\n",
        "ke_epochs = [\n",
        "    {0: 3.239791645050049, 1: 3.2986808376312258, 2: 3.253356773376465, 3: 3.228304588317871, 4: 3.204816005706787, 5: 3.1601281089782716, 6: 3.167900978088379, 7: 3.14154211807251, 8: 3.136232151031494, 9: 3.108181224822998, 10: 3.070255741119385, 11: 3.150283061981201},\n",
        "    {0: 3.2390173606872557, 1: 3.2963816051483152, 2: 3.2457099170684813, 3: 3.224952474594116, 4: 3.199368993759155, 5: 3.1539372539520265, 6: 3.149820531845093, 7: 3.1081827564239504, 8: 3.0931982421875, 9: 3.0528202323913574, 10: 3.028363660812378, 11: 3.050001382827759},\n",
        "    {0: 3.239009340286255, 1: 3.2983984470367433, 2: 3.2463859825134276, 3: 3.2241379623413087, 4: 3.1976428546905518, 5: 3.1447392463684083, 6: 3.1279000110626223, 7: 3.0747772026062012, 8: 3.0223362197875976, 9: 2.9473570098876953, 10: 2.87524440574646, 11: 2.8674370727539062},\n",
        "    {0: 3.238827512741089, 1: 3.2974642314910887, 2: 3.2431242713928223, 3: 3.2216903343200682, 4: 3.199512315750122, 5: 3.1563696422576903, 6: 3.134328540802002, 7: 3.0738258819580078, 8: 2.99842280960083, 9: 2.912938259124756, 10: 2.8102419834136962, 11: 2.771777828216553},\n",
        "    {0: 3.2383913173675536, 1: 3.296243072509766, 2: 3.242586082458496, 3: 3.221456657409668, 4: 3.2004083824157714, 5: 3.152312059402466, 6: 3.1258499088287355, 7: 3.064895736694336, 8: 2.980252359390259, 9: 2.860691904067993, 10: 2.749246000289917, 11: 2.710493507385254},\n",
        "    {0: 3.238114908218384, 1: 3.2962895565032957, 2: 3.243076961517334, 3: 3.221760944366455, 4: 3.2005157108306883, 5: 3.1507289390563966, 6: 3.123993368148804, 7: 3.0618533477783205, 8: 2.9700901947021485, 9: 2.843407918930054, 10: 2.727255252838135, 11: 2.6856814765930177}\n",
        "]\n",
        "\n",
        "# Plotting\n",
        "plt.figure(figsize=(10, 6))\n",
        "for i, epoch_ke in enumerate(ke_epochs):\n",
        "    layers = [l + 1 for l in epoch_ke.keys()]  # shift layer indices to 1–12\n",
        "    values = list(epoch_ke.values())\n",
        "    plt.plot(layers, values, label=f\"Epoch {i+1}\", marker='o')\n",
        "\n",
        "plt.xlabel(\"Layer\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "#plt.title(\"Knowledge Entropy vs Layers\", fontsize=16)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "xHDCg2DUDvH4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune SST-2 Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] SST-2 Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] SST-2 Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"sst2\", split=\"train\").shuffle(seed).select(range(5000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"sst2\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "e0e5CB8l49O9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# KE values for each epoch\n",
        "ke_by_epoch = {\n",
        "    1: [2.9651833793640137, 3.089046128463745, 3.0924095928192137, 3.0681774433135987, 3.0340894985198976, 2.9792041194915773, 2.9694000019073488, 2.9185178718566895, 2.9200110771179197, 2.82025255355835, 2.667954217147827, 2.6612180709838866],\n",
        "    2: [2.965893549346924, 3.085789897155762, 3.092639580154419, 3.078977996826172, 3.0649560813903807, 3.0321216072082517, 3.018477940368652, 2.9576354961395266, 2.911316421508789, 2.782052672576904, 2.6040575157165526, 2.5998096366882324],\n",
        "    3: [2.9650946979522703, 3.084409861755371, 3.0894264568328857, 3.077181778717041, 3.066806471252441, 3.0330316158294677, 3.026228702163696, 2.9687501598358152, 2.918505425262451, 2.781764380264282, 2.6073847118377684, 2.5956354915618896],\n",
        "    4: [2.9647007961273193, 3.0866590961456297, 3.088959024810791, 3.076308888244629, 3.0617130882263184, 3.0239853103637695, 3.0255811374664305, 2.9724936283111574, 2.91553058013916, 2.779527521133423, 2.608109250640869, 2.599828709793091],\n",
        "    5: [2.9640205013275147, 3.089625425720215, 3.0891447685241697, 3.0789716785430907, 3.066036195373535, 3.0240665218353273, 3.0201424140930175, 2.9597047756195067, 2.888725273895264, 2.7510012699127198, 2.584951690673828, 2.5864554821014405],\n",
        "    6: [2.963668116760254, 3.090019404220581, 3.089617932128906, 3.0806209617614746, 3.0694873622894288, 3.0306819889068604, 3.0209427997589113, 2.956924412536621, 2.8712899379730223, 2.7296667186737062, 2.57259878578186, 2.5810291679382322]\n",
        "}\n",
        "\n",
        "plt.figure(figsize=(10, 6))\n",
        "for epoch, ke_values in ke_by_epoch.items():\n",
        "    plt.plot(range(1,13), ke_values, marker='o', label=f'Epoch {epoch}')\n",
        "\n",
        "#plt.title(\"Knowledge Entropy vs. Layer\", fontsize=16)\n",
        "plt.xlabel(\"Layer Index\", fontsize=16)\n",
        "plt.ylabel(\"Approx. Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "lh8kP42oO0ZO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(\n",
        "        examples['sentence'],\n",
        "        truncation=True,\n",
        "        padding='max_length',\n",
        "        max_length=max_length\n",
        "    )\n",
        "\n",
        "\n",
        "from sklearn.metrics import matthews_corrcoef\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    acc_metric = evaluate.load(\"accuracy\")\n",
        "    mcc_metric = evaluate.load(\"matthews_correlation\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    acc = acc_metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "    mcc = mcc_metric.compute(predictions=preds, references=labs)[\"matthews_correlation\"]\n",
        "    return acc, mcc\n",
        "\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level KE\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level KE\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "        acc, mcc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"-> Full Finetune CoLA Acc: {acc:.4f} | MCC: {mcc:.4f}\")\n",
        "\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "\n",
        "        acc, mcc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"-> Full Finetune CoLA Acc: {acc:.4f} | MCC: {mcc:.4f}\")\n",
        "\n",
        "\n",
        "    return model\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "        acc, mcc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"-> Full Finetune CoLA Acc: {acc:.4f} | MCC: {mcc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"cola\", split=\"train\").shuffle(seed).select(range(1000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"cola\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "VLk75h1ionEm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Knowledge Entropy values over 6 epochs\n",
        "ke_values = [\n",
        "    {0: 2.9448, 1: 3.0885, 2: 3.1181, 3: 3.0932, 4: 3.0441, 5: 2.9380, 6: 2.9606, 7: 2.9447, 8: 2.9687, 9: 2.9775, 10: 2.9016, 11: 2.9350},\n",
        "    {0: 2.9439, 1: 3.0961, 2: 3.1100, 3: 3.0816, 4: 3.0435, 5: 2.9906, 6: 3.0009, 7: 2.9850, 8: 2.9817, 9: 2.9741, 10: 2.8645, 11: 2.8777},\n",
        "    {0: 2.9447, 1: 3.0925, 2: 3.1075, 3: 3.0703, 4: 3.0372, 5: 2.9933, 6: 3.0018, 7: 2.9899, 8: 2.9857, 9: 2.9657, 10: 2.8410, 11: 2.8267},\n",
        "    {0: 2.9449, 1: 3.0911, 2: 3.1028, 3: 3.0704, 4: 3.0389, 5: 2.9948, 6: 2.9996, 7: 2.9699, 8: 2.9628, 9: 2.9423, 10: 2.8048, 11: 2.7881},\n",
        "    {0: 2.9457, 1: 3.0933, 2: 3.1071, 3: 3.0762, 4: 3.0477, 5: 3.0011, 6: 3.0038, 7: 2.9738, 8: 2.9548, 9: 2.9251, 10: 2.7719, 11: 2.7567},\n",
        "    {0: 2.9456, 1: 3.0939, 2: 3.1071, 3: 3.0768, 4: 3.0461, 5: 3.0022, 6: 3.0023, 7: 2.9735, 8: 2.9553, 9: 2.9216, 10: 2.7640, 11: 2.7494},\n",
        "]\n",
        "\n",
        "plt.figure(figsize=(10, 6))\n",
        "for i, epoch_ke in enumerate(ke_values, start=1):\n",
        "    layers = [l + 1 for l in epoch_ke.keys()]  # shift layer indices to 1–12\n",
        "    entropy = list(epoch_ke.values())\n",
        "    plt.plot(layers, entropy, marker='o', label=f\"Epoch {i}\")\n",
        "\n",
        "plt.xlabel(\"Layer Index\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "#plt.title(\"Knowledge Entropy vs. Layer Index (CoLA)\", fontsize=16)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "cVv85s_wpaN-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "DNR1BHeYpaq6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['question'],\n",
        "               examples['sentence'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level KE\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level KE\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune QNLI Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] QNLI Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] QNLI Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"qnli\", split=\"train\").shuffle(seed).select(range(2000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"qnli\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"question\",\"sentence\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"question\",\"sentence\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "M5lgTIkOonfn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Knowledge Entropy (KE) values across layers for each epoch\n",
        "ke_epochs = {\n",
        "    1: {0: 3.2304, 1: 3.2924, 2: 3.2409, 3: 3.1981, 4: 3.1778, 5: 3.1335, 6: 3.1346, 7: 3.0527, 8: 3.0132, 9: 2.9415, 10: 2.8503, 11: 2.8969},\n",
        "    2: {0: 3.2314, 1: 3.2867, 2: 3.2390, 3: 3.2144, 4: 3.1975, 5: 3.1637, 6: 3.1601, 7: 3.0645, 8: 2.9672, 9: 2.8309, 10: 2.6973, 11: 2.6584},\n",
        "    3: {0: 3.2336, 1: 3.2851, 2: 3.2359, 3: 3.2133, 4: 3.2056, 5: 3.1733, 6: 3.1724, 7: 3.0849, 8: 2.9905, 9: 2.8596, 10: 2.7224, 11: 2.6792},\n",
        "    4: {0: 3.2347, 1: 3.2857, 2: 3.2389, 3: 3.2136, 4: 3.2055, 5: 3.1750, 6: 3.1752, 7: 3.0889, 8: 2.9981, 9: 2.8731, 10: 2.7439, 11: 2.6969},\n",
        "    5: {0: 3.2345, 1: 3.2861, 2: 3.2386, 3: 3.2143, 4: 3.2060, 5: 3.1730, 6: 3.1725, 7: 3.0911, 8: 2.9987, 9: 2.8678, 10: 2.7374, 11: 2.6870},\n",
        "    6: {0: 3.2349, 1: 3.2869, 2: 3.2387, 3: 3.2147, 4: 3.2044, 5: 3.1738, 6: 3.1718, 7: 3.0958, 8: 3.0050, 9: 2.8763, 10: 2.7463, 11: 2.6961}\n",
        "}\n",
        "\n",
        "# Plot\n",
        "plt.figure(figsize=(10, 6))\n",
        "for epoch, values in ke_epochs.items():\n",
        "    layers = [l + 1 for l in values.keys()]  # shift layer indices to 1–12\n",
        "    entropy = list(values.values())\n",
        "    plt.plot(layers, entropy, marker='o', label=f\"Epoch {epoch}\")\n",
        "\n",
        "plt.xlabel(\"Layer Index\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "#plt.title(\"Knowledge Entropy vs. Layer Index\", fontsize=16)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "t35oIk4Cdoed"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Q45PT1_AfhaE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "VCeLVFz-fh4i"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(\n",
        "        examples['question1'],\n",
        "        examples['question2'],\n",
        "        truncation=True,\n",
        "        padding='max_length',\n",
        "        max_length=max_length\n",
        "    )\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level KE\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level KE\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune QQP Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] QQP Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] QQP Acc: {acc:.4f}\")\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"qqp\", split=\"train\").shuffle(seed).select(range(3000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"qqp\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"question1\",\"question2\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"question1\",\"question2\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "md4igvXakGTo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# KE data per epoch\n",
        "ke_epochs = {\n",
        "    1: {0: 3.1468, 1: 3.2210, 2: 3.1869, 3: 3.1624, 4: 3.1201, 5: 3.0795, 6: 3.1073, 7: 3.0597, 8: 3.0786, 9: 3.0529, 10: 2.9826, 11: 2.9888},\n",
        "    2: {0: 3.1473, 1: 3.2196, 2: 3.1859, 3: 3.1655, 4: 3.1310, 5: 3.1000, 6: 3.1146, 7: 3.0777, 8: 3.0792, 9: 3.0537, 10: 2.9414, 11: 2.9214},\n",
        "    3: {0: 3.1469, 1: 3.2201, 2: 3.1880, 3: 3.1641, 4: 3.1386, 5: 3.1055, 6: 3.1073, 7: 3.0607, 8: 3.0420, 9: 2.9884, 10: 2.8447, 11: 2.8140},\n",
        "    4: {0: 3.1465, 1: 3.2205, 2: 3.1869, 3: 3.1643, 4: 3.1412, 5: 3.1092, 6: 3.1093, 7: 3.0606, 8: 3.0383, 9: 2.9692, 10: 2.8255, 11: 2.7930},\n",
        "    5: {0: 3.1463, 1: 3.2201, 2: 3.1878, 3: 3.1660, 4: 3.1404, 5: 3.1064, 6: 3.1102, 7: 3.0624, 8: 3.0388, 9: 2.9652, 10: 2.8183, 11: 2.7961},\n",
        "    6: {0: 3.1460, 1: 3.2191, 2: 3.1884, 3: 3.1685, 4: 3.1425, 5: 3.1084, 6: 3.1098, 7: 3.0626, 8: 3.0334, 9: 2.9510, 10: 2.8009, 11: 2.7761},\n",
        "}\n",
        "\n",
        "# Plotting\n",
        "plt.figure(figsize=(10, 6))\n",
        "for epoch, ke in ke_epochs.items():\n",
        "    layers = [l + 1 for l in ke.keys()]  # shift layer indices to 1–12\n",
        "    values = list(ke.values())\n",
        "    plt.plot(layers, values, marker='o', label=f\"Epoch {epoch}\")\n",
        "\n",
        "#plt.title(\"Knowledge Entropy vs Layers\", fontsize=16)\n",
        "plt.xlabel(\"Layer Index\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "plt.xticks(fontsize=16)\n",
        "plt.yticks(fontsize=16)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "bq_y4H85ghyG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(\n",
        "        examples['sentence1'],\n",
        "        examples['sentence2'],\n",
        "        truncation=True,\n",
        "        padding='max_length',\n",
        "        max_length=max_length\n",
        "    )\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level KE\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level KE\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune RTE Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] RTE Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] RTE Acc: {acc:.4f}\")\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"rte\", split=\"train\").shuffle(seed)\n",
        "    dev_ds   = load_dataset(\"glue\", \"rte\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "mT_gG5aqu50v"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# KE data for each epoch\n",
        "ke_data = {\n",
        "    1: {0: 3.234637494270618, 1: 3.2954297256775393, 2: 3.2472059474541592, 3: 3.2122703920572233, 4: 3.1996606771762552, 5: 3.166536705616193, 6: 3.192006678917469, 7: 3.1797323601368146, 8: 3.1996014416217804, 9: 3.198551667042268, 10: 3.1776862190319943, 11: 3.233222233179288},\n",
        "    2: {0: 3.2364144898377933, 1: 3.292912809512554, 2: 3.241034468779197, 3: 3.214976036395782, 4: 3.198220124611488, 5: 3.16093972325325, 6: 3.176389996057902, 7: 3.153661161661148, 8: 3.156559161650829, 9: 3.124322636769368, 10: 3.0624067340141687, 11: 3.0741044863676414},\n",
        "    3: {0: 3.2363225794755497, 1: 3.2945617039998374, 2: 3.2429233766519108, 3: 3.214366410023127, 4: 3.2025416669173117, 5: 3.16937870092881, 6: 3.180643222270868, 7: 3.16009650627772, 8: 3.1594576827990704, 9: 3.1243840715824027, 10: 3.053374951466536, 11: 3.0580264222927585},\n",
        "    4: {0: 3.2360405830236583, 1: 3.2945619431825786, 2: 3.2415615954460244, 3: 3.2135055791109037, 4: 3.2035804826479692, 5: 3.1748380936109104, 6: 3.1875738394566073, 7: 3.1665822023000474, 8: 3.1637751505925107, 9: 3.1298167407512665, 10: 3.052433254627081, 11: 3.038477917512258},\n",
        "    5: {0: 3.2361522951187234, 1: 3.294039017114884, 2: 3.241480989333911, 3: 3.2140066035282917, 4: 3.204068192304709, 5: 3.175252984731625, 6: 3.1887962214457684, 7: 3.1690831688734202, 8: 3.1620407494214864, 9: 3.1216949025789895, 10: 3.035130114127428, 11: 3.0117857922346163},\n",
        "    6: {0: 3.2358735440633235, 1: 3.2946181755799513, 2: 3.2427005072434745, 3: 3.213492139791831, 4: 3.204735367726057, 5: 3.174306563077829, 6: 3.1876255694108133, 7: 3.167302029255109, 8: 3.1590061837281938, 9: 3.115201425093871, 10: 3.0225293246599345, 11: 2.994033750051107}\n",
        "}\n",
        "\n",
        "# Plot\n",
        "plt.figure(figsize=(10, 6))\n",
        "for epoch, layer_data in ke_data.items():\n",
        "    layers = [l + 1 for l in layer_data.keys()]  # shift layer indices to 1–12\n",
        "    ke_values = list(layer_data.values())\n",
        "    plt.plot(layers, ke_values, marker='o', label=f\"Epoch {epoch}\")\n",
        "\n",
        "plt.xlabel(\"Layer Index\", fontsize=16)\n",
        "plt.ylabel(\"Knowledge Entropy\", fontsize=16)\n",
        "#plt.title(\"Knowledge Entropy vs. Layer\", fontsize=16)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "OwQGPSMJhDZ_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ========================================================\n",
        "# 2) Knowledge Entropy / Hook Utilities\n",
        "# ========================================================\n",
        "def register_ke_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'pre_act': None} for i in range(len(layers))}\n",
        "    hooks = []\n",
        "    for i, layer in enumerate(layers):\n",
        "        def hook_ffn_input(module, input, output, idx=i):\n",
        "            activations[idx]['pre_act'] = input[0].detach()\n",
        "        hooks.append(layer.intermediate.dense.register_forward_hook(hook_ffn_input))\n",
        "    return hooks, activations\n",
        "\n",
        "def compute_batch_knowledge_entropy(activations, activation_fn=F.relu, eps=1e-8):\n",
        "    ke_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        pre_act = buf['pre_act']\n",
        "        if pre_act is None:\n",
        "            continue\n",
        "        act = activation_fn(pre_act)\n",
        "        probs = act / (act.sum(dim=1, keepdim=True) + eps)\n",
        "        entropy = -torch.sum(probs * torch.log(probs + eps), dim=1).mean()\n",
        "        ke_scores[idx] = entropy.item()\n",
        "        buf['pre_act'] = None\n",
        "    return ke_scores\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune high‐KE)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_ke_layers(model, ke_scores, num_prune=4):\n",
        "    sorted_layers = sorted(ke_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx for idx, _ in sorted_layers[:num_prune]]\n",
        "    for idx in prune_idxs:\n",
        "        model.roberta.encoder.layer[idx].intermediate.dense = nn.Identity()\n",
        "        model.roberta.encoder.layer[idx].output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(\n",
        "        examples['sentence1'],\n",
        "        examples['sentence2'],\n",
        "        truncation=True,\n",
        "        padding='max_length',\n",
        "        max_length=max_length\n",
        "    )\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"pearsonr\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            # Regression head: output is [B, 1]\n",
        "            pred = out.logits.view(-1).cpu().numpy()\n",
        "            preds.extend(pred)\n",
        "    return metric.compute(predictions=preds, references=labs)[\"pearsonr\"]\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using KE instead of ER)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & KE Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=1  # num_labels=1 for regression!\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_ke_hooks(model)\n",
        "    last_ke = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        ke_sums, ke_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device, dtype=torch.float).unsqueeze(1))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level KE\n",
        "            batch_ke = compute_batch_knowledge_entropy(activations)\n",
        "            for idx, v in batch_ke.items():\n",
        "                ke_sums[idx]   += v\n",
        "                ke_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level KE\n",
        "        epoch_ke = {idx: ke_sums[idx]/ke_counts[idx]\n",
        "                    for idx in ke_sums if ke_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Knowledge Entropy:\", epoch_ke)\n",
        "        last_ke = epoch_ke\n",
        "\n",
        "    pearson = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune STS-B Pearson: {pearson:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_ke\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐KE) & Finetuning ===\")\n",
        "    prune_idxs = prune_ke_layers(model, ke_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐KE):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device, dtype=torch.float).unsqueeze(1))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        pearson = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] STS-B Pearson: {pearson:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device, dtype=torch.float).unsqueeze(1))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        pearson = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] STS-B Pearson: {pearson:.4f}\")\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"stsb\", split=\"train\").shuffle(seed)\n",
        "    dev_ds   = load_dataset(\"glue\", \"stsb\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, ke_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, ke_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "oCixrJGeu6Qn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Provided KE values\n",
        "ke_epochs = {\n",
        "    1: {0: 3.1216782207117624, 1: 3.205926341373831, 2: 3.1831439436725515, 3: 3.1425496576890164, 4: 3.108628351930451,\n",
        "        5: 3.0245226891879744, 6: 3.0113194900693085, 7: 2.96199728318481, 8: 2.9216685882032496, 9: 2.9011312645234386,\n",
        "        10: 2.7949637216056007, 11: 2.8036840522405337},\n",
        "    2: {0: 3.122385612946724, 1: 3.2053615688780255, 2: 3.1827276502431516, 3: 3.14758318588034, 4: 3.128111186047423,\n",
        "        5: 3.0808556003597083, 6: 3.0710045684527953, 7: 3.012045979334018, 8: 2.9473394376014634, 9: 2.890920449032737,\n",
        "        10: 2.706535710744632, 11: 2.694957201271959},\n",
        "    3: {0: 3.1233545473785824, 1: 3.2060421793119303, 2: 3.180608749721246, 3: 3.149725641096749, 4: 3.1337831318792944,\n",
        "        5: 3.087708413518022, 6: 3.0823230342175267, 7: 3.0231315847563978, 8: 2.9449381267908383, 9: 2.8736855148108513,\n",
        "        10: 2.6799341229636413, 11: 2.6526868296929624},\n",
        "    4: {0: 3.122849219697573, 1: 3.208289318190828, 2: 3.1786043458256836, 3: 3.152634641225547, 4: 3.140414782127518,\n",
        "        5: 3.0983164837695294, 6: 3.087455921942402, 7: 3.0256239744486164, 8: 2.950759364434509, 9: 2.883653296548899,\n",
        "        10: 2.686351388817205, 11: 2.6519999305130875},\n",
        "    5: {0: 3.1224169508969832, 1: 3.2059260482416696, 2: 3.17878786678606, 3: 3.156504187365069, 4: 3.1431414029860196,\n",
        "        5: 3.0961787657545403, 6: 3.0855182536951524, 7: 3.0326405927766844, 8: 2.9571537298353725, 9: 2.8834909653298877,\n",
        "        10: 2.6865991667348252, 11: 2.644667492788259},\n",
        "    6: {0: 3.122367330651953, 1: 3.20639664672512, 2: 3.179138040343644, 3: 3.157353766604491, 4: 3.1413243549755454,\n",
        "        5: 3.0952708057965954, 6: 3.0864931734612986, 7: 3.033865492267635, 8: 2.9587372055637324, 9: 2.8831359905725726,\n",
        "        10: 2.6869357043413524, 11: 2.6412745341140473}\n",
        "}\n",
        "\n",
        "# Plot\n",
        "plt.figure(figsize=(10, 6))\n",
        "for epoch, ke in ke_epochs.items():\n",
        "    layers = [l + 1 for l in ke.keys()]  # shift layer indices to 1–12\n",
        "    values = list(ke.values())\n",
        "    plt.plot(layers, values, marker='o', label=f'Epoch {epoch}')\n",
        "\n",
        "plt.xlabel('Layer Index', fontsize=16)\n",
        "plt.ylabel('Knowledge Entropy', fontsize=16)\n",
        "#plt.title('Knowledge Entropy vs Layer (fz=16)', fontsize=16)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.legend(fontsize=12)\n",
        "plt.grid(True)\n",
        "plt.tight_layout()\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "xiNpM5vKvAxC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Etp_DiKHvBTa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Caj-xs3GQVEE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Dq83WlRPQVf8"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}