{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6ENR6EQ2b8R8",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "8b71613c-bb7d-4a97-dd4d-1764381fe81b"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Found existing installation: datasets 4.0.0\n",
            "Uninstalling datasets-4.0.0:\n",
            "  Successfully uninstalled datasets-4.0.0\n",
            "Collecting datasets==2.18.0\n",
            "  Downloading datasets-2.18.0-py3-none-any.whl.metadata (20 kB)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (3.19.1)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (2.0.2)\n",
            "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (18.1.0)\n",
            "Collecting pyarrow-hotfix (from datasets==2.18.0)\n",
            "  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)\n",
            "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (0.3.8)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (2.2.2)\n",
            "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (2.32.4)\n",
            "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (4.67.1)\n",
            "Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (3.5.0)\n",
            "Requirement already satisfied: multiprocess in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (0.70.16)\n",
            "Collecting fsspec<=2024.2.0,>=2023.1.0 (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets==2.18.0)\n",
            "  Downloading fsspec-2024.2.0-py3-none-any.whl.metadata (6.8 kB)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (3.12.15)\n",
            "Requirement already satisfied: huggingface-hub>=0.19.4 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (0.35.0)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (25.0)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets==2.18.0) (6.0.2)\n",
            "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (2.6.1)\n",
            "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (1.4.0)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (25.3.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (1.7.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (6.6.4)\n",
            "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (0.3.2)\n",
            "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.18.0) (1.20.1)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.19.4->datasets==2.18.0) (4.15.0)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.19.4->datasets==2.18.0) (1.1.10)\n"
          ]
        }
      ],
      "source": [
        "!pip uninstall -y datasets\n",
        "!pip install datasets==2.18.0\n",
        "!pip install evaluate\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n"
      ],
      "metadata": {
        "id": "d7JFdShlZwE4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            # Only compute if shapes are valid and not empty\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item()\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# --- 2. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 3. Data Processing ---\n",
        "def make_t5_nli_prompt(premise, hypothesis):\n",
        "    return f\"nli premise: {premise} hypothesis: {hypothesis}\"\n",
        "\n",
        "def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    label_list = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "    labels = [label_list[x] if isinstance(x, int) and x < len(label_list) else x for x in batch['label']]\n",
        "    target = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds)\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device, label_texts):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "# --- 4. Training Loops ---\n",
        "\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, cross_er_scores, tokenizer, label_texts):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        " #   enc_prune_idxs = prune_er_layers(model.encoder.block, enc_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        " #   print(\"Pruned encoder layers (highest ER):\", enc_prune_idxs)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] e-SNLI Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 5. Main Entrypoint ---\n",
        "\n",
        "def main():\n",
        "    data_files = {\n",
        "        \"train\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_train.json\",\n",
        "        \"validation\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_valid.json\",\n",
        "        \"test\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_test.json\"\n",
        "    }\n",
        "    raw_datasets = load_dataset(\"json\", data_files=data_files)\n",
        "    tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "    label_texts = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "\n",
        "    train_ds = raw_datasets[\"train\"].shuffle(seed=42).select(range(10000))\n",
        "    dev_ds = raw_datasets[\"validation\"].shuffle(seed=42).select(range(2000))\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=train_ds.column_names)\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=dev_ds.column_names)\n",
        "\n",
        "    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer, label_texts)\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, cross_er_scores,\n",
        "        tokenizer, label_texts)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "tB4eSzBIFEsg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Pa2LlXfR-cus"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Prune the decoder\n",
        "\n",
        "# Mount Google Drive if on Colab\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Load CQA Data ---\n",
        "data_files = {\n",
        "    \"train\": \"/content/drive/MyDrive/NLP_datasets/cqa/cqa_train.json\",\n",
        "    \"test\":  \"/content/drive/MyDrive/NLP_datasets/cqa/cqa_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Preprocessing Functions ---\n",
        "\n",
        "def preprocess_cqa(batch, tokenizer, max_input_length=128, max_target_length=8, use_cot=False):\n",
        "    # Prompt with or without CoT\n",
        "    if use_cot and 'abstractive_explanation' in batch:\n",
        "        # Use question, choices, and abstractive explanation for reasoning\n",
        "        inputs = [\n",
        "            f\"question: {q} choices: {', '.join(choices)} rationale: {exp}\"\n",
        "            for q, choices, exp in zip(batch['question'], batch['choices'], batch['abstractive_explanation'])\n",
        "        ]\n",
        "    else:\n",
        "        inputs = [\n",
        "            f\"question: {q} choices: {', '.join(choices)}\"\n",
        "            for q, choices in zip(batch['question'], batch['choices'])\n",
        "        ]\n",
        "    targets = [str(ans).strip() for ans in batch['answer']]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    target = tokenizer(targets, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "USE_COT = False  # Set to True to include abstractive_explanation\n",
        "\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=USE_COT),\n",
        "                            batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev   = dataset[\"test\"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=False),\n",
        "                            batched=True, remove_columns=dataset[\"test\"].column_names)\n",
        "\n",
        "collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=32, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev, batch_size=32, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 3. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item()\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# --- 4. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=4, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 5. Training/Eval/ER Pipeline ---\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=4)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, tokenizer):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        " #   enc_prune_idxs = prune_er_layers(model.encoder.block, enc_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        " #   print(\"Pruned encoder layers (highest ER):\", enc_prune_idxs)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] CQA Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 6. Entrypoint ---\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer)\n",
        "    # --- PRUNING AND CONTINUED FINETUNING ---\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, tokenizer\n",
        "    )\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "wm8aQwKE-eKF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "-nEqET8995OX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Only prune decoder\n",
        "\n",
        "# --- Mount Google Drive if using Colab ---\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# --- Standard Imports ---\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup, Adafactor\n",
        ")\n",
        "from torch.cuda.amp import autocast\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "import random\n",
        "import numpy as np\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ------------- Repro -------------\n",
        "def set_seed(seed=42):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "set_seed(1234)\n",
        "\n",
        "# --- 1. Load ANLI1 Dataset ---\n",
        "data_files = {\n",
        "    \"train\":      \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_train.json\",\n",
        "    \"validation\": \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_valid.json\",\n",
        "    \"test\":       \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Preprocessing Function ---\n",
        "def make_t5_nli_prompt(premise, hypothesis):\n",
        "    return f\"nli premise: {premise} hypothesis: {hypothesis}\"\n",
        "\n",
        "def preprocess_anli(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]\n",
        "    label_list = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "\n",
        "    # robust label -> string\n",
        "    labels_str = []\n",
        "    for x in batch['label']:\n",
        "        sx = str(x)\n",
        "        if sx.isdigit() and int(sx) < 3:\n",
        "            labels_str.append(label_list[int(sx)])\n",
        "        else:\n",
        "            labels_str.append(sx.strip().lower())\n",
        "\n",
        "    # Fixed padding to keep hook tensor shapes consistent across steps\n",
        "    model_inputs = tokenizer(\n",
        "        inputs, padding=\"max_length\", truncation=True, max_length=max_input_length\n",
        "    )\n",
        "    target = tokenizer(\n",
        "        text_target=labels_str, padding=\"max_length\", truncation=True, max_length=max_target_length\n",
        "    )\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "# Tokenizer\n",
        "tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "\n",
        "# Map datasets\n",
        "train = dataset[\"train\"].map(\n",
        "    lambda ex: preprocess_anli(ex, tokenizer),\n",
        "    batched=True, remove_columns=dataset[\"train\"].column_names\n",
        ")\n",
        "dev = dataset[\"validation\"].map(\n",
        "    lambda ex: preprocess_anli(ex, tokenizer),\n",
        "    batched=True, remove_columns=dataset[\"validation\"].column_names\n",
        ")\n",
        "\n",
        "# --- Load model before creating the collator (so collator can mask label pads -> -100) ---\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "# Avoid dealing with past_key_value in custom blocks\n",
        "model.config.use_cache = False\n",
        "\n",
        "# Collator that converts pad tokens in labels to -100\n",
        "collator = DataCollatorForSeq2Seq(\n",
        "    tokenizer, model=model, label_pad_token_id=-100\n",
        ")\n",
        "\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 3. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    # Encoder hooks\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "\n",
        "    # Decoder hooks\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "\n",
        "    # Cross-attention hooks (decoder.layer[1] is cross-attn in T5 decoder blocks)\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        if hasattr(block, \"layer\") and len(block.layer) > 1:\n",
        "            cross_attn = block.layer[1]\n",
        "            def hook_fn_cross(module, inp, out, idx=i):\n",
        "                hs = out[0] if isinstance(out, tuple) else out\n",
        "                cross_acts[idx] = hs.detach()\n",
        "            cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "        else:\n",
        "            cross_hooks.append(None)\n",
        "\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            if h is not None:\n",
        "                h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    \"\"\"\n",
        "    Cos^2 between step-wise deltas, averaged over batch.\n",
        "    \"\"\"\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 1 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = []\n",
        "                for j in range(B):\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any()):\n",
        "                        cs = F.cosine_similarity(\n",
        "                            dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8\n",
        "                        ).item()\n",
        "                        if not (math.isnan(cs) or math.isinf(cs)):\n",
        "                            cos_squares.append(cs * cs)  # cs^2\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# --- 4. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    \"\"\"\n",
        "    Minimal drop-in replacement for a T5 decoder block that simply forwards hidden_states.\n",
        "    Matches T5Block's call signature and return tuple:\n",
        "    (hidden_states, present_key_value, self_attn_weights, cross_attn_weights,\n",
        "     position_bias, encoder_decoder_position_bias)\n",
        "    \"\"\"\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "\n",
        "    def forward(\n",
        "        self,\n",
        "        hidden_states,\n",
        "        attention_mask=None,                 # (in decoder this is 'causal_mask' positionally)\n",
        "        position_bias=None,\n",
        "        encoder_hidden_states=None,\n",
        "        encoder_attention_mask=None,\n",
        "        encoder_decoder_position_bias=None,\n",
        "        layer_head_mask=None,\n",
        "        cross_attn_layer_head_mask=None,\n",
        "        past_key_value=None,\n",
        "        use_cache=False,\n",
        "        output_attentions=False,\n",
        "        return_dict=False,\n",
        "        cache_position=None,                 # <-- NEW to accept HF's kwarg\n",
        "        **kwargs,\n",
        "    ):\n",
        "        # Simply pass through hidden_states and propagate positional biases\n",
        "        present_key_value = None\n",
        "        self_attn_weights = None\n",
        "        cross_attn_weights = None\n",
        "        return (\n",
        "            hidden_states,\n",
        "            present_key_value,\n",
        "            self_attn_weights,\n",
        "            cross_attn_weights,\n",
        "            position_bias,\n",
        "            encoder_decoder_position_bias,\n",
        "        )\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=4):\n",
        "    # Sort by descending ER (your 'redundancy' convention)\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    # Shift +1 because pair (i,i+1) -> prune i+1\n",
        "    prune_idxs = [idx + 1 for idx, _ in sorted_layers[:num_prune] if idx + 1 < len(blocks)]\n",
        "    prune_idxs = sorted(set(prune_idxs))\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock()\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 5. Training/Eval/ER Pipeline ---\n",
        "def canonicalize_label(s: str):\n",
        "    s = (s or \"\").strip().lower()\n",
        "    first = s.split()[0] if s else s\n",
        "    CANON = {\n",
        "        \"entailment\": \"entailment\",\n",
        "        \"entailed\": \"entailment\",\n",
        "        \"neutral\": \"neutral\",\n",
        "        \"contradiction\": \"contradiction\",\n",
        "        \"contradict\": \"contradiction\",\n",
        "        \"contradictory\": \"contradiction\",\n",
        "        \"contradicted\": \"contradiction\",\n",
        "    }\n",
        "    return CANON.get(first, first)\n",
        "\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if canonicalize_label(p) == canonicalize_label(l):\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "@torch.no_grad()\n",
        "def evaluate_model(model, dl, tokenizer, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    for batch in dl:\n",
        "        input_ids = batch[\"input_ids\"].to(device)\n",
        "        attention_mask = batch[\"attention_mask\"].to(device)\n",
        "\n",
        "        # decode gold labels to text\n",
        "        label_ids = batch[\"labels\"].clone()\n",
        "        label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "        ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "\n",
        "        # let the model actually spell labels; turn off cache to stay consistent with SkipBlock\n",
        "        outputs = model.generate(\n",
        "            input_ids=input_ids,\n",
        "            attention_mask=attention_mask,\n",
        "            max_new_tokens=4,       # enough for \"contradiction\"\n",
        "            use_cache=False\n",
        "        )\n",
        "        pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "\n",
        "        preds.extend([p.strip().lower() for p in pred_texts])\n",
        "        refs.extend([l.strip().lower() for l in ref_texts])\n",
        "\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "def build_optimizer_and_scheduler(model, train_steps):\n",
        "    \"\"\"\n",
        "    More stable defaults for T5:\n",
        "      - Adafactor with relative_step=True is recommended by HF for T5.\n",
        "      - If you prefer AdamW, lower LR (e.g., 5e-5) and enable grad clipping.\n",
        "    \"\"\"\n",
        "    opt = Adafactor(\n",
        "        model.parameters(),\n",
        "        relative_step=True, scale_parameter=True, warmup_init=True\n",
        "    )\n",
        "    sched = None  # Adafactor w/ relative_step schedules internally\n",
        "    return opt, sched\n",
        "\n",
        "def full_finetuning(model, train_loader, dev_loader, device, tokenizer):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "\n",
        "    opt, sched = build_optimizer_and_scheduler(model, len(train_loader) * 3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "\n",
        "            with autocast(enabled=False):  # turn off AMP while debugging stability\n",
        "                outputs = model(\n",
        "                    input_ids=batch['input_ids'].to(device),\n",
        "                    attention_mask=batch['attention_mask'].to(device),\n",
        "                    labels=batch['labels'].to(device)\n",
        "                )\n",
        "                loss = outputs.loss\n",
        "\n",
        "            if not torch.isfinite(loss):\n",
        "                print(\"Loss is NaN/Inf — skipping this batch.\")\n",
        "                continue\n",
        "\n",
        "            loss.backward()\n",
        "            # Grad clipping (helps avoid LN/attn blowups)\n",
        "            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
        "\n",
        "            # Adafactor: step without scheduler\n",
        "            opt.step()\n",
        "            if sched is not None:\n",
        "                sched.step()\n",
        "\n",
        "            # --- ER accumulations ---\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "\n",
        "            # snapshot current acts\n",
        "            prev_enc_acts = {i: (enc_acts[i].clone() if enc_acts[i] is not None else None) for i in enc_acts}\n",
        "            prev_dec_acts = {i: (dec_acts[i].clone() if dec_acts[i] is not None else None) for i in dec_acts}\n",
        "            prev_cross_acts = {i: (cross_acts[i].clone() if cross_acts[i] is not None else None) for i in cross_acts}\n",
        "\n",
        "        # epoch-level ER (means)\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, tokenizer):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        "    # Decoder-only pruning\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4)\n",
        "    print(\"Pruned decoder layers (highest ER -> next index):\", dec_prune_idxs)\n",
        "\n",
        "    # New optimizer after structural change\n",
        "    opt, sched = build_optimizer_and_scheduler(model, len(train_loader) * 2)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(\n",
        "                input_ids=batch['input_ids'].to(device),\n",
        "                attention_mask=batch['attention_mask'].to(device),\n",
        "                labels=batch['labels'].to(device)\n",
        "            )\n",
        "            loss = outputs.loss\n",
        "            if not torch.isfinite(loss):\n",
        "                print(\"Loss is NaN/Inf — skipping this batch.\")\n",
        "                continue\n",
        "            loss.backward()\n",
        "            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
        "            opt.step()\n",
        "            if sched is not None:\n",
        "                sched.step()\n",
        "\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] ANLI1 Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 6. Entrypoint ---\n",
        "def main():\n",
        "    global model  # using the earlier-loaded model\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        model, train_loader, dev_loader, device, tokenizer\n",
        "    )\n",
        "    # --- PRUNING AND CONTINUED FINETUNING (decoder only) ---\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, tokenizer\n",
        "    )\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "VP5ieuyXTqwG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# ER\n",
        "\n",
        "# --- Mount Google Drive if using Colab ---\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# --- Standard Imports ---\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup, Adafactor\n",
        ")\n",
        "from torch.cuda.amp import autocast\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "import random\n",
        "import numpy as np\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ------------- Repro -------------\n",
        "def set_seed(seed=42):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "set_seed(1234)\n",
        "\n",
        "# --- 1. Load ANLI1 Dataset ---\n",
        "data_files = {\n",
        "    \"train\":      \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_train.json\",\n",
        "    \"validation\": \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_valid.json\",\n",
        "    \"test\":       \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Preprocessing Function ---\n",
        "def make_t5_nli_prompt(premise, hypothesis):\n",
        "    return f\"nli premise: {premise} hypothesis: {hypothesis}\"\n",
        "\n",
        "def preprocess_anli(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]\n",
        "    label_list = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "\n",
        "    # robust label -> string\n",
        "    labels_str = []\n",
        "    for x in batch['label']:\n",
        "        sx = str(x)\n",
        "        if sx.isdigit() and int(sx) < 3:\n",
        "            labels_str.append(label_list[int(sx)])\n",
        "        else:\n",
        "            labels_str.append(sx.strip().lower())\n",
        "\n",
        "    # Fixed padding to keep hook tensor shapes consistent across steps\n",
        "    model_inputs = tokenizer(\n",
        "        inputs, padding=\"max_length\", truncation=True, max_length=max_input_length\n",
        "    )\n",
        "    target = tokenizer(\n",
        "        text_target=labels_str, padding=\"max_length\", truncation=True, max_length=max_target_length\n",
        "    )\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "# Tokenizer\n",
        "tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "\n",
        "# Map datasets\n",
        "train = dataset[\"train\"].map(\n",
        "    lambda ex: preprocess_anli(ex, tokenizer),\n",
        "    batched=True, remove_columns=dataset[\"train\"].column_names\n",
        ")\n",
        "dev = dataset[\"validation\"].map(\n",
        "    lambda ex: preprocess_anli(ex, tokenizer),\n",
        "    batched=True, remove_columns=dataset[\"validation\"].column_names\n",
        ")\n",
        "\n",
        "# --- Load model before creating the collator (so collator can mask label pads -> -100) ---\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "# Avoid dealing with past_key_value in custom blocks\n",
        "model.config.use_cache = False\n",
        "\n",
        "# Collator that converts pad tokens in labels to -100\n",
        "collator = DataCollatorForSeq2Seq(\n",
        "    tokenizer, model=model, label_pad_token_id=-100\n",
        ")\n",
        "\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 3. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    # Encoder hooks\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "\n",
        "    # Decoder hooks\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "\n",
        "    # Cross-attention hooks (decoder.layer[1] is cross-attn in T5 decoder blocks)\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        if hasattr(block, \"layer\") and len(block.layer) > 1:\n",
        "            cross_attn = block.layer[1]\n",
        "            def hook_fn_cross(module, inp, out, idx=i):\n",
        "                hs = out[0] if isinstance(out, tuple) else out\n",
        "                cross_acts[idx] = hs.detach()\n",
        "            cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "        else:\n",
        "            cross_hooks.append(None)\n",
        "\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            if h is not None:\n",
        "                h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    \"\"\"\n",
        "    Cos^2 between step-wise deltas, averaged over batch.\n",
        "    \"\"\"\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 1 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = []\n",
        "                for j in range(B):\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any()):\n",
        "                        cs = F.cosine_similarity(\n",
        "                            dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8\n",
        "                        ).item() ** 2\n",
        "                        if not (math.isnan(cs) or math.isinf(cs)):\n",
        "                            cos_squares.append(cs * cs)  # cs^2\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# --- 4. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    \"\"\"\n",
        "    Minimal drop-in replacement for a T5 decoder block that simply forwards hidden_states.\n",
        "    Matches T5Block's call signature and return tuple:\n",
        "    (hidden_states, present_key_value, self_attn_weights, cross_attn_weights,\n",
        "     position_bias, encoder_decoder_position_bias)\n",
        "    \"\"\"\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "\n",
        "    def forward(\n",
        "        self,\n",
        "        hidden_states,\n",
        "        attention_mask=None,                 # (in decoder this is 'causal_mask' positionally)\n",
        "        position_bias=None,\n",
        "        encoder_hidden_states=None,\n",
        "        encoder_attention_mask=None,\n",
        "        encoder_decoder_position_bias=None,\n",
        "        layer_head_mask=None,\n",
        "        cross_attn_layer_head_mask=None,\n",
        "        past_key_value=None,\n",
        "        use_cache=False,\n",
        "        output_attentions=False,\n",
        "        return_dict=False,\n",
        "        cache_position=None,                 # <-- NEW to accept HF's kwarg\n",
        "        **kwargs,\n",
        "    ):\n",
        "        # Simply pass through hidden_states and propagate positional biases\n",
        "        present_key_value = None\n",
        "        self_attn_weights = None\n",
        "        cross_attn_weights = None\n",
        "        return (\n",
        "            hidden_states,\n",
        "            present_key_value,\n",
        "            self_attn_weights,\n",
        "            cross_attn_weights,\n",
        "            position_bias,\n",
        "            encoder_decoder_position_bias,\n",
        "        )\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=4):\n",
        "    # Sort by descending ER (your 'redundancy' convention)\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    # Shift +1 because pair (i,i+1) -> prune i+1\n",
        "    prune_idxs = [idx + 1 for idx, _ in sorted_layers[:num_prune] if idx + 1 < len(blocks)]\n",
        "    prune_idxs = sorted(set(prune_idxs))\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock()\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 5. Training/Eval/ER Pipeline ---\n",
        "def canonicalize_label(s: str):\n",
        "    s = (s or \"\").strip().lower()\n",
        "    first = s.split()[0] if s else s\n",
        "    CANON = {\n",
        "        \"entailment\": \"entailment\",\n",
        "        \"entailed\": \"entailment\",\n",
        "        \"neutral\": \"neutral\",\n",
        "        \"contradiction\": \"contradiction\",\n",
        "        \"contradict\": \"contradiction\",\n",
        "        \"contradictory\": \"contradiction\",\n",
        "        \"contradicted\": \"contradiction\",\n",
        "    }\n",
        "    return CANON.get(first, first)\n",
        "\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if canonicalize_label(p) == canonicalize_label(l):\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "@torch.no_grad()\n",
        "def evaluate_model(model, dl, tokenizer, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    for batch in dl:\n",
        "        input_ids = batch[\"input_ids\"].to(device)\n",
        "        attention_mask = batch[\"attention_mask\"].to(device)\n",
        "\n",
        "        # decode gold labels to text\n",
        "        label_ids = batch[\"labels\"].clone()\n",
        "        label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "        ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "\n",
        "        # let the model actually spell labels; turn off cache to stay consistent with SkipBlock\n",
        "        outputs = model.generate(\n",
        "            input_ids=input_ids,\n",
        "            attention_mask=attention_mask,\n",
        "            max_new_tokens=4,       # enough for \"contradiction\"\n",
        "            use_cache=False\n",
        "        )\n",
        "        pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "\n",
        "        preds.extend([p.strip().lower() for p in pred_texts])\n",
        "        refs.extend([l.strip().lower() for l in ref_texts])\n",
        "\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "def build_optimizer_and_scheduler(model, train_steps):\n",
        "    \"\"\"\n",
        "    More stable defaults for T5:\n",
        "      - Adafactor with relative_step=True is recommended by HF for T5.\n",
        "      - If you prefer AdamW, lower LR (e.g., 5e-5) and enable grad clipping.\n",
        "    \"\"\"\n",
        "    opt = Adafactor(\n",
        "        model.parameters(),\n",
        "        relative_step=True, scale_parameter=True, warmup_init=True\n",
        "    )\n",
        "    sched = None  # Adafactor w/ relative_step schedules internally\n",
        "    return opt, sched\n",
        "\n",
        "def full_finetuning(model, train_loader, dev_loader, device, tokenizer):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "\n",
        "    opt, sched = build_optimizer_and_scheduler(model, len(train_loader) * 3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "\n",
        "            with autocast(enabled=False):  # turn off AMP while debugging stability\n",
        "                outputs = model(\n",
        "                    input_ids=batch['input_ids'].to(device),\n",
        "                    attention_mask=batch['attention_mask'].to(device),\n",
        "                    labels=batch['labels'].to(device)\n",
        "                )\n",
        "                loss = outputs.loss\n",
        "\n",
        "            if not torch.isfinite(loss):\n",
        "                print(\"Loss is NaN/Inf — skipping this batch.\")\n",
        "                continue\n",
        "\n",
        "            loss.backward()\n",
        "            # Grad clipping (helps avoid LN/attn blowups)\n",
        "            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
        "\n",
        "            # Adafactor: step without scheduler\n",
        "            opt.step()\n",
        "            if sched is not None:\n",
        "                sched.step()\n",
        "\n",
        "            # --- ER accumulations ---\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "\n",
        "            # snapshot current acts\n",
        "            prev_enc_acts = {i: (enc_acts[i].clone() if enc_acts[i] is not None else None) for i in enc_acts}\n",
        "            prev_dec_acts = {i: (dec_acts[i].clone() if dec_acts[i] is not None else None) for i in dec_acts}\n",
        "            prev_cross_acts = {i: (cross_acts[i].clone() if cross_acts[i] is not None else None) for i in cross_acts}\n",
        "\n",
        "        # epoch-level ER (means)\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, tokenizer):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        "    # Decoder-only pruning\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4)\n",
        "    print(\"Pruned decoder layers (highest ER -> next index):\", dec_prune_idxs)\n",
        "\n",
        "    # New optimizer after structural change\n",
        "    opt, sched = build_optimizer_and_scheduler(model, len(train_loader) * 2)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(\n",
        "                input_ids=batch['input_ids'].to(device),\n",
        "                attention_mask=batch['attention_mask'].to(device),\n",
        "                labels=batch['labels'].to(device)\n",
        "            )\n",
        "            loss = outputs.loss\n",
        "            if not torch.isfinite(loss):\n",
        "                print(\"Loss is NaN/Inf — skipping this batch.\")\n",
        "                continue\n",
        "            loss.backward()\n",
        "            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
        "            opt.step()\n",
        "            if sched is not None:\n",
        "                sched.step()\n",
        "\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] ANLI1 Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 6. Entrypoint ---\n",
        "def main():\n",
        "    global model  # using the earlier-loaded model\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        model, train_loader, dev_loader, device, tokenizer\n",
        "    )\n",
        "    # --- PRUNING AND CONTINUED FINETUNING (decoder only) ---\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, tokenizer\n",
        "    )\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "NQJVMcgwpLGG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Only prune decoder\n",
        "# ===========================\n",
        "# 0. Google Drive Mount\n",
        "# ===========================\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# ===========================\n",
        "# 1. Imports and Setup\n",
        "# ===========================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ===========================\n",
        "# 2. Load SVAMP Dataset\n",
        "# ===========================\n",
        "data_files = {\n",
        "    \"train\": \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_train.json\",\n",
        "    \"test\": \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# ===========================\n",
        "# 3. Preprocessing\n",
        "# ===========================\n",
        "def preprocess_svamp(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    model_inputs = tokenizer(\n",
        "        batch[\"input\"], padding=\"max_length\", truncation=True, max_length=max_input_length\n",
        "    )\n",
        "    targets = [str(x) for x in batch[\"label\"]]\n",
        "    target_encodings = tokenizer(\n",
        "        targets, padding=\"max_length\", truncation=True, max_length=max_target_length\n",
        "    )\n",
        "    model_inputs[\"labels\"] = target_encodings[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev = dataset[\"test\"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset[\"test\"].column_names)\n",
        "collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "dev_loader = DataLoader(dev, batch_size=8, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# ===========================\n",
        "# 4. Conditional ER Utilities\n",
        "# ===========================\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item()\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# ===========================\n",
        "# 5. Pruning Utilities\n",
        "# ===========================\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# ===========================\n",
        "# 6. Eval Helper\n",
        "# ===========================\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=8)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "# ===========================\n",
        "# 7. Training + ER Tracking + Pruning\n",
        "# ===========================\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, tokenizer):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        " #   enc_prune_idxs = prune_er_layers(model.encoder.block, enc_er_scores, num_prune=2, hidden_size=model.config.d_model)\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        " #   print(\"Pruned encoder layers (highest ER):\", enc_prune_idxs)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] SVAMP Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# ===========================\n",
        "# 8. Main Entrypoint\n",
        "# ===========================\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer)\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, tokenizer\n",
        "    )\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "BQ322_FLvamj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "gX0GHph222Ka"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "qGXGpf-0nyNa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "U3wDNkD86y-D"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}