{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yNrNC-Ls62Ua"
      },
      "outputs": [],
      "source": [
        "!pip uninstall -y datasets\n",
        "!pip install datasets==2.18.0\n",
        "!pip install evaluate\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration,\n",
        "    T5TokenizerFast,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item()**2\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 2. Data Processing ---\n",
        "def make_t5_nli_prompt(premise, hypothesis):\n",
        "    return f\"mnli premise: {premise} hypothesis: {hypothesis}\"\n",
        "\n",
        "def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch[\"premise\"], batch[\"hypothesis\"])]\n",
        "    label_list = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "    labels = [label_list[x] if (isinstance(x, int) and x in {0,1,2}) else \"neutral\"\n",
        "              for x in batch[\"label\"]]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    with tokenizer.as_target_tokenizer():\n",
        "        targets = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = targets[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "def map_pred_text_to_idx(pred, label_texts):\n",
        "    p = pred.strip().lower()\n",
        "    for i, lab in enumerate(label_texts):\n",
        "        if lab in p:\n",
        "            return i\n",
        "    return 1  # neutral fallback\n",
        "\n",
        "def compute_accuracy(preds, refs):\n",
        "    return sum(p==r for p,r in zip(preds,refs)) / len(preds) if preds else 0.0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, label_texts, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            ids = batch[\"input_ids\"].to(device)\n",
        "            mask = batch[\"attention_mask\"].to(device)\n",
        "            outs = model.generate(input_ids=ids, attention_mask=mask, max_new_tokens=2)\n",
        "            pred_texts = tokenizer.batch_decode(outs, skip_special_tokens=True)\n",
        "            lab_ids = batch[\"labels\"].clone()\n",
        "            lab_ids[lab_ids==-100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(lab_ids, skip_special_tokens=True)\n",
        "            preds += [map_pred_text_to_idx(p, label_texts) for p in pred_texts]\n",
        "            refs  += [map_pred_text_to_idx(r, label_texts) for r in ref_texts]\n",
        "    print(\"Sample preds:\", pred_texts[:5])\n",
        "    print(\"Sample refs:\",  ref_texts[:5])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "# --- 3. Custom collate_fn to stack already-padded lists into tensors ---\n",
        "def collate_fn(batch):\n",
        "    return {\n",
        "        \"input_ids\":      torch.tensor([ex[\"input_ids\"]      for ex in batch], dtype=torch.long),\n",
        "        \"attention_mask\": torch.tensor([ex[\"attention_mask\"] for ex in batch], dtype=torch.long),\n",
        "        \"labels\":         torch.tensor([ex[\"labels\"]         for ex in batch], dtype=torch.long),\n",
        "    }\n",
        "\n",
        "# --- 4. Training Loops ---\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt   = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler= GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_sums, enc_cnts = defaultdict(float), defaultdict(int)\n",
        "        dec_sums, dec_cnts = defaultdict(float), defaultdict(int)\n",
        "        crs_sums, crs_cnts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_e, prev_d, prev_c = None, None, None\n",
        "\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(\n",
        "                    input_ids=batch[\"input_ids\"].to(device),\n",
        "                    attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                    labels=batch[\"labels\"].to(device),\n",
        "                )\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            if prev_e is not None:\n",
        "                for idx,v in compute_conditional_batch_entropy(prev_e, enc_acts).items():\n",
        "                    enc_sums[idx]+=v; enc_cnts[idx]+=1\n",
        "                for idx,v in compute_conditional_batch_entropy(prev_d, dec_acts).items():\n",
        "                    dec_sums[idx]+=v; dec_cnts[idx]+=1\n",
        "                for idx,v in compute_conditional_batch_entropy(prev_c, cross_acts).items():\n",
        "                    crs_sums[idx]+=v; crs_cnts[idx]+=1\n",
        "\n",
        "            prev_e = {i: enc_acts[i].clone() for i in enc_acts}\n",
        "            prev_d = {i: dec_acts[i].clone() for i in dec_acts}\n",
        "            prev_c = {i: cross_acts[i].clone() for i in cross_acts}\n",
        "\n",
        "        enc_er = {i: enc_sums[i]/enc_cnts[i] for i in enc_sums}\n",
        "        dec_er = {i: dec_sums[i]/dec_cnts[i] for i in dec_sums}\n",
        "        crs_er = {i: crs_sums[i]/crs_cnts[i] for i in crs_sums}\n",
        "        print(f\"[Epoch {epoch+1}] ER enc: {enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] ER dec: {dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] ER crs: {crs_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, label_texts, device)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks([(enc_hooks,enc_acts),(dec_hooks,dec_acts),(cross_hooks,cross_acts)])\n",
        "    return model, enc_er, dec_er, crs_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er, dec_er, crs_er, tokenizer, label_texts):\n",
        "    print(\"=== Stage 2: Prune & Fine-Tuning ===\")\n",
        "    pruned = prune_er_layers(model.decoder.block, dec_er, num_prune=4, hidden_size=model.config.d_model)\n",
        "    print(\"Pruned decoder layers:\", pruned)\n",
        "    opt   = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "\n",
        "    for epoch in range(2):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(\n",
        "                input_ids=batch[\"input_ids\"].to(device),\n",
        "                attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                labels=batch[\"labels\"].to(device),\n",
        "            )\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, label_texts, device)\n",
        "        print(f\"[PruneFT Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    label_texts = [\"entailment\",\"neutral\",\"contradiction\"]\n",
        "\n",
        "    ds_train = load_dataset(\"glue\",\"mnli\",split=\"train[:5000]\").shuffle(seed)\n",
        "    ds_dev   = load_dataset(\"glue\",\"mnli\",split=\"validation_matched[:1000]\")\n",
        "    tokenizer= T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "\n",
        "    train = ds_train.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                        batched=True, remove_columns=[\"premise\",\"hypothesis\",\"idx\"])\n",
        "    dev   = ds_dev  .map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                        batched=True, remove_columns=[\"premise\",\"hypothesis\",\"idx\"])\n",
        "\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collate_fn)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False, collate_fn=collate_fn)\n",
        "\n",
        "    model, enc_er, dec_er, crs_er = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer, label_texts\n",
        "    )\n",
        "    _ = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er, dec_er, crs_er, tokenizer, label_texts\n",
        "    )\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "rYqWNzKCBKLM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            # Only compute if shapes are valid and not empty\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# --- 2. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 3. Data Processing ---\n",
        "def make_t5_mrpc_prompt(sentence1, sentence2):\n",
        "    return f\"mrpc sentence1: {sentence1} sentence2: {sentence2}\"\n",
        "\n",
        "def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_mrpc_prompt(s1, s2) for s1, s2 in zip(batch['sentence1'], batch['sentence2'])]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    label_list = [\"not_equivalent\", \"equivalent\"]  # 0: not equivalent, 1: equivalent\n",
        "    labels = [label_list[x] if (isinstance(x, int) and x in {0, 1}) else \"not_equivalent\" for x in batch['label']]\n",
        "    target = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds)\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device, label_texts):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "# --- 4. Training Loops ---\n",
        "\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, cross_er_scores, tokenizer, label_texts):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] MRPC Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 5. Main Entrypoint ---\n",
        "\n",
        "def main():\n",
        "    # Use Hugging Face's MRPC (GLUE)\n",
        "    raw_datasets = load_dataset(\"glue\", \"mrpc\")\n",
        "    tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "    label_texts = [\"not_equivalent\", \"equivalent\"]\n",
        "\n",
        "    # Use \"train\" and \"validation\" splits\n",
        "    train_ds = raw_datasets[\"train\"].shuffle(seed=42)\n",
        "    dev_ds = raw_datasets[\"validation\"].shuffle(seed=42)\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=train_ds.column_names)\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=dev_ds.column_names)\n",
        "\n",
        "    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer, label_texts)\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, cross_er_scores,\n",
        "        tokenizer, label_texts)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "NyL8K4X564A3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def reset_activation_dict(d):\n",
        "    for k in d:\n",
        "        d[k] = None\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# --- 2. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 3. Data Processing ---\n",
        "def make_t5_sst2_prompt(sentence):\n",
        "    return f\"sst2 sentence: {sentence}\"\n",
        "\n",
        "def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_sst2_prompt(s) for s in batch['sentence']]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    label_list = [\"negative\", \"positive\"]  # 0: negative, 1: positive\n",
        "    labels = [label_list[x] if (isinstance(x, int) and x in {0, 1}) else \"negative\" for x in batch['label']]\n",
        "    target = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds)\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device, label_texts):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "# --- 4. Training Loops ---\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            # --- Reset activations before each batch ---\n",
        "            reset_activation_dict(enc_acts)\n",
        "            reset_activation_dict(dec_acts)\n",
        "            reset_activation_dict(cross_acts)\n",
        "\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # --- Check all activations are collected ---\n",
        "            if (\n",
        "                any(v is None for v in enc_acts.values()) or\n",
        "                any(v is None for v in dec_acts.values()) or\n",
        "                any(v is None for v in cross_acts.values())\n",
        "            ):\n",
        "                print(\"[WARN] Some activations not collected, skipping ER computation this batch.\")\n",
        "                prev_enc_acts = None\n",
        "                prev_dec_acts = None\n",
        "                prev_cross_acts = None\n",
        "                continue\n",
        "\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, cross_er_scores, tokenizer, label_texts):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] SST-2 Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 5. Main Entrypoint ---\n",
        "def main():\n",
        "    raw_datasets = load_dataset(\"glue\", \"sst2\")\n",
        "    tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "    label_texts = [\"negative\", \"positive\"]\n",
        "\n",
        "    train_ds = raw_datasets[\"train\"].shuffle(seed=42).select(range(10000))\n",
        "    dev_ds = raw_datasets[\"validation\"].shuffle(seed=42)\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=train_ds.column_names)\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=dev_ds.column_names)\n",
        "\n",
        "    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer, label_texts)\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, cross_er_scores,\n",
        "        tokenizer, label_texts)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "OjW9THvtKY8X"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "3E0wZs4WKZfz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# --- 2. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 3. Data Processing ---\n",
        "def make_t5_cola_prompt(sentence):\n",
        "    return f\"cola sentence: {sentence}\"\n",
        "\n",
        "def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_cola_prompt(s) for s in batch['sentence']]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    label_list = [\"unacceptable\", \"acceptable\"]  # 0: unacceptable, 1: acceptable\n",
        "    labels = [label_list[x] if (isinstance(x, int) and x in {0, 1}) else \"unacceptable\" for x in batch['label']]\n",
        "    target = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds)\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device, label_texts):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "# --- 4. Training Loops ---\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, cross_er_scores, tokenizer, label_texts):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] CoLA Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 5. Main Entrypoint ---\n",
        "def main():\n",
        "    raw_datasets = load_dataset(\"glue\", \"cola\")\n",
        "    tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "    label_texts = [\"unacceptable\", \"acceptable\"]\n",
        "\n",
        "    train_ds = raw_datasets[\"train\"].shuffle(seed=42).select(range(8000))\n",
        "    dev_ds = raw_datasets[\"validation\"].shuffle(seed=42)\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=train_ds.column_names)\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=dev_ds.column_names)\n",
        "\n",
        "    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer, label_texts)\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, cross_er_scores,\n",
        "        tokenizer, label_texts)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "nS89xojmCBwo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# --- 2. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 3. Data Processing ---\n",
        "def make_t5_qnli_prompt(question, sentence):\n",
        "    return f\"qnli question: {question} sentence: {sentence}\"\n",
        "\n",
        "def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_qnli_prompt(q, s) for q, s in zip(batch['question'], batch['sentence'])]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    label_list = [\"entailment\", \"not_entailment\"]  # 0: entailment, 1: not_entailment\n",
        "    labels = [label_list[x] if (isinstance(x, int) and x in {0, 1}) else \"not_entailment\" for x in batch['label']]\n",
        "    with tokenizer.as_target_tokenizer():\n",
        "        targets = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    target_ids = targets[\"input_ids\"]\n",
        "    # Set padding tokens to -100 (ignored by loss)\n",
        "    target_ids = [\n",
        "        [tok if tok != tokenizer.pad_token_id else -100 for tok in label_ids]\n",
        "        for label_ids in target_ids\n",
        "    ]\n",
        "    model_inputs[\"labels\"] = target_ids\n",
        "    return model_inputs\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device, label_texts):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            pred_texts = [p.strip().lower() for p in pred_texts]\n",
        "            ref_texts = [r.strip().lower() for r in ref_texts]\n",
        "            for p, r in zip(pred_texts, ref_texts):\n",
        "                if p == \"entailment\":\n",
        "                    preds.append(0)\n",
        "                elif p == \"not_entailment\":\n",
        "                    preds.append(1)\n",
        "                else:\n",
        "                    preds.append(1)  # fallback\n",
        "                if r == \"entailment\":\n",
        "                    refs.append(0)\n",
        "                elif r == \"not_entailment\":\n",
        "                    refs.append(1)\n",
        "                else:\n",
        "                    refs.append(1)\n",
        "    acc = sum([int(p == r) for p, r in zip(preds, refs)]) / len(preds)\n",
        "    print(\"Sample model outputs:\", pred_texts[:5])\n",
        "    print(\"Sample true labels:\", ref_texts[:5])\n",
        "    return acc\n",
        "\n",
        "# --- 4. Training Loops ---\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, cross_er_scores, tokenizer, label_texts):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] QNLI Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 5. Main Entrypoint ---\n",
        "def main():\n",
        "    raw_datasets = load_dataset(\"glue\", \"qnli\")\n",
        "    tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "    label_texts = [\"entailment\", \"not_entailment\"]\n",
        "\n",
        "    train_ds = raw_datasets[\"train\"].shuffle(seed=42).select(range(5000))\n",
        "    dev_ds = raw_datasets[\"validation\"].shuffle(seed=42)\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=train_ds.column_names)\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=dev_ds.column_names)\n",
        "\n",
        "    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer, label_texts)\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, cross_er_scores,\n",
        "        tokenizer, label_texts)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "Z2UJf5hTiAE_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# --- 2. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 3. Data Processing ---\n",
        "def make_t5_qqp_prompt(question1, question2):\n",
        "    return f\"qqp question1: {question1} question2: {question2}\"\n",
        "\n",
        "def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    # QQP: label 0=not_duplicate, 1=duplicate\n",
        "    inputs = [make_t5_qqp_prompt(q1, q2) for q1, q2 in zip(batch['question1'], batch['question2'])]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    label_list = [\"not_duplicate\", \"duplicate\"]\n",
        "    labels = [label_list[x] if (isinstance(x, int) and x in {0, 1}) else \"not_duplicate\" for x in batch['label']]\n",
        "    with tokenizer.as_target_tokenizer():\n",
        "        targets = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    # Mask out pad tokens as -100 for loss\n",
        "    target_ids = [\n",
        "        [tok if tok != tokenizer.pad_token_id else -100 for tok in label_ids]\n",
        "        for label_ids in targets[\"input_ids\"]\n",
        "    ]\n",
        "    model_inputs[\"labels\"] = target_ids\n",
        "    return model_inputs\n",
        "\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds)\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device, label_texts):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            pred_texts = [p.strip().lower() for p in pred_texts]\n",
        "            ref_texts = [r.strip().lower() for r in ref_texts]\n",
        "            # Map predictions and references to class index (0=not_duplicate, 1=duplicate)\n",
        "            for p, r in zip(pred_texts, ref_texts):\n",
        "                if p == \"duplicate\":\n",
        "                    preds.append(1)\n",
        "                elif p == \"not_duplicate\":\n",
        "                    preds.append(0)\n",
        "                else:\n",
        "                    preds.append(0)  # fallback\n",
        "                if r == \"duplicate\":\n",
        "                    refs.append(1)\n",
        "                elif r == \"not_duplicate\":\n",
        "                    refs.append(0)\n",
        "                else:\n",
        "                    refs.append(0)\n",
        "    acc = sum([int(p == r) for p, r in zip(preds, refs)]) / len(preds)\n",
        "    print(\"Sample model outputs:\", pred_texts[:5])\n",
        "    print(\"Sample true labels:\", ref_texts[:5])\n",
        "    return acc\n",
        "\n",
        "# --- 4. Training Loops ---\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            # Reset activations\n",
        "            for d in [enc_acts, dec_acts, cross_acts]:\n",
        "                for k in d:\n",
        "                    d[k] = None\n",
        "\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # Activation checks\n",
        "            if (\n",
        "                any(v is None for v in enc_acts.values()) or\n",
        "                any(v is None for v in dec_acts.values()) or\n",
        "                any(v is None for v in cross_acts.values())\n",
        "            ):\n",
        "                prev_enc_acts = None\n",
        "                prev_dec_acts = None\n",
        "                prev_cross_acts = None\n",
        "                continue\n",
        "\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, cross_er_scores, tokenizer, label_texts):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] QQP Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 5. Main Entrypoint ---\n",
        "\n",
        "def main():\n",
        "    # Use Hugging Face's QQP (GLUE)\n",
        "    raw_datasets = load_dataset(\"glue\", \"qqp\")\n",
        "    tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "    label_texts = [\"not_duplicate\", \"duplicate\"]\n",
        "\n",
        "    # Use \"train\" and \"validation\" splits\n",
        "    train_ds = raw_datasets[\"train\"].shuffle(seed=42).select(range(5000))\n",
        "    dev_ds = raw_datasets[\"validation\"].shuffle(seed=42)\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=train_ds.column_names)\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=dev_ds.column_names)\n",
        "\n",
        "    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer, label_texts)\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, cross_er_scores,\n",
        "        tokenizer, label_texts)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "4I6tZWUun9hV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def reset_activation_dict(d):\n",
        "    for k in d:\n",
        "        d[k] = None\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# --- 2. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 3. Data Processing ---\n",
        "def make_t5_rte_prompt(premise, hypothesis):\n",
        "    return f\"rte premise: {premise} hypothesis: {hypothesis}\"\n",
        "\n",
        "def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    # RTE: label 0 = entailment, 1 = not_entailment\n",
        "    inputs = [make_t5_rte_prompt(p, h) for p, h in zip(batch['sentence1'], batch['sentence2'])]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    label_list = [\"entailment\", \"not_entailment\"]\n",
        "    labels = [label_list[x] if (isinstance(x, int) and x in {0, 1}) else \"not_entailment\" for x in batch['label']]\n",
        "    target = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds)\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, label_texts):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(model.device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(model.device)\n",
        "            # Use more tokens for output\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=4)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            for p, r in zip(pred_texts, ref_texts):\n",
        "                p = p.strip().lower()\n",
        "                r = r.strip().lower()\n",
        "                # Map prediction to class\n",
        "                if p.startswith(\"entail\"):\n",
        "                    preds.append(0)\n",
        "                elif p.startswith(\"not\") or \"not\" in p:\n",
        "                    preds.append(1)\n",
        "                else:\n",
        "                    # Fallback: assign to most common class\n",
        "                    preds.append(1)\n",
        "                # Map reference to class\n",
        "                if r.startswith(\"entail\"):\n",
        "                    refs.append(0)\n",
        "                elif r.startswith(\"not\") or \"not\" in r:\n",
        "                    refs.append(1)\n",
        "                else:\n",
        "                    refs.append(1)\n",
        "    acc = sum([int(p == r) for p, r in zip(preds, refs)]) / len(preds)\n",
        "    print(\"Sample model outputs:\", pred_texts[:5])\n",
        "    print(\"Sample true labels:\", ref_texts[:5])\n",
        "    print(\"Sample mapped preds:\", preds[:5], \"refs:\", refs[:5])\n",
        "    return acc\n",
        "\n",
        "\n",
        "\n",
        "# --- 4. Training Loops ---\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            for d in [enc_acts, dec_acts, cross_acts]:\n",
        "                for k in d: d[k] = None\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            # Activation checks\n",
        "            if (\n",
        "                any(v is None for v in enc_acts.values()) or\n",
        "                any(v is None for v in dec_acts.values()) or\n",
        "                any(v is None for v in cross_acts.values())\n",
        "            ):\n",
        "                prev_enc_acts = None\n",
        "                prev_dec_acts = None\n",
        "                prev_cross_acts = None\n",
        "                continue\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, label_texts)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, cross_er_scores, tokenizer, label_texts):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, label_texts)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] RTE Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 5. Main Entrypoint ---\n",
        "\n",
        "def main():\n",
        "    raw_datasets = load_dataset(\"glue\", \"rte\")\n",
        "    tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "    label_texts = [\"entailment\", \"not_entailment\"]\n",
        "\n",
        "    train_ds = raw_datasets[\"train\"].shuffle(seed=42)\n",
        "    dev_ds = raw_datasets[\"validation\"].shuffle(seed=42)\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=train_ds.column_names)\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=dev_ds.column_names)\n",
        "\n",
        "    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer, label_texts)\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, cross_er_scores,\n",
        "        tokenizer, label_texts)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "CvZo5-P6CDxl"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "from scipy.stats import pearsonr\n",
        "import numpy as np\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# --- 2. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 3. Data Processing ---\n",
        "def make_t5_stsb_prompt(s1, s2):\n",
        "    return f\"stsb sentence1: {s1} sentence2: {s2}\"\n",
        "\n",
        "def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_stsb_prompt(s1, s2) for s1, s2 in zip(batch['sentence1'], batch['sentence2'])]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    # Convert scores to strings for text-to-text, use \"-1\" if not present\n",
        "    labels = [str(s) if (s is not None and not math.isnan(s)) else \"-1\" for s in batch['label']]\n",
        "    target = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "def compute_pearson(preds, refs):\n",
        "    try:\n",
        "        return pearsonr(preds, refs)[0]\n",
        "    except Exception:\n",
        "        return 0.0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(model.device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(model.device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=8)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            # Extract floats\n",
        "            for p, l in zip(pred_texts, ref_texts):\n",
        "                try:\n",
        "                    preds.append(float(p.strip()))\n",
        "                except Exception:\n",
        "                    preds.append(0.0)\n",
        "                try:\n",
        "                    refs.append(float(l.strip()))\n",
        "                except Exception:\n",
        "                    refs.append(0.0)\n",
        "    # Pearson\n",
        "    return compute_pearson(np.array(preds), np.array(refs))\n",
        "\n",
        "# --- 4. Training Loops ---\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        pearson = evaluate_model(model, dev_loader, tokenizer)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Pearson: {pearson:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, cross_er_scores, tokenizer):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        pearson = evaluate_model(model, dev_loader, tokenizer)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] STS-B Pearson: {pearson:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 5. Main Entrypoint ---\n",
        "\n",
        "def main():\n",
        "    raw_datasets = load_dataset(\"glue\", \"stsb\")\n",
        "    tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "\n",
        "    train_ds = raw_datasets[\"train\"].shuffle(seed=42)\n",
        "    dev_ds = raw_datasets[\"validation\"].shuffle(seed=42)\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=train_ds.column_names)\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=dev_ds.column_names)\n",
        "\n",
        "    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer)\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, cross_er_scores,\n",
        "        tokenizer)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "wylUy4I2CEN-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "zYw1JkeJCEqW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "t8dN9sibCFI1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "plh93jteCFuF"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}