{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6ENR6EQ2b8R8"
      },
      "outputs": [],
      "source": [
        "!pip uninstall -y datasets\n",
        "!pip install datasets==2.18.0\n",
        "!pip install evaluate\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n"
      ],
      "metadata": {
        "id": "d7JFdShlZwE4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            # Only compute if shapes are valid and not empty\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# --- 2. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 3. Data Processing ---\n",
        "def make_t5_nli_prompt(premise, hypothesis):\n",
        "    return f\"nli premise: {premise} hypothesis: {hypothesis}\"\n",
        "\n",
        "def preprocess_function(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    label_list = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "    labels = [label_list[x] if isinstance(x, int) and x < len(label_list) else x for x in batch['label']]\n",
        "    target = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds)\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device, label_texts):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "# --- 4. Training Loops ---\n",
        "\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer, label_texts):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, cross_er_scores, tokenizer, label_texts):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        " #   enc_prune_idxs = prune_er_layers(model.encoder.block, enc_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        " #   print(\"Pruned encoder layers (highest ER):\", enc_prune_idxs)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device, label_texts)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] e-SNLI Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 5. Main Entrypoint ---\n",
        "\n",
        "def main():\n",
        "    data_files = {\n",
        "        \"train\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_train.json\",\n",
        "        \"validation\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_valid.json\",\n",
        "        \"test\": \"/content/drive/MyDrive/NLP_datasets/esnli/esnli_test.json\"\n",
        "    }\n",
        "    raw_datasets = load_dataset(\"json\", data_files=data_files)\n",
        "    tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "    label_texts = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "\n",
        "    train_ds = raw_datasets[\"train\"].shuffle(seed=42).select(range(10000))\n",
        "    dev_ds = raw_datasets[\"validation\"].shuffle(seed=42).select(range(2000))\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True, remove_columns=train_ds.column_names)\n",
        "    dev = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                     batched=True, remove_columns=dev_ds.column_names)\n",
        "\n",
        "    collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer, label_texts)\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, cross_er_scores,\n",
        "        tokenizer, label_texts)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "tB4eSzBIFEsg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Pa2LlXfR-cus"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Prune the decoder\n",
        "\n",
        "# Mount Google Drive if on Colab\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Load CQA Data ---\n",
        "data_files = {\n",
        "    \"train\": \"/content/drive/MyDrive/NLP_datasets/cqa/cqa_train.json\",\n",
        "    \"test\":  \"/content/drive/MyDrive/NLP_datasets/cqa/cqa_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Preprocessing Functions ---\n",
        "\n",
        "def preprocess_cqa(batch, tokenizer, max_input_length=128, max_target_length=8, use_cot=False):\n",
        "    # Prompt with or without CoT\n",
        "    if use_cot and 'abstractive_explanation' in batch:\n",
        "        # Use question, choices, and abstractive explanation for reasoning\n",
        "        inputs = [\n",
        "            f\"question: {q} choices: {', '.join(choices)} rationale: {exp}\"\n",
        "            for q, choices, exp in zip(batch['question'], batch['choices'], batch['abstractive_explanation'])\n",
        "        ]\n",
        "    else:\n",
        "        inputs = [\n",
        "            f\"question: {q} choices: {', '.join(choices)}\"\n",
        "            for q, choices in zip(batch['question'], batch['choices'])\n",
        "        ]\n",
        "    targets = [str(ans).strip() for ans in batch['answer']]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    target = tokenizer(targets, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "USE_COT = False  # Set to True to include abstractive_explanation\n",
        "\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=USE_COT),\n",
        "                            batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev   = dataset[\"test\"].map(lambda ex: preprocess_cqa(ex, tokenizer, use_cot=False),\n",
        "                            batched=True, remove_columns=dataset[\"test\"].column_names)\n",
        "\n",
        "collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=32, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev, batch_size=32, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 3. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# --- 4. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=4, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 5. Training/Eval/ER Pipeline ---\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=4)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, tokenizer):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        " #   enc_prune_idxs = prune_er_layers(model.encoder.block, enc_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        " #   print(\"Pruned encoder layers (highest ER):\", enc_prune_idxs)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] CQA Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 6. Entrypoint ---\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer)\n",
        "    # --- PRUNING AND CONTINUED FINETUNING ---\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, tokenizer\n",
        "    )\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "wm8aQwKE-eKF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "-nEqET8995OX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Only prune decoder\n",
        "\n",
        "# --- Mount Google Drive if using Colab ---\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# --- Standard Imports ---\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# --- 1. Load ANLI1 Dataset ---\n",
        "data_files = {\n",
        "    \"train\":      \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_train.json\",\n",
        "    \"validation\": \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_valid.json\",\n",
        "    \"test\":       \"/content/drive/MyDrive/NLP_datasets/anli1/anli1_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# --- 2. Preprocessing Function ---\n",
        "def make_t5_nli_prompt(premise, hypothesis):\n",
        "    return f\"nli premise: {premise} hypothesis: {hypothesis}\"\n",
        "\n",
        "def preprocess_anli(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    inputs = [make_t5_nli_prompt(p, h) for p, h in zip(batch['premise'], batch['hypothesis'])]\n",
        "    label_list = [\"entailment\", \"neutral\", \"contradiction\"]\n",
        "    labels = [label_list[int(x)] if isinstance(x, (int, float, str)) and str(x).isdigit() and int(x)<3 else str(x) for x in batch['label']]\n",
        "    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=max_input_length)\n",
        "    target = tokenizer(labels, padding=\"max_length\", truncation=True, max_length=max_target_length)\n",
        "    model_inputs[\"labels\"] = target[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_anli(ex, tokenizer), batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev   = dataset[\"validation\"].map(lambda ex: preprocess_anli(ex, tokenizer), batched=True, remove_columns=dataset[\"validation\"].column_names)\n",
        "collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=16, shuffle=True, collate_fn=collator)\n",
        "dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# --- 3. Conditional ER Hook Utilities ---\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# --- 4. Pruning Utilities ---\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=4, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# --- 5. Training/Eval/ER Pipeline ---\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=2)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, tokenizer):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        "#    enc_prune_idxs = prune_er_layers(model.encoder.block, enc_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        "#    print(\"Pruned encoder layers (highest ER):\", enc_prune_idxs)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=2e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] ANLI1 Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# --- 6. Entrypoint ---\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer)\n",
        "    # --- PRUNING AND CONTINUED FINETUNING ---\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, tokenizer\n",
        "    )\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "WV0cr0Wi966u"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "yM4MGcHI21qy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Only prune decoder\n",
        "# ===========================\n",
        "# 0. Google Drive Mount\n",
        "# ===========================\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# ===========================\n",
        "# 1. Imports and Setup\n",
        "# ===========================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader\n",
        "from transformers import (\n",
        "    T5ForConditionalGeneration, T5TokenizerFast,\n",
        "    DataCollatorForSeq2Seq, get_linear_schedule_with_warmup\n",
        ")\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from collections import defaultdict\n",
        "import warnings\n",
        "import math\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ===========================\n",
        "# 2. Load SVAMP Dataset\n",
        "# ===========================\n",
        "data_files = {\n",
        "    \"train\": \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_train.json\",\n",
        "    \"test\": \"/content/drive/MyDrive/NLP_datasets/svamp/svamp_test.json\"\n",
        "}\n",
        "dataset = load_dataset(\"json\", data_files=data_files)\n",
        "\n",
        "# ===========================\n",
        "# 3. Preprocessing\n",
        "# ===========================\n",
        "def preprocess_svamp(batch, tokenizer, max_input_length=128, max_target_length=8):\n",
        "    model_inputs = tokenizer(\n",
        "        batch[\"input\"], padding=\"max_length\", truncation=True, max_length=max_input_length\n",
        "    )\n",
        "    targets = [str(x) for x in batch[\"label\"]]\n",
        "    target_encodings = tokenizer(\n",
        "        targets, padding=\"max_length\", truncation=True, max_length=max_target_length\n",
        "    )\n",
        "    model_inputs[\"labels\"] = target_encodings[\"input_ids\"]\n",
        "    return model_inputs\n",
        "\n",
        "tokenizer = T5TokenizerFast.from_pretrained(\"t5-base\")\n",
        "train = dataset[\"train\"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset[\"train\"].column_names)\n",
        "dev = dataset[\"test\"].map(lambda ex: preprocess_svamp(ex, tokenizer), batched=True, remove_columns=dataset[\"test\"].column_names)\n",
        "collator = DataCollatorForSeq2Seq(tokenizer, model=None, padding=\"max_length\", max_length=128)\n",
        "train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "dev_loader = DataLoader(dev, batch_size=8, shuffle=False, collate_fn=collator)\n",
        "\n",
        "# ===========================\n",
        "# 4. Conditional ER Utilities\n",
        "# ===========================\n",
        "def register_conditional_er_hooks(model):\n",
        "    enc_layers = model.encoder.block\n",
        "    enc_acts = {i: None for i in range(len(enc_layers))}\n",
        "    enc_hooks = []\n",
        "    for i, layer in enumerate(enc_layers):\n",
        "        def hook_fn_enc(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            enc_acts[idx] = hs.detach()\n",
        "        enc_hooks.append(layer.register_forward_hook(hook_fn_enc))\n",
        "    dec_layers = model.decoder.block\n",
        "    dec_acts = {i: None for i in range(len(dec_layers))}\n",
        "    dec_hooks = []\n",
        "    for i, layer in enumerate(dec_layers):\n",
        "        def hook_fn_dec(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            dec_acts[idx] = hs.detach()\n",
        "        dec_hooks.append(layer.register_forward_hook(hook_fn_dec))\n",
        "    cross_acts = {i: None for i in range(len(dec_layers))}\n",
        "    cross_hooks = []\n",
        "    for i, block in enumerate(dec_layers):\n",
        "        def hook_fn_cross(module, inp, out, idx=i):\n",
        "            hs = out[0] if isinstance(out, tuple) else out\n",
        "            cross_acts[idx] = hs.detach()\n",
        "        cross_attn = block.layer[1]\n",
        "        cross_hooks.append(cross_attn.register_forward_hook(hook_fn_cross))\n",
        "    return (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)\n",
        "\n",
        "def remove_hooks(hook_sets):\n",
        "    for hooks, _ in hook_sets:\n",
        "        for h in hooks:\n",
        "            h.remove()\n",
        "\n",
        "def compute_conditional_batch_entropy(prev_acts, curr_acts):\n",
        "    er_scores = {}\n",
        "    for i in range(len(curr_acts) - 1):\n",
        "        prev_X, prev_Y = prev_acts[i], prev_acts[i+1]\n",
        "        curr_X, curr_Y = curr_acts[i], curr_acts[i+1]\n",
        "        if (\n",
        "            prev_X is not None and prev_Y is not None and\n",
        "            curr_X is not None and curr_Y is not None and\n",
        "            prev_X.shape == curr_X.shape and\n",
        "            prev_Y.shape == curr_Y.shape\n",
        "        ):\n",
        "            B = curr_X.size(0)\n",
        "            dX = (curr_X - prev_X).view(B, -1)\n",
        "            dY = (curr_Y - prev_Y).view(B, -1)\n",
        "            if B >= 2 and not (torch.isnan(dX).any() or torch.isnan(dY).any()):\n",
        "                cos_squares = [\n",
        "                    F.cosine_similarity(dY[j].unsqueeze(0), dX[j].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                    for j in range(1, B)\n",
        "                    if not (torch.isnan(dX[j]).any() or torch.isnan(dY[j]).any())\n",
        "                ]\n",
        "                if cos_squares:\n",
        "                    er = sum(cos_squares) / len(cos_squares)\n",
        "                    if not (math.isnan(er) or math.isinf(er)):\n",
        "                        er_scores[i] = er\n",
        "    return er_scores\n",
        "\n",
        "# ===========================\n",
        "# 5. Pruning Utilities\n",
        "# ===========================\n",
        "class SkipBlock(nn.Module):\n",
        "    def __init__(self, hidden_size):\n",
        "        super().__init__()\n",
        "        self.hidden_size = hidden_size\n",
        "    def forward(self, hidden_states, *args, **kwargs):\n",
        "        return (hidden_states, None, None, None, None, None)\n",
        "\n",
        "def prune_er_layers(blocks, er_scores, num_prune=2, hidden_size=768):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(blocks)]\n",
        "    for idx in prune_idxs:\n",
        "        blocks[idx] = SkipBlock(hidden_size)\n",
        "    return prune_idxs\n",
        "\n",
        "# ===========================\n",
        "# 6. Eval Helper\n",
        "# ===========================\n",
        "def compute_accuracy(preds, refs):\n",
        "    correct = 0\n",
        "    for p, l in zip(preds, refs):\n",
        "        if p == l:\n",
        "            correct += 1\n",
        "    return correct / len(preds) if len(preds) > 0 else 0\n",
        "\n",
        "def evaluate_model(model, dl, tokenizer, device):\n",
        "    model.eval()\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dl:\n",
        "            input_ids = batch[\"input_ids\"].to(device)\n",
        "            attention_mask = batch[\"attention_mask\"].to(device)\n",
        "            outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=8)\n",
        "            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
        "            label_ids = batch[\"labels\"].clone()\n",
        "            label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
        "            ref_texts = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
        "            preds.extend([p.strip().lower() for p in pred_texts])\n",
        "            refs.extend([l.strip().lower() for l in ref_texts])\n",
        "    return compute_accuracy(preds, refs)\n",
        "\n",
        "# ===========================\n",
        "# 7. Training + ER Tracking + Pruning\n",
        "# ===========================\n",
        "def full_finetuning(train_loader, dev_loader, device, tokenizer):\n",
        "    print(\"=== Stage 1: Full Fine-Tuning & Conditional ER Estimation ===\")\n",
        "    model = T5ForConditionalGeneration.from_pretrained(\"t5-base\").to(device)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
        "    scaler = GradScaler()\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    (enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts) = register_conditional_er_hooks(model)\n",
        "    last_enc_er, last_dec_er, last_cross_er = None, None, None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        enc_er_sums, enc_er_counts = defaultdict(float), defaultdict(int)\n",
        "        dec_er_sums, dec_er_counts = defaultdict(float), defaultdict(int)\n",
        "        cross_er_sums, cross_er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        prev_enc_acts, prev_dec_acts, prev_cross_acts = None, None, None\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                                attention_mask=batch['attention_mask'].to(device),\n",
        "                                labels=batch['labels'].to(device))\n",
        "                loss = outputs.loss\n",
        "                scaler.scale(loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "            if prev_enc_acts is not None:\n",
        "                enc_batch_er = compute_conditional_batch_entropy(prev_enc_acts, enc_acts)\n",
        "                for idx, v in enc_batch_er.items():\n",
        "                    enc_er_sums[idx] += v\n",
        "                    enc_er_counts[idx] += 1\n",
        "            if prev_dec_acts is not None:\n",
        "                dec_batch_er = compute_conditional_batch_entropy(prev_dec_acts, dec_acts)\n",
        "                for idx, v in dec_batch_er.items():\n",
        "                    dec_er_sums[idx] += v\n",
        "                    dec_er_counts[idx] += 1\n",
        "            if prev_cross_acts is not None:\n",
        "                cross_batch_er = compute_conditional_batch_entropy(prev_cross_acts, cross_acts)\n",
        "                for idx, v in cross_batch_er.items():\n",
        "                    cross_er_sums[idx] += v\n",
        "                    cross_er_counts[idx] += 1\n",
        "            prev_enc_acts = {i: enc_acts[i].clone() if enc_acts[i] is not None else None for i in enc_acts}\n",
        "            prev_dec_acts = {i: dec_acts[i].clone() if dec_acts[i] is not None else None for i in dec_acts}\n",
        "            prev_cross_acts = {i: cross_acts[i].clone() if cross_acts[i] is not None else None for i in cross_acts}\n",
        "        epoch_enc_er = {idx: enc_er_sums[idx]/enc_er_counts[idx] for idx in enc_er_sums if enc_er_counts[idx] > 0}\n",
        "        epoch_dec_er = {idx: dec_er_sums[idx]/dec_er_counts[idx] for idx in dec_er_sums if dec_er_counts[idx] > 0}\n",
        "        epoch_cross_er = {idx: cross_er_sums[idx]/cross_er_counts[idx] for idx in cross_er_sums if cross_er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Encoder Conditional ER: {epoch_enc_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Decoder Conditional ER: {epoch_dec_er}\")\n",
        "        print(f\"[Epoch {epoch+1}] approx Cross-Attention Conditional ER: {epoch_cross_er}\")\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Epoch {epoch+1}] Dev Acc: {acc:.4f}\")\n",
        "        last_enc_er, last_dec_er, last_cross_er = epoch_enc_er, epoch_dec_er, epoch_cross_er\n",
        "\n",
        "    remove_hooks([(enc_hooks, enc_acts), (dec_hooks, dec_acts), (cross_hooks, cross_acts)])\n",
        "    return model, last_enc_er, last_dec_er, last_cross_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, enc_er_scores, dec_er_scores, tokenizer):\n",
        "    print(\"=== Stage 2: Prune (High-ER) & Fine-tuning ===\")\n",
        " #   enc_prune_idxs = prune_er_layers(model.encoder.block, enc_er_scores, num_prune=2, hidden_size=model.config.d_model)\n",
        "    dec_prune_idxs = prune_er_layers(model.decoder.block, dec_er_scores, num_prune=4, hidden_size=model.config.d_model)\n",
        " #   print(\"Pruned encoder layers (highest ER):\", enc_prune_idxs)\n",
        "    print(\"Pruned decoder layers (highest ER):\", dec_prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=5e-4)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*2)\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            outputs = model(input_ids=batch['input_ids'].to(device),\n",
        "                            attention_mask=batch['attention_mask'].to(device),\n",
        "                            labels=batch['labels'].to(device))\n",
        "            loss = outputs.loss\n",
        "            loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, tokenizer, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] SVAMP Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "# ===========================\n",
        "# 8. Main Entrypoint\n",
        "# ===========================\n",
        "def main():\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    model, enc_er_scores, dec_er_scores, cross_er_scores = full_finetuning(\n",
        "        train_loader, dev_loader, device, tokenizer)\n",
        "    model = prune_and_finetuning(\n",
        "        model, train_loader, dev_loader, device,\n",
        "        enc_er_scores, dec_er_scores, tokenizer\n",
        "    )\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "BQ322_FLvamj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "gX0GHph222Ka"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "qGXGpf-0nyNa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "U3wDNkD86y-D"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}