{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uioigcUKQDOd"
      },
      "outputs": [],
      "source": [
        "!pip uninstall -y datasets\n",
        "!pip install datasets==2.18.0\n",
        "!pip install evaluate\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Fan, A., Grave, E., & Joulin, A. (2019). \"Reducing Transformer Depth on Demand with Structured Dropout\"."
      ],
      "metadata": {
        "id": "yM0_mVQKTBbU"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "import random\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- LayerSkip utility for pruning entire transformer blocks ---\n",
        "class SkipLayer(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "\n",
        "    def forward(\n",
        "        self,\n",
        "        hidden_states,\n",
        "        attention_mask=None,\n",
        "        head_mask=None,\n",
        "        encoder_hidden_states=None,\n",
        "        encoder_attention_mask=None,\n",
        "        past_key_value=None,\n",
        "        output_attentions=False,\n",
        "        output_hidden_states=False,\n",
        "        return_dict=True,\n",
        "    ):\n",
        "        if return_dict:\n",
        "            from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions\n",
        "            return BaseModelOutputWithPastAndCrossAttentions(\n",
        "                last_hidden_state=hidden_states,\n",
        "                hidden_states=None,\n",
        "                attentions=None,\n",
        "                past_key_values=None,\n",
        "                cross_attentions=None\n",
        "            )\n",
        "        else:\n",
        "            return (hidden_states, None, None)\n",
        "\n",
        "# --- Data/Eval helpers ---\n",
        "def preprocess_function(examples, tok, max_length=128):\n",
        "    return tok(examples['premise'],\n",
        "               examples['hypothesis'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "# --- Standard fine-tuning ---\n",
        "def finetune_model(model, train_loader, dev_loader, device, epochs):\n",
        "    model.train()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*epochs)\n",
        "    scaler = GradScaler()\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] MNLI Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- LayerDrop-style layer pruning ---\n",
        "def layerdrop_prune(model, num_prune=4, seed=42):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    total_layers = len(layers)\n",
        "    hidden_size = layers[0].output.dense.out_features\n",
        "\n",
        "    # Randomly choose layers to drop\n",
        "    rng = np.random.default_rng(seed)\n",
        "    prune_idxs = rng.choice(total_layers, size=num_prune, replace=False)\n",
        "    prune_idxs = sorted(list(prune_idxs))\n",
        "    print(f\"LayerDrop: Pruning layers {prune_idxs} out of {total_layers}\")\n",
        "\n",
        "    for idx in prune_idxs:\n",
        "        # Replace the whole transformer block with identity\n",
        "        layers[idx] = SkipLayer(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- Main ---\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # Load MNLI (subsample for speed)\n",
        "    train_ds = load_dataset(\"glue\", \"mnli\", split=\"train\").shuffle(seed).select(range(2000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"mnli\", split=\"validation_matched\").select(range(1000))\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"premise\",\"hypothesis\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"premise\",\"hypothesis\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    # --- Stage 1: Full fine-tuning ---\n",
        "    print(\"\\n=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=3).to(device)\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=6)\n",
        "    acc_full = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}\")\n",
        "\n",
        "    # --- Stage 2: LayerDrop Pruning ---\n",
        "    print(\"\\n=== Stage 2: LayerDrop Pruning (Remove 4 layers) ===\")\n",
        "    prune_idxs = layerdrop_prune(model, num_prune=4, seed=seed)\n",
        "\n",
        "    # --- Stage 3: Fine-tune pruned model ---\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=5)\n",
        "    acc_pruned = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}\")\n",
        "\n",
        "    print(f\"Pruned layer indices: {prune_idxs}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "cXEzz9SrR7Mz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "import random\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# --- LayerSkip utility for pruning entire transformer blocks ---\n",
        "\n",
        "class SkipLayer(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "\n",
        "    def forward(\n",
        "        self,\n",
        "        hidden_states,\n",
        "        attention_mask=None,\n",
        "        head_mask=None,\n",
        "        encoder_hidden_states=None,\n",
        "        encoder_attention_mask=None,\n",
        "        past_key_value=None,\n",
        "        output_attentions=False,\n",
        "        output_hidden_states=False,\n",
        "        return_dict=True,\n",
        "    ):\n",
        "        if return_dict:\n",
        "            from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions\n",
        "            return BaseModelOutputWithPastAndCrossAttentions(\n",
        "                last_hidden_state=hidden_states,\n",
        "                hidden_states=None,\n",
        "                attentions=None,\n",
        "                past_key_values=None,\n",
        "                cross_attentions=None\n",
        "            )\n",
        "        else:\n",
        "            return (hidden_states, None, None)\n",
        "\n",
        "\n",
        "# --- Data/Eval helpers ---\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "# --- Standard fine-tuning ---\n",
        "def finetune_model(model, train_loader, dev_loader, device, epochs):\n",
        "    model.train()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*epochs)\n",
        "    scaler = GradScaler()\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] MRPC Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- LayerDrop-style layer pruning ---\n",
        "def layerdrop_prune(model, num_prune=4, seed=42):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    total_layers = len(layers)\n",
        "    hidden_size = layers[0].output.dense.out_features\n",
        "\n",
        "    # Randomly choose layers to drop\n",
        "    rng = np.random.default_rng(seed)\n",
        "    prune_idxs = rng.choice(total_layers, size=num_prune, replace=False)\n",
        "    prune_idxs = sorted(list(prune_idxs))\n",
        "    print(f\"LayerDrop: Pruning layers {prune_idxs} out of {total_layers}\")\n",
        "\n",
        "    for idx in prune_idxs:\n",
        "        # Replace the whole transformer block with identity\n",
        "        layers[idx] = SkipLayer(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- Main ---\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # Load MRPC (subsample for speed)\n",
        "    train_ds = load_dataset(\"glue\", \"mrpc\", split=\"train\").shuffle(seed).select(range(1000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"mrpc\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    # --- Stage 1: Full fine-tuning ---\n",
        "    print(\"\\n=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=2).to(device)\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=6)\n",
        "    acc_full = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}\")\n",
        "\n",
        "    # --- Stage 2: LayerDrop Pruning ---\n",
        "    print(\"\\n=== Stage 2: LayerDrop Pruning (Remove 4 layers) ===\")\n",
        "    prune_idxs = layerdrop_prune(model, num_prune=4, seed=seed)\n",
        "\n",
        "    # --- Stage 3: Fine-tune pruned model ---\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=5)\n",
        "    acc_pruned = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}\")\n",
        "\n",
        "    print(f\"Pruned layer indices: {prune_idxs}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "lh4bOVG8QD_5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "import random\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "class SkipLayer(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "\n",
        "    def forward(\n",
        "        self,\n",
        "        hidden_states,\n",
        "        attention_mask=None,\n",
        "        head_mask=None,\n",
        "        encoder_hidden_states=None,\n",
        "        encoder_attention_mask=None,\n",
        "        past_key_value=None,\n",
        "        output_attentions=False,\n",
        "        output_hidden_states=False,\n",
        "        return_dict=True,\n",
        "    ):\n",
        "        if return_dict:\n",
        "            from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions\n",
        "            return BaseModelOutputWithPastAndCrossAttentions(\n",
        "                last_hidden_state=hidden_states,\n",
        "                hidden_states=None,\n",
        "                attentions=None,\n",
        "                past_key_values=None,\n",
        "                cross_attentions=None\n",
        "            )\n",
        "        else:\n",
        "            return (hidden_states, None, None)\n",
        "\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    # SST-2 has only 'sentence'\n",
        "    return tok(examples['sentence'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "def finetune_model(model, train_loader, dev_loader, device, epochs):\n",
        "    model.train()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*epochs)\n",
        "    scaler = GradScaler()\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] SST-2 Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "def layerdrop_prune(model, num_prune=4, seed=42):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    total_layers = len(layers)\n",
        "    hidden_size = layers[0].output.dense.out_features\n",
        "\n",
        "    rng = np.random.default_rng(seed)\n",
        "    prune_idxs = rng.choice(total_layers, size=num_prune, replace=False)\n",
        "    prune_idxs = sorted(list(prune_idxs))\n",
        "    print(f\"LayerDrop: Pruning layers {prune_idxs} out of {total_layers}\")\n",
        "\n",
        "    for idx in prune_idxs:\n",
        "        layers[idx] = SkipLayer(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "\n",
        "# load & preprocess SST-2 subset\n",
        "    train_ds = load_dataset(\"glue\", \"sst2\", split=\"train\").shuffle(seed).select(range(5000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"sst2\", split=\"validation\")\n",
        "\n",
        "\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    # --- Stage 1: Full fine-tuning ---\n",
        "    print(\"\\n=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=2).to(device)\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=6)\n",
        "    acc_full = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}\")\n",
        "\n",
        "    # --- Stage 2: LayerDrop Pruning ---\n",
        "    print(\"\\n=== Stage 2: LayerDrop Pruning (Remove 4 layers) ===\")\n",
        "    prune_idxs = layerdrop_prune(model, num_prune=4, seed=seed)\n",
        "\n",
        "    # --- Stage 3: Fine-tune pruned model ---\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=5)\n",
        "    acc_pruned = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}\")\n",
        "\n",
        "    print(f\"Pruned layer indices: {prune_idxs}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "VtVihkTSQEiH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "import random\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "class SkipLayer(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "\n",
        "    def forward(\n",
        "        self,\n",
        "        hidden_states,\n",
        "        attention_mask=None,\n",
        "        head_mask=None,\n",
        "        encoder_hidden_states=None,\n",
        "        encoder_attention_mask=None,\n",
        "        past_key_value=None,\n",
        "        output_attentions=False,\n",
        "        output_hidden_states=False,\n",
        "        return_dict=True,\n",
        "    ):\n",
        "        if return_dict:\n",
        "            from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions\n",
        "            return BaseModelOutputWithPastAndCrossAttentions(\n",
        "                last_hidden_state=hidden_states,\n",
        "                hidden_states=None,\n",
        "                attentions=None,\n",
        "                past_key_values=None,\n",
        "                cross_attentions=None\n",
        "            )\n",
        "        else:\n",
        "            return (hidden_states, None, None)\n",
        "\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "def finetune_model(model, train_loader, dev_loader, device, epochs):\n",
        "    model.train()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*epochs)\n",
        "    scaler = GradScaler()\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] CoLA Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "def layerdrop_prune(model, num_prune=4, seed=42):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    total_layers = len(layers)\n",
        "    hidden_size = layers[0].output.dense.out_features\n",
        "\n",
        "    rng = np.random.default_rng(seed)\n",
        "    prune_idxs = rng.choice(total_layers, size=num_prune, replace=False)\n",
        "    prune_idxs = sorted(list(prune_idxs))\n",
        "    print(f\"LayerDrop: Pruning layers {prune_idxs} out of {total_layers}\")\n",
        "\n",
        "    for idx in prune_idxs:\n",
        "        layers[idx] = SkipLayer(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # Load CoLA\n",
        "    train_ds = load_dataset(\"glue\", \"cola\", split=\"train\")\n",
        "    dev_ds   = load_dataset(\"glue\", \"cola\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    # --- Stage 1: Full fine-tuning ---\n",
        "    print(\"\\n=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=2).to(device)\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=6)\n",
        "    acc_full = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}\")\n",
        "\n",
        "    # --- Stage 2: LayerDrop Pruning ---\n",
        "    print(\"\\n=== Stage 2: LayerDrop Pruning (Remove 4 layers) ===\")\n",
        "    prune_idxs = layerdrop_prune(model, num_prune=4, seed=seed)\n",
        "\n",
        "    # --- Stage 3: Fine-tune pruned model ---\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=5)\n",
        "    acc_pruned = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}\")\n",
        "\n",
        "    print(f\"Pruned layer indices: {prune_idxs}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "UqwgZfYITOao"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "import random\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "class SkipLayer(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "\n",
        "    def forward(\n",
        "        self,\n",
        "        hidden_states,\n",
        "        attention_mask=None,\n",
        "        head_mask=None,\n",
        "        encoder_hidden_states=None,\n",
        "        encoder_attention_mask=None,\n",
        "        past_key_value=None,\n",
        "        output_attentions=False,\n",
        "        output_hidden_states=False,\n",
        "        return_dict=True,\n",
        "    ):\n",
        "        if return_dict:\n",
        "            from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions\n",
        "            return BaseModelOutputWithPastAndCrossAttentions(\n",
        "                last_hidden_state=hidden_states,\n",
        "                hidden_states=None,\n",
        "                attentions=None,\n",
        "                past_key_values=None,\n",
        "                cross_attentions=None\n",
        "            )\n",
        "        else:\n",
        "            return (hidden_states, None, None)\n",
        "\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['question'],\n",
        "               examples['sentence'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "def finetune_model(model, train_loader, dev_loader, device, epochs):\n",
        "    model.train()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*epochs)\n",
        "    scaler = GradScaler()\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] QNLI Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "def layerdrop_prune(model, num_prune=4, seed=42):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    total_layers = len(layers)\n",
        "    hidden_size = layers[0].output.dense.out_features\n",
        "\n",
        "    rng = np.random.default_rng(seed)\n",
        "    prune_idxs = rng.choice(total_layers, size=num_prune, replace=False)\n",
        "    prune_idxs = sorted(list(prune_idxs))\n",
        "    print(f\"LayerDrop: Pruning layers {prune_idxs} out of {total_layers}\")\n",
        "\n",
        "    for idx in prune_idxs:\n",
        "        layers[idx] = SkipLayer(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # Load QNLI\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"qnli\", split=\"train\").shuffle(seed).select(range(5000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"qnli\", split=\"validation\")\n",
        "\n",
        "\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"question\",\"sentence\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"question\",\"sentence\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    # --- Stage 1: Full fine-tuning ---\n",
        "    print(\"\\n=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=2).to(device)\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=6)\n",
        "    acc_full = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}\")\n",
        "\n",
        "    # --- Stage 2: LayerDrop Pruning ---\n",
        "    print(\"\\n=== Stage 2: LayerDrop Pruning (Remove 4 layers) ===\")\n",
        "    prune_idxs = layerdrop_prune(model, num_prune=4, seed=seed)\n",
        "\n",
        "    # --- Stage 3: Fine-tune pruned model ---\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=5)\n",
        "    acc_pruned = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}\")\n",
        "\n",
        "    print(f\"Pruned layer indices: {prune_idxs}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "_IGVk4G6TOzN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "import random\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "class SkipLayer(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "\n",
        "    def forward(\n",
        "        self,\n",
        "        hidden_states,\n",
        "        attention_mask=None,\n",
        "        head_mask=None,\n",
        "        encoder_hidden_states=None,\n",
        "        encoder_attention_mask=None,\n",
        "        past_key_value=None,\n",
        "        output_attentions=False,\n",
        "        output_hidden_states=False,\n",
        "        return_dict=True,\n",
        "    ):\n",
        "        if return_dict:\n",
        "            from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions\n",
        "            return BaseModelOutputWithPastAndCrossAttentions(\n",
        "                last_hidden_state=hidden_states,\n",
        "                hidden_states=None,\n",
        "                attentions=None,\n",
        "                past_key_values=None,\n",
        "                cross_attentions=None\n",
        "            )\n",
        "        else:\n",
        "            return (hidden_states, None, None)\n",
        "\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['question1'],\n",
        "               examples['question2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "def finetune_model(model, train_loader, dev_loader, device, epochs):\n",
        "    model.train()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*epochs)\n",
        "    scaler = GradScaler()\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] QQP Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "def layerdrop_prune(model, num_prune=4, seed=42):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    total_layers = len(layers)\n",
        "    hidden_size = layers[0].output.dense.out_features\n",
        "\n",
        "    rng = np.random.default_rng(seed)\n",
        "    prune_idxs = rng.choice(total_layers, size=num_prune, replace=False)\n",
        "    prune_idxs = sorted(list(prune_idxs))\n",
        "    print(f\"LayerDrop: Pruning layers {prune_idxs} out of {total_layers}\")\n",
        "\n",
        "    for idx in prune_idxs:\n",
        "        layers[idx] = SkipLayer(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # Load QQP (subsample for speed; remove .select(range(1000)) for full)\n",
        "    train_ds = load_dataset(\"glue\", \"qqp\", split=\"train\").select(range(2000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"qqp\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"question1\",\"question2\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"question1\",\"question2\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    # --- Stage 1: Full fine-tuning ---\n",
        "    print(\"\\n=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=2).to(device)\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=6)\n",
        "    acc_full = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}\")\n",
        "\n",
        "    # --- Stage 2: LayerDrop Pruning ---\n",
        "    print(\"\\n=== Stage 2: LayerDrop Pruning (Remove 4 layers) ===\")\n",
        "    prune_idxs = layerdrop_prune(model, num_prune=4, seed=seed)\n",
        "\n",
        "    # --- Stage 3: Fine-tune pruned model ---\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=5)\n",
        "    acc_pruned = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}\")\n",
        "\n",
        "    print(f\"Pruned layer indices: {prune_idxs}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "RRptDLVXTPQB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "import random\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import evaluate\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- LayerSkip utility for pruning entire transformer blocks ---\n",
        "class SkipLayer(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "\n",
        "    def forward(\n",
        "        self,\n",
        "        hidden_states,\n",
        "        attention_mask=None,\n",
        "        head_mask=None,\n",
        "        encoder_hidden_states=None,\n",
        "        encoder_attention_mask=None,\n",
        "        past_key_value=None,\n",
        "        output_attentions=False,\n",
        "        output_hidden_states=False,\n",
        "        return_dict=True,\n",
        "    ):\n",
        "        if return_dict:\n",
        "            from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions\n",
        "            return BaseModelOutputWithPastAndCrossAttentions(\n",
        "                last_hidden_state=hidden_states,\n",
        "                hidden_states=None,\n",
        "                attentions=None,\n",
        "                past_key_values=None,\n",
        "                cross_attentions=None\n",
        "            )\n",
        "        else:\n",
        "            return (hidden_states, None, None)\n",
        "\n",
        "# --- Data/Eval helpers ---\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "# --- Standard fine-tuning ---\n",
        "def finetune_model(model, train_loader, dev_loader, device, epochs):\n",
        "    model.train()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*epochs)\n",
        "    scaler = GradScaler()\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] RTE Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- LayerDrop-style layer pruning ---\n",
        "def layerdrop_prune(model, num_prune=4, seed=42):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    total_layers = len(layers)\n",
        "    hidden_size = layers[0].output.dense.out_features\n",
        "\n",
        "    rng = np.random.default_rng(seed)\n",
        "    prune_idxs = rng.choice(total_layers, size=num_prune, replace=False)\n",
        "    prune_idxs = sorted(list(prune_idxs))\n",
        "    print(f\"LayerDrop: Pruning layers {prune_idxs} out of {total_layers}\")\n",
        "\n",
        "    for idx in prune_idxs:\n",
        "        layers[idx] = SkipLayer(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- Main ---\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # Load RTE (no subsampling, small dataset)\n",
        "    train_ds = load_dataset(\"glue\", \"rte\", split=\"train\")\n",
        "    dev_ds   = load_dataset(\"glue\", \"rte\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    # --- Stage 1: Full fine-tuning ---\n",
        "    print(\"\\n=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=2).to(device)\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=6)\n",
        "    acc_full = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}\")\n",
        "\n",
        "    # --- Stage 2: LayerDrop Pruning ---\n",
        "    print(\"\\n=== Stage 2: LayerDrop Pruning (Remove 4 layers) ===\")\n",
        "    prune_idxs = layerdrop_prune(model, num_prune=4, seed=seed)\n",
        "\n",
        "    # --- Stage 3: Fine-tune pruned model ---\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune_model(model, train_loader, dev_loader, device, epochs=5)\n",
        "    acc_pruned = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"\\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}\")\n",
        "\n",
        "    print(f\"Pruned layer indices: {prune_idxs}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "iejePYsBTPqD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import random\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "# Monkey-patch numpy.array to ignore the copy argument (for NumPy 2.0 compatibility)\n",
        "_np_array = np.array\n",
        "def _patched_array(obj, *args, copy=False, **kwargs):\n",
        "    return _np_array(obj, *args, **kwargs)\n",
        "np.array = _patched_array\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from datasets import load_dataset\n",
        "\n",
        "from collections import defaultdict\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# 1) SKIP FFN LAYER (for pruning)\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_fixed_layers(model, prune_idxs=[0, 5, 7, 10]):\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    print(\"Pruned layers:\", prune_idxs)\n",
        "\n",
        "# 2) LoRA block (optional, can be removed if not using LoRA)\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "# 3) STS-B Evaluation\n",
        "def evaluate_stsb(model, dataloader, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"glue\", \"stsb\")\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dataloader:\n",
        "            out = model(\n",
        "                input_ids=batch[\"input_ids\"].to(device),\n",
        "                attention_mask=batch[\"attention_mask\"].to(device),\n",
        "            )\n",
        "            # flatten predictions\n",
        "            p = out.logits.squeeze(-1).cpu().tolist()\n",
        "            preds.extend(p if isinstance(p, list) else [p])\n",
        "            # flatten references (handle [[5.0], [4.75], ...])\n",
        "            r = batch[\"labels\"].cpu().tolist()\n",
        "            for x in r:\n",
        "                if isinstance(x, (list, tuple, np.ndarray)):\n",
        "                    refs.append(float(x[0]))\n",
        "                else:\n",
        "                    refs.append(float(x))\n",
        "    return metric.compute(predictions=preds, references=refs)\n",
        "\n",
        "# 4) Fine-tuning block\n",
        "def finetune(train_loader, dev_loader, device, model, epochs=6, lr=2e-5):\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=lr)\n",
        "    sched = get_linear_schedule_with_warmup(\n",
        "        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*epochs\n",
        "    )\n",
        "    scaler = GradScaler()\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(\n",
        "                    input_ids=batch[\"input_ids\"].to(device),\n",
        "                    attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                    labels=batch[\"labels\"].to(device),\n",
        "                )\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        metrics = evaluate_stsb(model, dev_loader, device)\n",
        "        print(f\"[Epoch {epoch+1}] STS-B Pearson: {metrics['pearson']:.4f}, Spearman: {metrics['spearmanr']:.4f}\")\n",
        "    return model\n",
        "\n",
        "# 5) Main Entrypoint\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train_ds = load_dataset(\"glue\", \"stsb\", split=\"train\").shuffle(seed)\n",
        "    dev_ds   = load_dataset(\"glue\", \"stsb\", split=\"validation\")\n",
        "\n",
        "    def preprocess(ex):\n",
        "        return tokenizer(\n",
        "            ex[\"sentence1\"], ex[\"sentence2\"],\n",
        "            truncation=True, padding=\"max_length\", max_length=128\n",
        "        )\n",
        "\n",
        "    train_ds = train_ds.map(preprocess, batched=True)\n",
        "    dev_ds   = dev_ds.map(preprocess, batched=True)\n",
        "    train_ds = train_ds.map(lambda x: {\"labels\": float(x[\"label\"])}, batched=False)\n",
        "    dev_ds   = dev_ds.map(lambda x: {\"labels\": float(x[\"label\"])}, batched=False)\n",
        "    train_ds = train_ds.remove_columns([\"sentence1\", \"sentence2\", \"label\", \"idx\"])\n",
        "    dev_ds   = dev_ds.remove_columns([\"sentence1\", \"sentence2\", \"label\", \"idx\"])\n",
        "    train_ds.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])\n",
        "    dev_ds.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])\n",
        "    collator     = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,  collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev_ds,   batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    # Stage 1: Full fine-tuning\n",
        "    print(\"=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=1).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    model = finetune(train_loader, dev_loader, device, model, epochs=6, lr=2e-5)\n",
        "    metrics = evaluate_stsb(model, dev_loader, device)\n",
        "    print(f\"Pearson after 6-epoch fine-tuning: {metrics['pearson']:.4f}\")\n",
        "\n",
        "    # Stage 2: Prune fixed layers\n",
        "    print(\"\\n=== Stage 2: Prune Layers [0,5,7,10] ===\")\n",
        "    prune_fixed_layers(model, prune_idxs=[0, 5, 7, 10])\n",
        "\n",
        "    # Stage 3: Fine-tune pruned model\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune(train_loader, dev_loader, device, model, epochs=5, lr=1e-5)\n",
        "    metrics = evaluate_stsb(model, dev_loader, device)\n",
        "    print(f\"Pearson after 5-epoch post-pruning fine-tuning: {metrics['pearson']:.4f}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "yEv7nfKJLxpR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "hjLQNqcELyLV"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}