{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mAP9acr0Seey"
      },
      "outputs": [],
      "source": [
        "!pip uninstall -y datasets\n",
        "!pip install datasets==2.18.0\n",
        "!pip install evaluate\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n"
      ],
      "metadata": {
        "id": "Gp32nV5EVuKx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "import random\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import warnings\n",
        "from collections import defaultdict\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- SkipFFN for T5 FFN dropout ---\n",
        "class SkipFFN(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states):\n",
        "        return hidden_states\n",
        "\n",
        "# --- LayerDrop utility for random pruning ---\n",
        "def layerdrop_prune_t5(model, num_prune_enc=4, num_prune_dec=4, seed=42):\n",
        "    enc_blocks = model.encoder.block\n",
        "    dec_blocks = model.decoder.block\n",
        "    total_enc = len(enc_blocks)\n",
        "    total_dec = len(dec_blocks)\n",
        "    d_model = model.config.d_model\n",
        "\n",
        "    rng = np.random.default_rng(seed)\n",
        "    enc_idxs = rng.choice(total_enc, size=num_prune_enc, replace=False)\n",
        "    dec_idxs = rng.choice(total_dec, size=num_prune_dec, replace=False)\n",
        "    enc_idxs = sorted(enc_idxs)\n",
        "    dec_idxs = sorted(dec_idxs)\n",
        "\n",
        "    for idx in enc_idxs:\n",
        "        enc_blocks[idx].layer[1].DenseReluDense = SkipFFN(d_model)\n",
        "    for idx in dec_idxs:\n",
        "        dec_blocks[idx].layer[2].DenseReluDense = SkipFFN(d_model)\n",
        "    print(f\"LayerDrop (Encoder FFN): pruned layers {enc_idxs}\")\n",
        "    print(f\"LayerDrop (Decoder FFN): pruned layers {dec_idxs}\")\n",
        "    return enc_idxs, dec_idxs\n",
        "\n",
        "# --- Data/Helper functions ---\n",
        "def make_t5_nli_prompt(premise, hypothesis):\n",
        "    return f\"nli premise: {premise} hypothesis: {hypothesis}\"\n",
        "\n",
        "def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    label_list = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "    labels = [label_list[x] if isinstance(x, int) and x < len(label_list) else x for x in batch['label']]\n",
        "    target = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device, label_texts):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "# --- Training/Fine-tuning Loop ---\n",
        "def finetune_t5(train_loader, dev_loader, device, tokenizer, label_texts, epochs=6, lr=3e-4, model=None):\n",
        "    if model is None:\n",
        "        model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=lr)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*epochs)\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Epoch {epoch+1}] e-SNLI Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- Entrypoint ---\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "\n",
        "    data_files = {\n",
        "        \"train\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_train.json\",\n",
        "        \"validation\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_valid.json\",\n",
        "        \"test\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_test.json\"\n",
        "    }\n",
        "    raw_datasets = load_dataset(\"json\", data_files=data_files)\n",
        "    tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "    label_texts = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "\n",
        "    train_ds = raw_datasets[\"train\"].shuffle(seed=seed).select(range(10000))\n",
        "    dev_ds = raw_datasets[\"validation\"].shuffle(seed=seed).select(range(2000))\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=train_ds.column_names)\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=dev_ds.column_names)\n",
        "\n",
        "    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # --- Stage 1: Full Fine-Tuning ---\n",
        "    print(\"\\n=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = finetune_t5(train_loader, dev_loader, device, tokenizer, label_texts, epochs=6)\n",
        "    acc_full = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "    print(f\"\\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}\")\n",
        "\n",
        "    # --- Stage 2: LayerDrop Pruning ---\n",
        "    print(\"\\n=== Stage 2: LayerDrop Pruning  ===\")\n",
        "    enc_pruned, dec_pruned = layerdrop_prune_t5(model, num_prune_enc=2, num_prune_dec=4, seed=seed)\n",
        "\n",
        "    # --- Stage 3: Fine-tune Pruned Model ---\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune_t5(train_loader, dev_loader, device, tokenizer, label_texts, epochs=5, lr=5e-4, model=model)\n",
        "    acc_pruned = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "    print(f\"\\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}\")\n",
        "    print(f\"Encoder FFN pruned indices: {enc_pruned}, Decoder FFN pruned indices: {dec_pruned}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "voKwLw60Sfqi"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Mount Google Drive if on Colab\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "import random\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Load CQA Data ---\n",
        "data_files = {\n",
        "    \"train\": \"/content/drive/MyDrive/NLP_datasets/cqa/cqa_train.json\",\n",
        "    \"test\":  \"/content/drive/MyDrive/NLP_datasets/cqa/cqa_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Preprocessing Functions ---\n",
        "def preprocess_cqa(batch, tokenizer, max_input_length=128, max_target_length=8, use_cot=False):\n",
        "    if use_cot and 'abstractive_explanation' in batch:\n",
        "        inputs = [\n",
        "            f\"question: {q} choices: {', '.join(choices)} rationale: {exp}\"\n",
        "            for q, choices, exp in zip(batch['question'], batch['choices'], batch['abstractive_explanation'])\n",
        "        ]\n",
        "    else:\n",
        "        inputs = [\n",
        "            f\"question: {q} choices: {', '.join(choices)}\"\n",
        "            for q, choices in zip(batch['question'], batch['choices'])\n",
        "        ]\n",
        "    targets = [str(ans).strip() for ans in batch['answer']]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    target = tokenizer(targets, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "USE_COT = False\n",
        "\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=USE_COT),\n",
        "                            batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev   = dataset[\"test\"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=False),\n",
        "                            batched=True, remove_columns=dataset[\"test\"].column_names)\n",
        "\n",
        "collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 3. SkipFFN utility for random dropout ---\n",
        "class SkipFFN(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states):\n",
        "        return hidden_states\n",
        "\n",
        "def layerdrop_prune_t5(model, num_prune_enc=4, num_prune_dec=4, seed=42):\n",
        "    enc_blocks = model.encoder.block\n",
        "    dec_blocks = model.decoder.block\n",
        "    total_enc = len(enc_blocks)\n",
        "    total_dec = len(dec_blocks)\n",
        "    d_model = model.config.d_model\n",
        "\n",
        "    rng = np.random.default_rng(seed)\n",
        "    enc_idxs = rng.choice(total_enc, size=num_prune_enc, replace=False)\n",
        "    dec_idxs = rng.choice(total_dec, size=num_prune_dec, replace=False)\n",
        "    enc_idxs = sorted(enc_idxs)\n",
        "    dec_idxs = sorted(dec_idxs)\n",
        "\n",
        "    for idx in enc_idxs:\n",
        "        enc_blocks[idx].layer[1].DenseReluDense = SkipFFN(d_model)\n",
        "    for idx in dec_idxs:\n",
        "        dec_blocks[idx].layer[2].DenseReluDense = SkipFFN(d_model)\n",
        "    print(f\"LayerDrop (Encoder FFN): pruned layers {enc_idxs}\")\n",
        "    print(f\"LayerDrop (Decoder FFN): pruned layers {dec_idxs}\")\n",
        "    return enc_idxs, dec_idxs\n",
        "\n",
        "# --- 4. Training/Evaluation ---\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=4)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "def finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=6, lr=3e-4, model=None):\n",
        "    if model is None:\n",
        "        model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=lr)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*epochs)\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Epoch {epoch+1}] CQA Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 5. Entrypoint ---\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # --- Stage 1: Full Fine-Tuning ---\n",
        "    print(\"\\n=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=6)\n",
        "    acc_full = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "    print(f\"\\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}\")\n",
        "\n",
        "    # --- Stage 2: LayerDrop Pruning ---\n",
        "    print(\"\\n=== Stage 2: LayerDrop Pruning ===\")\n",
        "    enc_pruned, dec_pruned = layerdrop_prune_t5(model, num_prune_enc=2, num_prune_dec=4, seed=seed)\n",
        "\n",
        "    # --- Stage 3: Fine-tune Pruned Model ---\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=5, lr=5e-4, model=model)\n",
        "    acc_pruned = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "    print(f\"\\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}\")\n",
        "    print(f\"Encoder FFN pruned indices: {enc_pruned}, Decoder FFN pruned indices: {dec_pruned}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "NmDc7y4USgGs"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- Mount Google Drive if using Colab ---\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# --- Standard Imports ---\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "import random\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Load ANLI1 Dataset ---\n",
        "data_files = {\n",
        "    \"train\":      \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_train.json\",\n",
        "    \"validation\": \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_valid.json\",\n",
        "    \"test\":       \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Preprocessing Function ---\n",
        "def make_t5_nli_prompt(premise, hypothesis):\n",
        "    return f\"nli premise: {premise} hypothesis: {hypothesis}\"\n",
        "\n",
        "def preprocess_anli(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]\n",
        "    label_list = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "    labels = [label_list[int(x)] if isinstance(x, (int, float, str)) and str(x).isdigit() and int(x)<3 else str(x) for x in batch['label']]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    target = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_anli(ex, tokenizer), batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev   = dataset[\"validation\"].map(lambda ex: preprocess_anli(ex, tokenizer), batched=True, remove_columns=dataset[\"validation\"].column_names)\n",
        "collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 3. SkipFFN for random dropout ---\n",
        "class SkipFFN(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states):\n",
        "        return hidden_states\n",
        "\n",
        "def layerdrop_prune_t5(model, num_prune_enc=4, num_prune_dec=4, seed=42):\n",
        "    enc_blocks = model.encoder.block\n",
        "    dec_blocks = model.decoder.block\n",
        "    total_enc = len(enc_blocks)\n",
        "    total_dec = len(dec_blocks)\n",
        "    d_model = model.config.d_model\n",
        "\n",
        "    rng = np.random.default_rng(seed)\n",
        "    enc_idxs = rng.choice(total_enc, size=num_prune_enc, replace=False)\n",
        "    dec_idxs = rng.choice(total_dec, size=num_prune_dec, replace=False)\n",
        "    enc_idxs = sorted(enc_idxs)\n",
        "    dec_idxs = sorted(dec_idxs)\n",
        "\n",
        "    for idx in enc_idxs:\n",
        "        enc_blocks[idx].layer[1].DenseReluDense = SkipFFN(d_model)\n",
        "    for idx in dec_idxs:\n",
        "        dec_blocks[idx].layer[2].DenseReluDense = SkipFFN(d_model)\n",
        "    print(f\"LayerDrop (Encoder FFN): pruned layers {enc_idxs}\")\n",
        "    print(f\"LayerDrop (Decoder FFN): pruned layers {dec_idxs}\")\n",
        "    return enc_idxs, dec_idxs\n",
        "\n",
        "# --- 4. Training/Evaluation ---\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "def finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=6, lr=3e-4, model=None):\n",
        "    if model is None:\n",
        "        model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=lr)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*epochs)\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Epoch {epoch+1}] ANLI1 Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 5. Entrypoint ---\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # --- Stage 1: Full Fine-Tuning ---\n",
        "    print(\"\\n=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=6)\n",
        "    acc_full = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "    print(f\"\\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}\")\n",
        "\n",
        "    # --- Stage 2: LayerDrop Pruning ---\n",
        "    print(\"\\n=== Stage 2: LayerDrop Pruning ===\")\n",
        "    enc_pruned, dec_pruned = layerdrop_prune_t5(model, num_prune_enc=2, num_prune_dec=4, seed=seed)\n",
        "\n",
        "    # --- Stage 3: Fine-tune Pruned Model ---\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=5, lr=5e-4, model=model)\n",
        "    acc_pruned = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "    print(f\"\\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}\")\n",
        "    print(f\"Encoder FFN pruned indices: {enc_pruned}, Decoder FFN pruned indices: {dec_pruned}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "Q_w3g02_Sght"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ===========================\n",
        "# 0. Google Drive Mount\n",
        "# ===========================\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# ===========================\n",
        "# 1. Imports and Setup\n",
        "# ===========================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "import random\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ===========================\n",
        "# 2. Load SVAMP Dataset\n",
        "# ===========================\n",
        "data_files = {\n",
        "    \"train\": \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_train.json\",\n",
        "    \"test\": \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# ===========================\n",
        "# 3. Preprocessing\n",
        "# ===========================\n",
        "def preprocess_svamp(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    model_inputs = tokenizer(\n",
        "        batch[\"input\"], padding=\"max_length\", truncation=True, max_length=max_input_length\n",
        "    )\n",
        "    targets = [str(x) for x in batch[\"label\"]]\n",
        "    target_encodings = tokenizer(\n",
        "        targets, padding=\"max_length\", truncation=True, max_length=max_target_length\n",
        "    )\n",
        "    model_inputs[\"labels\"] = target_encodings[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev = dataset[\"test\"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset[\"test\"].column_names)\n",
        "collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# ===========================\n",
        "# 4. SkipFFN & LayerDropout Utilities\n",
        "# ===========================\n",
        "class SkipFFN(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states):\n",
        "        return hidden_states\n",
        "\n",
        "def layerdrop_prune_t5(model, num_prune_enc=4, num_prune_dec=4, seed=42):\n",
        "    enc_blocks = model.encoder.block\n",
        "    dec_blocks = model.decoder.block\n",
        "    total_enc = len(enc_blocks)\n",
        "    total_dec = len(dec_blocks)\n",
        "    d_model = model.config.d_model\n",
        "\n",
        "    rng = np.random.default_rng(seed)\n",
        "    enc_idxs = rng.choice(total_enc, size=num_prune_enc, replace=False)\n",
        "    dec_idxs = rng.choice(total_dec, size=num_prune_dec, replace=False)\n",
        "    enc_idxs = sorted(enc_idxs)\n",
        "    dec_idxs = sorted(dec_idxs)\n",
        "\n",
        "    for idx in enc_idxs:\n",
        "        enc_blocks[idx].layer[1].DenseReluDense = SkipFFN(d_model)\n",
        "    for idx in dec_idxs:\n",
        "        dec_blocks[idx].layer[2].DenseReluDense = SkipFFN(d_model)\n",
        "    print(f\"LayerDrop (Encoder FFN): pruned layers {enc_idxs}\")\n",
        "    print(f\"LayerDrop (Decoder FFN): pruned layers {dec_idxs}\")\n",
        "    return enc_idxs, dec_idxs\n",
        "\n",
        "# ===========================\n",
        "# 5. Training/Eval\n",
        "# ===========================\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=8)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "def finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=6, lr=3e-4, model=None):\n",
        "    if model is None:\n",
        "        model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=lr)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*epochs)\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Epoch {epoch+1}] SVAMP Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# ===========================\n",
        "# 6. Main Entrypoint\n",
        "# ===========================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # --- Stage 1: Full Fine-Tuning ---\n",
        "    print(\"\\n=== Stage 1: Full Fine-Tuning (No Pruning) ===\")\n",
        "    model = finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=6)\n",
        "    acc_full = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "    print(f\"\\nAccuracy after 6-epoch full fine-tuning: {acc_full:.4f}\")\n",
        "\n",
        "    # --- Stage 2: LayerDrop Pruning ---\n",
        "    print(\"\\n=== Stage 2: LayerDrop Pruning ===\")\n",
        "    enc_pruned, dec_pruned = layerdrop_prune_t5(model, num_prune_enc=2, num_prune_dec=4, seed=seed)\n",
        "\n",
        "    # --- Stage 3: Fine-tune Pruned Model ---\n",
        "    print(\"\\n=== Stage 3: Fine-Tune Pruned Model (5 epochs) ===\")\n",
        "    model = finetune_t5(train_loader, dev_loader, device, tokenizer, epochs=5, lr=5e-4, model=model)\n",
        "    acc_pruned = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "    print(f\"\\nAccuracy after 5-epoch post-pruning fine-tuning: {acc_pruned:.4f}\")\n",
        "    print(f\"Encoder FFN pruned indices: {enc_pruned}, Decoder FFN pruned indices: {dec_pruned}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "-QLvXElCSg_1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "ojg3hV4CShdX"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}