{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6ENR6EQ2b8R8"
      },
      "outputs": [],
      "source": [
        "!pip uninstall -y datasets\n",
        "!pip install datasets==2.18.0\n",
        "!pip install evaluate\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n"
      ],
      "metadata": {
        "id": "3BycbH80QnGO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ========================================================\n",
        "# 1) Imports, warnings, reproducibility\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy Rate hooks for RoBERTa encoder layers\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "        # Only proceed if all present, and batch sizes match!\n",
        "        valid = (X_prev is not None and Y_prev is not None and\n",
        "                 X_curr is not None and Y_curr is not None and\n",
        "                 X_prev.shape[0] == X_curr.shape[0] and\n",
        "                 Y_prev.shape[0] == Y_curr.shape[0])\n",
        "        if not valid:\n",
        "            # Roll forward to avoid getting stuck\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8)**2\n",
        "                cos_squares.append(c2.item())\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities (replace FFN with identity)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(model.roberta.encoder.layer)]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ========================================================\n",
        "# 4) Data Processing and Evaluation\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=128):\n",
        "    return tok(examples['premise'],\n",
        "               examples['hypothesis'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "# ========================================================\n",
        "# 5) Training & Pruning\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=3).to(device)\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    scaler = GradScaler()\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device), attention_mask=b['attention_mask'].to(device), labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx] += v\n",
        "                er_counts[idx] += 1\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx] for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"-> Epoch {epoch+1} e-SNLI Dev Acc: {acc:.4f}\")\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    for epoch in range(3):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device), attention_mask=b['attention_mask'].to(device), labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] e-SNLI Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# ========================================================\n",
        "# 6) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    # e-SNLI: train/validation paths must be set to your actual file locations\n",
        "    data_files = {\n",
        "        \"train\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_train.json\",\n",
        "        \"validation\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_valid.json\",\n",
        "        \"test\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_test.json\"\n",
        "    }\n",
        "    raw_datasets = load_dataset(\"json\", data_files=data_files)\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train_ds = raw_datasets[\"train\"].shuffle(seed=seed).select(range(5000))\n",
        "    dev_ds = raw_datasets[\"validation\"].shuffle(seed=seed).select(range(1000))\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=[\"premise\", \"hypothesis\", \"explanation_1\", \"explanation_2\", \"explanation_3\"])\n",
        "    train = train.rename_column(\"label\", \"labels\")\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=[\"premise\", \"hypothesis\", \"explanation_1\", \"explanation_2\", \"explanation_3\"])\n",
        "    dev = dev.rename_column(\"label\", \"labels\")\n",
        "    collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "TrHStP01R_ml"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# 1. Entropy Rate / Hook Utilities\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "        # Only proceed if all present, and batch sizes match!\n",
        "        valid = (X_prev is not None and Y_prev is not None and\n",
        "                 X_curr is not None and Y_curr is not None and\n",
        "                 X_prev.shape[0] == X_curr.shape[0] and\n",
        "                 Y_prev.shape[0] == Y_curr.shape[0])\n",
        "        if not valid:\n",
        "            # Roll forward to avoid getting stuck\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8)**2\n",
        "                cos_squares.append(c2.item())\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "    return er_scores\n",
        "\n",
        "# 2. Pruning Utilities with SkipFF (prune high-ER)\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(model.roberta.encoder.layer)]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# 3. Data Preprocessing/Eval\n",
        "def preprocess_function(examples, tok, max_length=128):\n",
        "    return tok(examples['premise'],\n",
        "               examples['hypothesis'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "# 4. Training/Evaluation Loops\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=3).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    scaler = GradScaler()\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device), attention_mask=b['attention_mask'].to(device), labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx] += v\n",
        "                er_counts[idx] += 1\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx] for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune Acc: {acc:.4f}\")\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    for epoch in range(3):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device), attention_mask=b['attention_mask'].to(device), labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# 5. Main Entrypoint (example: e-SNLI or MNLI format)\n",
        "def main():\n",
        "    data_files = {\n",
        "        \"train\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_train.json\",\n",
        "        \"validation\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_valid.json\",\n",
        "        \"test\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_test.json\"\n",
        "    }\n",
        "    raw_datasets = load_dataset(\"json\", data_files=data_files)\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    # For e-SNLI: 0=entailment, 1=neutral, 2=contradiction\n",
        "\n",
        "    train_ds = raw_datasets[\"train\"].shuffle(seed=42).select(range(5000))\n",
        "    dev_ds = raw_datasets[\"validation\"].shuffle(seed=42).select(range(1000))\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=train_ds.column_names)\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=dev_ds.column_names)\\\n",
        "                .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "GLLXVWqXQngP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ---- 1. Load CQA Data ----\n",
        "data_files = {\n",
        "    \"train\": \"/content/drive/MyDrive/NLP_datasets/cqa/cqa_train.json\",\n",
        "    \"test\":  \"/content/drive/MyDrive/NLP_datasets/cqa/cqa_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# ---- 2. Preprocessing for RoBERTa ----\n",
        "def preprocess_cqa_roberta(batch, tokenizer, max_length=128):\n",
        "    # For CQA: input = question + choices, target = answer (classification index)\n",
        "    sentences = [f\"question: {q} choices: {', '.join(choices)}\"\n",
        "                 for q, choices in zip(batch['question'], batch['choices'])]\n",
        "    # Convert answer (text) to integer label (0, 1, 2, 3)\n",
        "    answer_map = {str(i): i for i in range(4)}\n",
        "    labels = [answer_map.get(str(a), 0) for a in batch[\"answer\"]]\n",
        "    toks = tokenizer(\n",
        "        sentences, padding=\"max_length\", truncation=True, max_length=max_length, return_tensors=\"pt\"\n",
        "    )\n",
        "    toks[\"labels\"] = torch.tensor(labels)\n",
        "    return toks\n",
        "\n",
        "tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_cqa_roberta(ex, tokenizer),\n",
        "                            batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev = dataset[\"test\"].map(lambda ex: preprocess_cqa_roberta(ex, tokenizer),\n",
        "                          batched=True, remove_columns=dataset[\"test\"].column_names)\n",
        "\n",
        "collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# ---- 3. ER Hook Utilities for RoBERTa Encoder ----\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "        valid = (X_prev is not None and Y_prev is not None and\n",
        "                 X_curr is not None and Y_curr is not None and\n",
        "                 X_prev.shape[0] == X_curr.shape[0] and\n",
        "                 Y_prev.shape[0] == Y_curr.shape[0])\n",
        "        if not valid:\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = [\n",
        "                F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                for i in range(1, B)\n",
        "            ]\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "    return er_scores\n",
        "\n",
        "# ---- 4. Pruning Utilities for Encoder ----\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # Prune highest-ER layers (skip FFN and Output)\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(model.roberta.encoder.layer)]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ---- 5. Training/Evaluation Pipeline ----\n",
        "def compute_accuracy(preds, labels):\n",
        "    return (np.array(preds) == np.array(labels)).mean()\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            ids = batch['input_ids'].to(device)\n",
        "            mask = batch['attention_mask'].to(device)\n",
        "            labs.extend(batch['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return compute_accuracy(preds, labs)\n",
        "\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=4).to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    scaler = GradScaler()\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), labels=batch['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx] += v\n",
        "                er_counts[idx] += 1\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx] for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"-> Epoch {epoch+1} CQA Dev Acc: {acc:.4f}\")\n",
        "        last_er = epoch_er\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest-ER):\", prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    for epoch in range(3):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), labels=batch['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] CQA Dev Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# ---- 6. Entrypoint ----\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "gCQ6RnNZT3cf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import numpy as np\n",
        "import random\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Load ANLI1 Dataset ---\n",
        "data_files = {\n",
        "    \"train\":      \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_train.json\",\n",
        "    \"validation\": \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_valid.json\",\n",
        "    \"test\":       \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Preprocessing Function for RoBERTa ---\n",
        "def preprocess_anli_roberta(batch, tokenizer, max_length=128):\n",
        "    sentences = [\n",
        "        f\"nli premise: {p} hypothesis: {h}\" for p, h in zip(batch[\"premise\"], batch[\"hypothesis\"])\n",
        "    ]\n",
        "    # ANLI label: 0=entailment, 1=neutral, 2=contradiction\n",
        "    labels = [int(x) if str(x).isdigit() and int(x) < 3 else 0 for x in batch[\"label\"]]\n",
        "    toks = tokenizer(\n",
        "        sentences, padding=\"max_length\", truncation=True, max_length=max_length, return_tensors=\"pt\"\n",
        "    )\n",
        "    toks[\"labels\"] = torch.tensor(labels)\n",
        "    return toks\n",
        "\n",
        "tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_anli_roberta(ex, tokenizer),\n",
        "                            batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev   = dataset[\"validation\"].map(lambda ex: preprocess_anli_roberta(ex, tokenizer),\n",
        "                                  batched=True, remove_columns=dataset[\"validation\"].column_names)\n",
        "\n",
        "collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 3. ER Hook Utilities for Encoder Only ---\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "        valid = (X_prev is not None and Y_prev is not None and\n",
        "                 X_curr is not None and Y_curr is not None and\n",
        "                 X_prev.shape[0] == X_curr.shape[0] and\n",
        "                 Y_prev.shape[0] == Y_curr.shape[0])\n",
        "        if not valid:\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = [\n",
        "                F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                for i in range(1, B)\n",
        "            ]\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "    return er_scores\n",
        "\n",
        "# --- 4. Pruning Utilities for Encoder Only ---\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(model.roberta.encoder.layer)]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 5. Training/Eval/ER Pipeline ---\n",
        "def compute_accuracy(preds, labels):\n",
        "    return (np.array(preds) == np.array(labels)).mean() if len(preds) else 0\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            ids = batch['input_ids'].to(device)\n",
        "            mask = batch['attention_mask'].to(device)\n",
        "            labs.extend(batch['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return compute_accuracy(preds, labs)\n",
        "\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=3).to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    scaler = GradScaler()\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), labels=batch['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx] += v\n",
        "                er_counts[idx] += 1\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx] for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"-> Epoch {epoch+1} ANLI1 Dev Acc: {acc:.4f}\")\n",
        "        last_er = epoch_er\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest-ER):\", prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    for epoch in range(3):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device), labels=batch['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] ANLI1 Dev Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 6. Entrypoint ---\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "BXiqYqLsTq07"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- Mount Google Drive if using Colab ---\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "import random, numpy as np, torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from datasets import load_dataset\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings, math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "SEED = 42\n",
        "random.seed(SEED)\n",
        "np.random.seed(SEED)\n",
        "torch.manual_seed(SEED)\n",
        "\n",
        "# --- 1. Load ANLI1 Dataset ---\n",
        "data_files = {\n",
        "    \"train\":      \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_train.json\",\n",
        "    \"validation\": \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_valid.json\",\n",
        "    \"test\":       \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Preprocessing for RoBERTa ---\n",
        "tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "def preprocess_anli_roberta(batch, tokenizer, max_length=128):\n",
        "    # RoBERTa: [CLS] nli premise: ... hypothesis: ... [SEP]\n",
        "    sentences = [\n",
        "        f\"nli premise: {p} hypothesis: {h}\" for p, h in zip(batch[\"premise\"], batch[\"hypothesis\"])\n",
        "    ]\n",
        "    # ANLI label: 0=entailment, 1=neutral, 2=contradiction\n",
        "    labels = [int(x) if str(x).isdigit() and int(x) < 3 else 0 for x in batch[\"label\"]]\n",
        "    toks = tokenizer(\n",
        "        sentences, padding=\"max_length\", truncation=True, max_length=max_length\n",
        "    )\n",
        "    toks[\"labels\"] = labels\n",
        "    return toks\n",
        "\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_anli_roberta(ex, tokenizer), batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev   = dataset[\"validation\"].map(lambda ex: preprocess_anli_roberta(ex, tokenizer), batched=True, remove_columns=dataset[\"validation\"].column_names)\n",
        "\n",
        "collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 3. ER Hook Utilities (Encoder Only) ---\n",
        "\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks: h.remove()\n",
        "\n",
        "\n",
        "\n",
        "def compute_batch_entropy(activations):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "        valid = (X_prev is not None and Y_prev is not None and\n",
        "                 X_curr is not None and Y_curr is not None and\n",
        "                 X_prev.shape[0] == X_curr.shape[0] and\n",
        "                 Y_prev.shape[0] == Y_curr.shape[0])\n",
        "        if not valid:\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8)**2\n",
        "                cos_squares.append(c2.item())\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# --- 4. Pruning Utilities (Last Layers) ---\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # Pick top-ER (last) layers as \"decoder\"\n",
        "    layers = model.roberta.encoder.layer\n",
        "    # sort by ER (descending), get indices\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(layers)]\n",
        "    for idx in prune_idxs:\n",
        "        layer = layers[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 5. Eval / Accuracy Helper ---\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            ids  = batch[\"input_ids\"].to(device)\n",
        "            mask = batch[\"attention_mask\"].to(device)\n",
        "            labels = batch[\"labels\"].cpu().numpy().tolist()\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy().tolist())\n",
        "            labs.extend([int(l) for l in labels])\n",
        "    acc = sum(p==l for p,l in zip(preds,labs)) / len(labs)\n",
        "    return acc\n",
        "\n",
        "# --- 6. Training + ER Tracking + Pruning ---\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=3  # 3-way NLI\n",
        "    ).to(device)\n",
        "    opt    = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
        "    sched  = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    scaler = GradScaler()\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_cnts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(\n",
        "                    input_ids=batch[\"input_ids\"].to(device),\n",
        "                    attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                    labels=batch[\"labels\"].to(device)\n",
        "                )\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for i,v in batch_er.items():\n",
        "                er_sums[i] += v\n",
        "                er_cnts[i] += 1\n",
        "        last_er = {i: er_sums[i]/er_cnts[i] for i in er_sums if er_cnts[i]>0}\n",
        "        print(f\"[Epoch {epoch+1}] approx ER: {last_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] ANLI1 Dev Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High-ER \" 'Decoder' \") & Finetune ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned last (top-ER) layers:\", prune_idxs)\n",
        "    opt   = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(3):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(\n",
        "                input_ids=batch[\"input_ids\"].to(device),\n",
        "                attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                labels=batch[\"labels\"].to(device)\n",
        "            )\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] ANLI1 Dev Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 7. Entrypoint ---\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    _ = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "E-JGDxYledma"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- Mount Google Drive if using Colab ---\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "import random, numpy as np, torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from datasets import load_dataset\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict, Counter\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "SEED = 42\n",
        "random.seed(SEED)\n",
        "np.random.seed(SEED)\n",
        "torch.manual_seed(SEED)\n",
        "\n",
        "# --- 1. Load SVAMP Dataset ---\n",
        "data_files = {\n",
        "    \"train\": \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_train.json\",\n",
        "    \"test\":  \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Inspect and Assign Column Names ---\n",
        "print(\"Train columns:\", dataset[\"train\"].column_names)\n",
        "print(\"Test columns:\", dataset[\"test\"].column_names)\n",
        "ques_key = \"input\"\n",
        "ans_key  = \"label\"\n",
        "\n",
        "# --- 3. Top-N Answer Classification Setup ---\n",
        "TOP_N = 100  # Or 50, 100, 200 for ablation\n",
        "train_answers = [str(x) for x in dataset[\"train\"][ans_key]]\n",
        "\n",
        "counts = Counter(train_answers)\n",
        "most_common_answers = [a for a, _ in counts.most_common(TOP_N)]\n",
        "label2idx = {ans: i+1 for i, ans in enumerate(most_common_answers)}  # Reserve 0 for \"other\"\n",
        "num_labels = len(label2idx) + 1  # +1 for \"other\"\n",
        "\n",
        "def map_label(x):\n",
        "    return label2idx.get(str(x), 0)\n",
        "\n",
        "# --- 4. Preprocessing for RoBERTa ---\n",
        "tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "def preprocess_svamp(batch, tokenizer, max_length=128):\n",
        "    questions = batch[ques_key]\n",
        "    enc = tokenizer(\n",
        "        questions, padding=\"max_length\", truncation=True, max_length=max_length\n",
        "    )\n",
        "    enc[\"labels\"] = [map_label(x) for x in batch[ans_key]]\n",
        "    return enc\n",
        "\n",
        "train = dataset[\"train\"].map(\n",
        "    lambda ex: preprocess_svamp(ex, tokenizer),\n",
        "    batched=True,\n",
        "    remove_columns=dataset[\"train\"].column_names\n",
        ")\n",
        "dev   = dataset[\"test\"].map(\n",
        "    lambda ex: preprocess_svamp(ex, tokenizer),\n",
        "    batched=True,\n",
        "    remove_columns=dataset[\"test\"].column_names\n",
        ")\n",
        "\n",
        "collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 5. ER Hook Utilities (Encoder Only) ---\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {i: {'prev': None, 'curr': None} for i in range(len(layers)-1)}\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_prev(module, inp, out, idx=i):\n",
        "            activations[idx]['prev'] = out.detach()\n",
        "        def hook_curr(module, inp, out, idx=i):\n",
        "            activations[idx]['curr'] = out.detach()\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_prev))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_curr))\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks: h.remove()\n",
        "\n",
        "\n",
        "\n",
        "def compute_batch_entropy(activations):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev = buf[\"prev\"]\n",
        "        X_curr = buf[\"curr\"]\n",
        "        if X_prev is None or X_curr is None or X_prev.shape[0] != X_curr.shape[0]:\n",
        "            buf[\"prev\"] = X_curr\n",
        "            buf[\"curr\"] = None\n",
        "            continue\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cosims = []\n",
        "            for i in range(B):\n",
        "                for j in range(i+1, B):\n",
        "                    cosims.append(F.cosine_similarity(dX[i].unsqueeze(0), dX[j].unsqueeze(0), dim=1).item())\n",
        "            er = float(np.mean(cosims)) if cosims else 0.0\n",
        "        er_scores[idx] = er\n",
        "        buf[\"prev\"] = X_curr\n",
        "        buf[\"curr\"] = None\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "\n",
        "# --- 6. Pruning Utilities (Last Layers) ---\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(layers)]\n",
        "    for idx in prune_idxs:\n",
        "        layer = layers[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 7. Eval / Accuracy Helper ---\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            ids  = batch[\"input_ids\"].to(device)\n",
        "            mask = batch[\"attention_mask\"].to(device)\n",
        "            labels = batch[\"labels\"].cpu().numpy().tolist()\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy().tolist())\n",
        "            labs.extend([int(l) for l in labels])\n",
        "    acc = sum(p==l for p,l in zip(preds,labs)) / len(labs)\n",
        "    return acc\n",
        "\n",
        "# --- 8. Training + ER Tracking + Pruning ---\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=num_labels\n",
        "    ).to(device)\n",
        "    opt    = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
        "    sched  = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    scaler = GradScaler()\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_cnts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(\n",
        "                    input_ids=batch[\"input_ids\"].to(device),\n",
        "                    attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                    labels=batch[\"labels\"].to(device)\n",
        "                )\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for i,v in batch_er.items():\n",
        "                er_sums[i] += v\n",
        "                er_cnts[i] += 1\n",
        "        last_er = {i: er_sums[i]/er_cnts[i] for i in er_sums if er_cnts[i]>0}\n",
        "        print(f\"[Epoch {epoch+1}] approx ER: {last_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] SVAMP Dev Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High-ER 'Decoder') & Finetune ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned last (top-ER) layers:\", prune_idxs)\n",
        "    opt   = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(3):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(\n",
        "                input_ids=batch[\"input_ids\"].to(device),\n",
        "                attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                labels=batch[\"labels\"].to(device)\n",
        "            )\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] SVAMP Dev Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 9. Entrypoint ---\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    _ = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "k4KfLJJ6dpmg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "daggMzy3D-bB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- Mount Google Drive if using Colab ---\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "import random, numpy as np, torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from datasets import load_dataset\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict, Counter\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "SEED = 42\n",
        "random.seed(SEED)\n",
        "np.random.seed(SEED)\n",
        "torch.manual_seed(SEED)\n",
        "\n",
        "# --- 1. Load SVAMP Dataset ---\n",
        "data_files = {\n",
        "    \"train\": \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_train.json\",\n",
        "    \"test\":  \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Assign SVAMP keys ---\n",
        "ques_key = \"input\"   # SVAMP uses \"input\" for questions\n",
        "ans_key  = \"label\"   # and \"label\" for answers\n",
        "\n",
        "# --- 3. Top-N Answer Classification Setup ---\n",
        "TOP_N = 100  # Try 50, 100, 200 for ablation\n",
        "train_answers = [str(x) for x in dataset[\"train\"][ans_key]]\n",
        "counts = Counter(train_answers)\n",
        "most_common_answers = [a for a, _ in counts.most_common(TOP_N)]\n",
        "label2idx = {ans: i+1 for i, ans in enumerate(most_common_answers)}  # 0 is \"other\"\n",
        "num_labels = len(label2idx) + 1\n",
        "\n",
        "def map_label(x):\n",
        "    return label2idx.get(str(x), 0)\n",
        "\n",
        "# --- 4. Preprocessing for RoBERTa ---\n",
        "tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "def preprocess_svamp(batch, tokenizer, max_length=128):\n",
        "    questions = batch[ques_key]\n",
        "    enc = tokenizer(\n",
        "        questions, padding=\"max_length\", truncation=True, max_length=max_length\n",
        "    )\n",
        "    enc[\"labels\"] = [map_label(x) for x in batch[ans_key]]\n",
        "    return enc\n",
        "\n",
        "train = dataset[\"train\"].map(\n",
        "    lambda ex: preprocess_svamp(ex, tokenizer),\n",
        "    batched=True,\n",
        "    remove_columns=dataset[\"train\"].column_names\n",
        ")\n",
        "dev   = dataset[\"test\"].map(\n",
        "    lambda ex: preprocess_svamp(ex, tokenizer),\n",
        "    batched=True,\n",
        "    remove_columns=dataset[\"test\"].column_names\n",
        ")\n",
        "\n",
        "collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 5. ER Hook Utilities (Encoder Only) ---\n",
        "\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks: h.remove()\n",
        "\n",
        "\n",
        "\n",
        "def compute_batch_entropy(activations):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "        valid = (X_prev is not None and Y_prev is not None and\n",
        "                 X_curr is not None and Y_curr is not None and\n",
        "                 X_prev.shape[0] == X_curr.shape[0] and\n",
        "                 Y_prev.shape[0] == Y_curr.shape[0])\n",
        "        if not valid:\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8)**2\n",
        "                cos_squares.append(c2.item())\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# --- 6. Pruning Utilities (Last Layers) ---\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(layers)]\n",
        "    for idx in prune_idxs:\n",
        "        layer = layers[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 7. Eval / Accuracy Helper ---\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            ids  = batch[\"input_ids\"].to(device)\n",
        "            mask = batch[\"attention_mask\"].to(device)\n",
        "            labels = batch[\"labels\"].cpu().numpy().tolist()\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().tolist())\n",
        "            labs.extend([int(l) for l in labels])\n",
        "    acc = sum(p==l for p,l in zip(preds,labs)) / len(labs)\n",
        "    return acc\n",
        "\n",
        "# --- 8. Training + ER Tracking + Pruning ---\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=num_labels\n",
        "    ).to(device)\n",
        "    opt    = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
        "    sched  = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    scaler = GradScaler()\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_cnts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(\n",
        "                    input_ids=batch[\"input_ids\"].to(device),\n",
        "                    attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                    labels=batch[\"labels\"].to(device)\n",
        "                )\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for i,v in batch_er.items():\n",
        "                er_sums[i] += v\n",
        "                er_cnts[i] += 1\n",
        "        last_er = {i: er_sums[i]/er_cnts[i] for i in er_sums if er_cnts[i]>0}\n",
        "        print(f\"[Epoch {epoch+1}] approx ER: {last_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] SVAMP Dev Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High-ER 'Decoder') & Finetune ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned last (top-ER) layers:\", prune_idxs)\n",
        "    opt   = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(3):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(\n",
        "                input_ids=batch[\"input_ids\"].to(device),\n",
        "                attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                labels=batch[\"labels\"].to(device)\n",
        "            )\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] SVAMP Dev Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 9. Entrypoint ---\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    _ = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "SCiowJvZEAE4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "SVkwVh2BIJ3W"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "24PGPF1OIKfz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Ablation\n",
        "\n",
        "# --- Mount Google Drive if using Colab ---\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "import random, numpy as np, torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from datasets import load_dataset\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict, Counter\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "SEED = 42\n",
        "random.seed(SEED)\n",
        "np.random.seed(SEED)\n",
        "torch.manual_seed(SEED)\n",
        "\n",
        "# --- 1. Load SVAMP Dataset ---\n",
        "data_files = {\n",
        "    \"train\": \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_train.json\",\n",
        "    \"test\":  \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Assign SVAMP keys ---\n",
        "ques_key = \"input\"   # SVAMP uses \"input\" for questions\n",
        "ans_key  = \"label\"   # and \"label\" for answers\n",
        "\n",
        "# --- 3. Top-N Answer Classification Setup ---\n",
        "TOP_N = 100  # Try 50, 100, 200 for ablation\n",
        "train_answers = [str(x) for x in dataset[\"train\"][ans_key]]\n",
        "counts = Counter(train_answers)\n",
        "most_common_answers = [a for a, _ in counts.most_common(TOP_N)]\n",
        "label2idx = {ans: i+1 for i, ans in enumerate(most_common_answers)}  # 0 is \"other\"\n",
        "num_labels = len(label2idx) + 1\n",
        "\n",
        "def map_label(x):\n",
        "    return label2idx.get(str(x), 0)\n",
        "\n",
        "# --- 4. Preprocessing for RoBERTa ---\n",
        "tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "def preprocess_svamp(batch, tokenizer, max_length=128):\n",
        "    questions = batch[ques_key]\n",
        "    enc = tokenizer(\n",
        "        questions, padding=\"max_length\", truncation=True, max_length=max_length\n",
        "    )\n",
        "    enc[\"labels\"] = [map_label(x) for x in batch[ans_key]]\n",
        "    return enc\n",
        "\n",
        "train = dataset[\"train\"].map(\n",
        "    lambda ex: preprocess_svamp(ex, tokenizer),\n",
        "    batched=True,\n",
        "    remove_columns=dataset[\"train\"].column_names\n",
        ")\n",
        "dev   = dataset[\"test\"].map(\n",
        "    lambda ex: preprocess_svamp(ex, tokenizer),\n",
        "    batched=True,\n",
        "    remove_columns=dataset[\"test\"].column_names\n",
        ")\n",
        "\n",
        "collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 5. ER Hook Utilities (Encoder Only) ---\n",
        "\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks: h.remove()\n",
        "\n",
        "\n",
        "\n",
        "def compute_batch_entropy(activations):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "        valid = (X_prev is not None and Y_prev is not None and\n",
        "                 X_curr is not None and Y_curr is not None and\n",
        "                 X_prev.shape[0] == X_curr.shape[0] and\n",
        "                 Y_prev.shape[0] == Y_curr.shape[0])\n",
        "        if not valid:\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8)**2\n",
        "                cos_squares.append(c2.item())\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# --- 6. Pruning Utilities (Last Layers) ---\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(layers)]\n",
        "    for idx in prune_idxs:\n",
        "        layer = layers[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 7. Eval / Accuracy Helper ---\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            ids  = batch[\"input_ids\"].to(device)\n",
        "            mask = batch[\"attention_mask\"].to(device)\n",
        "            labels = batch[\"labels\"].cpu().numpy().tolist()\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().tolist())\n",
        "            labs.extend([int(l) for l in labels])\n",
        "    acc = sum(p==l for p,l in zip(preds,labs)) / len(labs)\n",
        "    return acc\n",
        "\n",
        "# --- 8. Training + ER Tracking + Pruning ---\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=num_labels\n",
        "    ).to(device)\n",
        "    opt    = torch.optim.AdamW(model.parameters(), lr=2e-2)\n",
        "    sched  = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    scaler = GradScaler()\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_cnts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(\n",
        "                    input_ids=batch[\"input_ids\"].to(device),\n",
        "                    attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                    labels=batch[\"labels\"].to(device)\n",
        "                )\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for i,v in batch_er.items():\n",
        "                er_sums[i] += v\n",
        "                er_cnts[i] += 1\n",
        "        last_er = {i: er_sums[i]/er_cnts[i] for i in er_sums if er_cnts[i]>0}\n",
        "        print(f\"[Epoch {epoch+1}] approx ER: {last_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] SVAMP Dev Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High-ER 'Decoder') & Finetune ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned last (top-ER) layers:\", prune_idxs)\n",
        "    opt   = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(3):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(\n",
        "                input_ids=batch[\"input_ids\"].to(device),\n",
        "                attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                labels=batch[\"labels\"].to(device)\n",
        "            )\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] SVAMP Dev Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 9. Entrypoint ---\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    _ = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "FqgHsNTYINey"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "wrQZcS9uJwy1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Ablation\n",
        "\n",
        "# --- Mount Google Drive if using Colab ---\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "import random, numpy as np, torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from datasets import load_dataset\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict, Counter\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "SEED = 42\n",
        "random.seed(SEED)\n",
        "np.random.seed(SEED)\n",
        "torch.manual_seed(SEED)\n",
        "\n",
        "# --- 1. Load SVAMP Dataset ---\n",
        "data_files = {\n",
        "    \"train\": \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_train.json\",\n",
        "    \"test\":  \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Assign SVAMP keys ---\n",
        "ques_key = \"input\"   # SVAMP uses \"input\" for questions\n",
        "ans_key  = \"label\"   # and \"label\" for answers\n",
        "\n",
        "# --- 3. Top-N Answer Classification Setup ---\n",
        "TOP_N = 100  # Try 50, 100, 200 for ablation\n",
        "train_answers = [str(x) for x in dataset[\"train\"][ans_key]]\n",
        "counts = Counter(train_answers)\n",
        "most_common_answers = [a for a, _ in counts.most_common(TOP_N)]\n",
        "label2idx = {ans: i+1 for i, ans in enumerate(most_common_answers)}  # 0 is \"other\"\n",
        "num_labels = len(label2idx) + 1\n",
        "\n",
        "def map_label(x):\n",
        "    return label2idx.get(str(x), 0)\n",
        "\n",
        "# --- 4. Preprocessing for RoBERTa ---\n",
        "tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "def preprocess_svamp(batch, tokenizer, max_length=128):\n",
        "    questions = batch[ques_key]\n",
        "    enc = tokenizer(\n",
        "        questions, padding=\"max_length\", truncation=True, max_length=max_length\n",
        "    )\n",
        "    enc[\"labels\"] = [map_label(x) for x in batch[ans_key]]\n",
        "    return enc\n",
        "\n",
        "train = dataset[\"train\"].map(\n",
        "    lambda ex: preprocess_svamp(ex, tokenizer),\n",
        "    batched=True,\n",
        "    remove_columns=dataset[\"train\"].column_names\n",
        ")\n",
        "dev   = dataset[\"test\"].map(\n",
        "    lambda ex: preprocess_svamp(ex, tokenizer),\n",
        "    batched=True,\n",
        "    remove_columns=dataset[\"test\"].column_names\n",
        ")\n",
        "\n",
        "collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 5. ER Hook Utilities (Encoder Only) ---\n",
        "\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks: h.remove()\n",
        "\n",
        "\n",
        "\n",
        "def compute_batch_entropy(activations):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "        valid = (X_prev is not None and Y_prev is not None and\n",
        "                 X_curr is not None and Y_curr is not None and\n",
        "                 X_prev.shape[0] == X_curr.shape[0] and\n",
        "                 Y_prev.shape[0] == Y_curr.shape[0])\n",
        "        if not valid:\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8)**2\n",
        "                cos_squares.append(c2.item())\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# --- 6. Pruning Utilities (Last Layers) ---\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(layers)]\n",
        "    for idx in prune_idxs:\n",
        "        layer = layers[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 7. Eval / Accuracy Helper ---\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            ids  = batch[\"input_ids\"].to(device)\n",
        "            mask = batch[\"attention_mask\"].to(device)\n",
        "            labels = batch[\"labels\"].cpu().numpy().tolist()\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().tolist())\n",
        "            labs.extend([int(l) for l in labels])\n",
        "    acc = sum(p==l for p,l in zip(preds,labs)) / len(labs)\n",
        "    return acc\n",
        "\n",
        "# --- 8. Training + ER Tracking + Pruning ---\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=num_labels\n",
        "    ).to(device)\n",
        "    opt    = torch.optim.AdamW(model.parameters(), lr=2e-3)\n",
        "    sched  = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    scaler = GradScaler()\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_cnts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(\n",
        "                    input_ids=batch[\"input_ids\"].to(device),\n",
        "                    attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                    labels=batch[\"labels\"].to(device)\n",
        "                )\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for i,v in batch_er.items():\n",
        "                er_sums[i] += v\n",
        "                er_cnts[i] += 1\n",
        "        last_er = {i: er_sums[i]/er_cnts[i] for i in er_sums if er_cnts[i]>0}\n",
        "        print(f\"[Epoch {epoch+1}] approx ER: {last_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] SVAMP Dev Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High-ER 'Decoder') & Finetune ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned last (top-ER) layers:\", prune_idxs)\n",
        "    opt   = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(3):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(\n",
        "                input_ids=batch[\"input_ids\"].to(device),\n",
        "                attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                labels=batch[\"labels\"].to(device)\n",
        "            )\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] SVAMP Dev Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 9. Entrypoint ---\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    _ = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "MSloqrbrJ1N0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "4m3e5ZL8J0FH"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}